// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.
namespace Stride.Rendering.Shadows
{
    /// <summary>
    /// Performs percentage closer filtering.
    /// </summary>
    shader ShadowMapFilterPcf<MemberName PerLighting, int TFilterSize> : ShadowMapFilterBase<PerLighting>
    {
        void CalculatePCFKernelParameters(float2 position, out float2 base_uv, out float2 st)    // TODO: Make "st"!
        {
            float2 uv = position * ShadowMapTextureSize; // 1 unit - 1 texel
            
            base_uv = floor(uv + 0.5);

            st = uv + 0.5 - base_uv;

            base_uv -= 0.5;
            base_uv *= ShadowMapTextureTexelSize;
        }

        float Get3x3FilterKernel(float2 base_uv, float2 st, out float3 kernel[4])
        {
            float2 uvW0 = (3 - 2 * st);
            float2 uvW1 = (1 + 2 * st);

            float2 uv0 = (2 - st) / uvW0 - 1;
            float2 uv1 = st / uvW1 + 1;
            
            // Each kernel element contains a texture coordinate (XY) and a weight (Z).
            kernel[0] = float3(base_uv + uv0 * ShadowMapTextureTexelSize, uvW0.x * uvW0.y);
            kernel[1] = float3(base_uv + float2(uv1.x, uv0.y) * ShadowMapTextureTexelSize, uvW1.x * uvW0.y);
            kernel[2] = float3(base_uv + float2(uv0.x, uv1.y) * ShadowMapTextureTexelSize, uvW0.x * uvW1.y);
            kernel[3] = float3(base_uv + uv1 * ShadowMapTextureTexelSize, uvW1.x * uvW1.y);

            return 16.0;
        }

        float Get5x5FilterKernel(float2 base_uv, float2 st, out float3 kernel[9])
        {
            float2 uvW0 = (4 - 3 * st);
            float2 uvW1 = 7;
            float2 uvW2 = (1 + 3 * st);

            float2 uv0 = (3 - 2 * st) / uvW0 - 2;
            float2 uv1 = (3 + st) / uvW1;
            float2 uv2 = st / uvW2 + 2;
            
            // Each kernel element contains a texture coordinate (XY) and a weight (Z).
            kernel[0] = float3(base_uv + uv0 * ShadowMapTextureTexelSize, uvW0.x * uvW0.y);
            kernel[1] = float3(base_uv + float2(uv1.x, uv0.y) * ShadowMapTextureTexelSize, uvW1.x * uvW0.y);
            kernel[2] = float3(base_uv + float2(uv2.x, uv0.y) * ShadowMapTextureTexelSize, uvW2.x * uvW0.y);
            kernel[3] = float3(base_uv + float2(uv0.x, uv1.y) * ShadowMapTextureTexelSize, uvW0.x * uvW1.y);
            kernel[4] = float3(base_uv + uv1 * ShadowMapTextureTexelSize, uvW1.x * uvW1.y);
            kernel[5] = float3(base_uv + float2(uv2.x, uv1.y) * ShadowMapTextureTexelSize, uvW2.x * uvW1.y);
            kernel[6] = float3(base_uv + float2(uv0.x, uv2.y) * ShadowMapTextureTexelSize, uvW0.x * uvW2.y);
            kernel[7] = float3(base_uv + float2(uv1.x, uv2.y) * ShadowMapTextureTexelSize, uvW1.x * uvW2.y);
            kernel[8] = float3(base_uv + uv2 * ShadowMapTextureTexelSize, uvW2.x * uvW2.y);

            return 144.0;
        }
        
        float Get7x7FilterKernel(float2 base_uv, float2 st, out float3 kernel[16])
        {
            float2 uvW0 = (5 * st - 6);
            float2 uvW1 = (11 * st - 28);
            float2 uvW2 = -(11 * st + 17);
            float2 uvW3 = -(5 * st + 1);
            
            float2 uv0 = (4 * st - 5) / uvW0 - 3;
            float2 uv1 = (4 * st - 16) / uvW1 - 1;
            float2 uv2 = -(7 * st + 5) / uvW2 + 1;
            float2 uv3 = -st / uvW3 + 3;

            // Each kernel element contains a texture coordinate (XY) and a weight (Z).
            kernel[0] = float3(base_uv + uv0 * ShadowMapTextureTexelSize, uvW0.x * uvW0.y);
            kernel[1] = float3(base_uv + float2(uv1.x, uv0.y) * ShadowMapTextureTexelSize, uvW1.x * uvW0.y);
            kernel[2] = float3(base_uv + float2(uv2.x, uv0.y) * ShadowMapTextureTexelSize, uvW2.x * uvW0.y);
            kernel[3] = float3(base_uv + float2(uv3.x, uv0.y) * ShadowMapTextureTexelSize, uvW3.x * uvW0.y);
            kernel[4] = float3(base_uv + float2(uv0.x, uv1.y) * ShadowMapTextureTexelSize, uvW0.x * uvW1.y);
            kernel[5] = float3(base_uv + uv1 * ShadowMapTextureTexelSize, uvW1.x * uvW1.y);
            kernel[6] = float3(base_uv + float2(uv2.x, uv1.y) * ShadowMapTextureTexelSize, uvW2.x * uvW1.y);
            kernel[7] = float3(base_uv + float2(uv3.x, uv1.y) * ShadowMapTextureTexelSize, uvW3.x * uvW1.y);
            kernel[8] = float3(base_uv + float2(uv0.x, uv2.y) * ShadowMapTextureTexelSize, uvW0.x * uvW2.y);
            kernel[9] = float3(base_uv + float2(uv1.x, uv2.y) * ShadowMapTextureTexelSize, uvW1.x * uvW2.y);
            kernel[10] = float3(base_uv + uv2 * ShadowMapTextureTexelSize, uvW2.x * uvW2.y);
            kernel[11] = float3(base_uv + float2(uv3.x, uv2.y) * ShadowMapTextureTexelSize, uvW3.x * uvW2.y);
            kernel[12] = float3(base_uv + float2(uv0.x, uv3.y) * ShadowMapTextureTexelSize, uvW0.x * uvW3.y);
            kernel[13] = float3(base_uv + float2(uv1.x, uv3.y) * ShadowMapTextureTexelSize, uvW1.x * uvW3.y);
            kernel[14] = float3(base_uv + float2(uv2.x, uv3.y) * ShadowMapTextureTexelSize, uvW2.x * uvW3.y);
            kernel[15] = float3(base_uv + uv3 * ShadowMapTextureTexelSize, uvW3.x * uvW3.y);

            return 2704.0;
        }

        float SampleTextureAndCompare(float2 position, float positionDepth)
        {
            return ShadowMapTexture.SampleCmpLevelZero(LinearClampCompareLessEqualSampler, position, positionDepth);
        }

        float FilterShadow(float2 position, float positionDepth)
        {
            float shadow = 0.0f;

            // TODO: handle bias
          
            float2 base_uv;
            float2 st;
            CalculatePCFKernelParameters(position, base_uv, st);

            // TODO: Apply gradient for initial offset in this way once gradient mapping has been added
            //          Replacing the above 2 lines this this and using the float2 parameter depthGradient which contains the change in depth along the x and y axis over the size of the entire texture atlas
            //base_uv -= float2(0.5, 0.5);
            //float2 initialOffset = base_uv - uv;
            //base_uv *= ShadowMapTextureTexelSize;
            //
            // Take offset to pixel center into account according to the depth gradient
            //positionDepth += dot(initialOffset * ShadowMapTextureTexelSize, depthGradient);

            if (TFilterSize == 3)
            {
                float3 kernel[4];
                float normalizationFactor = Get3x3FilterKernel(base_uv, st, kernel);

                [unroll]    // The shader compiler is stupid and refuses to compile without this attribute because "divergent control flow"...
                for(int i=0; i<4; ++i)
                {
                    shadow += kernel[i].z * SampleTextureAndCompare(kernel[i].xy, positionDepth);
                }

                shadow /= normalizationFactor;
            } 
            else if (TFilterSize == 5)
            {
                float3 kernel[9];
                float normalizationFactor = Get5x5FilterKernel(base_uv, st, kernel);
                
                [unroll]    // The shader compiler is stupid and refuses to compile without this attribute because "divergent control flow"...
                for(int i=0; i<9; ++i)
                {
                    shadow += kernel[i].z * SampleTextureAndCompare(kernel[i].xy, positionDepth);
                }

                shadow /= normalizationFactor;
            }
            else if (TFilterSize == 7)
            {
                float3 kernel[16];
                float normalizationFactor = Get7x7FilterKernel(base_uv, st, kernel);
                
                [unroll]    // The shader compiler is stupid and refuses to compile without this attribute because "divergent control flow"...
                for(int i=0; i<16; ++i)
                {
                    shadow += kernel[i].z * SampleTextureAndCompare(kernel[i].xy, positionDepth);
                }

                shadow /= normalizationFactor;
            }

            return shadow;
        }
        
        /// <summary>
        /// Returns the filter radius in texture space.
        /// </summary>
        float GetFilterRadiusInPixels(void)   
        {
            // TODO: Some of these filters are so wide, that they cause artifacts on thin objects like ears for example.

            //return float(TFilterSize) / 2.0 + 1.0;

            if (TFilterSize == 3)
            {
                return 2.5;    // 3
            }
            else if (TFilterSize == 5)
            {
                return 3.5;    // 5
            }
            else
            {
                return 4.5;    // 7
            }
        }

        float SampleAndFilter(float3 adjustedPixelPositionWS, float3 adjustedPixelPositionShadowSpace, float2 depthRanges, float4x4 inverseWorldToShadowCascadeUV, bool isOrthographic, bool isDualParaboloid = false)
        {
            float2 uv = adjustedPixelPositionShadowSpace.xy * ShadowMapTextureSize; // 1 unit - 1 texel
            
            float2 base_uv = floor(uv + 0.5);
            float2 st = uv + 0.5 - base_uv;
            base_uv *= ShadowMapTextureTexelSize;
            
            float thickness = 0.0;
            float normalizationFactor = 1.0;

            if (TFilterSize == 3)
            {
                float3 kernel[4];
                normalizationFactor = Get3x3FilterKernel(base_uv, st, kernel);
                
                [unroll]    // The shader compiler is stupid and refuses to compile without this attribute because "divergent control flow"...
                for(int i=0; i<4; ++i)
                {
                    thickness += kernel[i].z * SampleThickness(float3(kernel[i].xy, adjustedPixelPositionShadowSpace.z), adjustedPixelPositionWS,
                                                               depthRanges, inverseWorldToShadowCascadeUV, isOrthographic);
                }
            } 
            else if (TFilterSize == 5)
            {
                float3 kernel[9];
                normalizationFactor = Get5x5FilterKernel(base_uv, st, kernel);
                
                [unroll]    // The shader compiler is stupid and refuses to compile without this attribute because "divergent control flow"...
                for(int i=0; i<9; ++i)
                {
                    thickness += kernel[i].z * SampleThickness(float3(kernel[i].xy, adjustedPixelPositionShadowSpace.z), adjustedPixelPositionWS,
                                                               depthRanges, inverseWorldToShadowCascadeUV, isOrthographic);
                }
            }
            else if (TFilterSize == 7)
            {
                float3 kernel[16];
                normalizationFactor = Get7x7FilterKernel(base_uv, st, kernel);
                
                [unroll]    // The shader compiler is stupid and refuses to compile without this attribute because "divergent control flow"...
                for(int i=0; i<16; ++i)
                {
                    thickness += kernel[i].z * SampleThickness(float3(kernel[i].xy, adjustedPixelPositionShadowSpace.z), adjustedPixelPositionWS,
                                                               depthRanges, inverseWorldToShadowCascadeUV, isOrthographic);
                }
            }

            return(thickness / normalizationFactor);
        }

        float FilterThickness(float3 pixelPositionWS,
                              float3 meshNormalWS,
                              float2 depthRanges,
                              float4x4 worldToShadowCascadeUV,    // Transforms from world space to shadow cascade UV.
                              float4x4 inverseWorldToShadowCascadeUV,  // Transforms from shadow cascade UV to world space.
                              bool isOrthographic)
        {
            // TODO: This filter is not great yet (quality wise), because we'd probably have to evaluate the scattering per sample to get smooth results. But that's too slow.

            float3 adjustedPixelPositionWS;
            float3 adjustedPixelPositionShadowSpace;

            if(isOrthographic)  // TODO: The offset calculation only works for directional lights for now.
            {
                // Calculate the adjusted world space coordinate and shadow map coordinate of the current pixel:


                // TODO: PERFORMANCE: Ideally we'd like to move this to "ShadowMapReceiverDirectional.sdsl",
                //                    because that way we can ensure that the adjusted pixel positions are calculated only once per pixel per light.
                //                    Sadly this is not that easy because the offset depends on the filter width, which is only available inside of "ShadowMapFilterPcf.sdsl".
                
                CalculateAdjustedShadowSpacePixelPosition(GetFilterRadiusInPixels(), pixelPositionWS, meshNormalWS,
                                                          worldToShadowCascadeUV, inverseWorldToShadowCascadeUV,
                                                          adjustedPixelPositionWS, adjustedPixelPositionShadowSpace);
            }
            else
            {
                /*
                float3 offset = -meshNormalWS * 0.01;   // TODO: This is bad!

                // Calculate the regular shadow map coordinate: // TODO: Use the one calculated by the shadow mapping?
                float4 shadowMapCoordinate = mul(float4(pixelPositionWS + offset, 1.0), worldToShadowCascadeUV);
                shadowMapCoordinate.xyz /= shadowMapCoordinate.w;

                adjustedPixelPositionShadowSpace = shadowMapCoordinate.xyz;
                adjustedPixelPositionWS = pixelPositionWS;
                */
                
                CalculateAdjustedShadowSpacePixelPositionPerspective(GetFilterRadiusInPixels(), pixelPositionWS, meshNormalWS,
                                                                     worldToShadowCascadeUV, inverseWorldToShadowCascadeUV,
                                                                     adjustedPixelPositionWS, adjustedPixelPositionShadowSpace);
            }

            // Now perform the actual filtering:
            return SampleAndFilter(adjustedPixelPositionWS, adjustedPixelPositionShadowSpace, depthRanges, inverseWorldToShadowCascadeUV, isOrthographic);
        }

    };
}
