// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.

namespace Stride.Rendering.Images
{

    /// <summary>
    /// A blur with weights applied along one direction. 
    /// Expects as input: 
    /// - Texture0: color buffer
    /// </summary>
    ///
    /// <typeparam name="TWeightCount">The number of weights along a direction.</typeparam>
    /// <typeparam name="TTotalNumber">Total number of tpas. The value is always 2 * TWeightCount - 1.</typeparam>

    shader DepthAwareDirectionalBlurUtil<int TWeightCount, int TTotalNumber> : Texturing, ComputeColor
    {
        // Direction to apply the blur. (normalized vector)
        float2 Direction;

        // The radius of the blur to apply around the considered fragment
        float Radius;

        // Weights of each tap (weights values are symmetric along each direction)
        float TapWeights[TWeightCount];

        float CoCReference;

        // Gets the blur result for the current pixel.
        override float4 Compute()
        {
            // Offset between 2 consecutive taps
            float2 tapOffset =  Radius / (TWeightCount - 1) * Texture0TexelSize;
            
            // Fills arrays with all the taps
            float4 tapColor[TTotalNumber];  // All the taps colors
            float tapOriginalWeight[TTotalNumber]; // With their respective weight

            // Center tap
            int centerIndex = TWeightCount - 1;
            tapColor[centerIndex] = Texture0.Sample(LinearSampler, streams.TexCoord).xyzw;
            // Premultiply alpha
            tapColor[centerIndex].rgb *= tapColor[centerIndex].a;
            tapOriginalWeight[centerIndex] = TapWeights[0];

            // Treats all the taps in the 2 directions from the center
            [unroll]
            for(int i = 1; i < TWeightCount; i++)
            {
                // Backwards
                float2 tapUV = streams.TexCoord - i * Direction * tapOffset;
                int tapIndex = centerIndex - i; 
                tapColor[tapIndex] = Texture0.Sample(LinearSampler, tapUV).xyzw;
                // Premultiply alpha
                tapColor[tapIndex].rgb *= tapColor[tapIndex].a;
                tapOriginalWeight[tapIndex] = TapWeights[i];

                // Forwards
                tapUV = streams.TexCoord + i * Direction * tapOffset;
                tapIndex = centerIndex + i;
                tapColor[tapIndex] = Texture0.Sample(LinearSampler, tapUV).xyzw;
                // Premultiply alpha
                tapColor[tapIndex].rgb *= tapColor[tapIndex].a;
                tapOriginalWeight[tapIndex] = TapWeights[i];
            }

            // Calculate the final average color
            float4 resultColor = float4(0.0, 0.0, 0.0, 0.0);
            float totalWeight = 0.0;

            [unroll]
            for(int k = 0; k < TTotalNumber; k++)
            {
                float tapWeight = tapOriginalWeight[k];
                // You could change the weight on the fly here to discard some sample
                resultColor += tapColor[k].xyzw * tapWeight;
                totalWeight += tapWeight;
            }

            if (totalWeight > 0) {
                // Normalizes the final result
                resultColor /= totalWeight;
            } else {
                resultColor = float4(0.0, 0.0, 0.0, 0.0);
            }

            // Go back to non-premultiplied-alpha
            if (resultColor.a > 0) 
            {
                 resultColor.rgb /= resultColor.a;
            }

            return resultColor;
        }
    };
}
