// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.
namespace Stride.Rendering.Shadows
{
    /// <summary>
    /// Defines shadow filtering method.
    /// </summary>
    shader ShadowMapFilterBase<MemberName PerLighting> : ShadowMapCommon<PerLighting>, Texturing, Math
    {
        /// <summary>
        /// Calculate the shadow factor based on the position and shadow map distance.
        /// </summary>
        abstract float FilterShadow(float2 position, float positionDepth);
                
        /// <summary>
        /// Used to calculate offsetted shadow map and world space pixel coordinates,
        /// in order to mitigate shadow mapping artifacts for thickness calculation.
        /// </summary>
        void CalculateAdjustedShadowSpacePixelPosition(float filterRadiusInPixels,  // The radius of the sampling kernel in texture space.
                                                       float3 pixelPositionWS,
                                                       float3 meshNormalWS,
                                                       float4x4 worldToShadowCascadeUV,    // Transforms from world space to shadow cascade UV.
                                                       float4x4 inverseWorldToShadowCascadeUV,    // Transforms from shadow cascade UV to world space.
                                                       out float3 adjustedPixelPositionWS,
                                                       out float3 adjustedPixelPositionShadowSpace)
        {
            // TODO: PERFORMANCE: The offset length can be calculated once for for each directional light on the CPU. For perspective shadows, this is a bit more complex.
            // TODO: PERFORMANCE: Can we do the offset in shadow map space (by projecting the normal vector)?
            
            float4 bottomLeftTexelWS = Project(float4(0.0, 0.0, 0.0, 1.0), inverseWorldToShadowCascadeUV);
            
            // TODO: Does "ShadowMapTextureTexelSize" contain the texel size relative to the light's viewport or relative to the whole atlas?
            const float4 topRightTexelWS = Project(float4(ShadowMapTextureTexelSize.xy * filterRadiusInPixels, 0.0, 1.0), inverseWorldToShadowCascadeUV);
            
            const float texelDiagonalLength = distance(topRightTexelWS.xyz, bottomLeftTexelWS.xyz);
            
            const float3 positionOffsetWS = meshNormalWS * texelDiagonalLength;    // TODO: Do we even need an offset on faces that face the light?
            adjustedPixelPositionWS = pixelPositionWS - positionOffsetWS;  // Shrink the position into the surface to avoid SSS artifacts.

            // The pixel coordinate within shadow space (the light's post-projection space):
            const float4 shadowMapCoordinate = Project(float4(adjustedPixelPositionWS, 1.0), worldToShadowCascadeUV);

            adjustedPixelPositionShadowSpace = shadowMapCoordinate.xyz;
        }
        
        /// <summary>
        /// Used to calculate offsetted shadow map and world space pixel coordinates,
        /// in order to mitigate shadow mapping artifacts for thickness calculation.
        /// </summary>
        void CalculateAdjustedShadowSpacePixelPositionPerspective(float filterRadiusInPixels,  // The radius of the sampling kernel in texture space.
                                                                  float3 pixelPositionWS,
                                                                  float3 meshNormalWS,
                                                                  float4x4 worldToShadowCascadeUV,    // Transforms from world space to shadow cascade UV.
                                                                  float4x4 inverseWorldToShadowCascadeUV,    // Transforms from shadow cascade UV to world space.
                                                                  out float3 adjustedPixelPositionWS,
                                                                  out float3 adjustedPixelPositionShadowSpace)
        {
            // TODO: PERFORMANCE: The offset length can be calculated once for for each directional light on the CPU. For perspective shadows, this is a bit more complex.
            // TODO: PERFORMANCE: Can we do the offset in shadow map space (by projecting the normal vector)?
            
            const float4 shadowMapCoordinate = Project(float4(pixelPositionWS, 1.0), worldToShadowCascadeUV);


            
            // TODO: Does "ShadowMapTextureTexelSize" contain the texel size relative to the light's viewport or relative to the whole atlas?
            const float4 topRightTexelWS = Project(float4(shadowMapCoordinate.xy + ShadowMapTextureTexelSize.xy * filterRadiusInPixels, shadowMapCoordinate.z, 1.0), inverseWorldToShadowCascadeUV);  // TODO: Calculate two coordinates (that center shadowMapCoordinate)?
            
            const float texelDiagonalLength = distance(topRightTexelWS.xyz, pixelPositionWS);
            
            const float3 positionOffsetWS = meshNormalWS * texelDiagonalLength;    // TODO: Do we even need an offset on faces that face the light?
            adjustedPixelPositionWS = pixelPositionWS - positionOffsetWS;  // Shrink the position into the surface to avoid SSS artifacts.

            // The pixel coordinate within shadow space (the light's post-projection space):
            const float4 adjustedShadowMapCoordinate = Project(float4(adjustedPixelPositionWS, 1.0), worldToShadowCascadeUV);

            adjustedPixelPositionShadowSpace = adjustedShadowMapCoordinate.xyz;
        }

        /// <summary>
        /// Calculate the thickness of the object at this pixel using the shadow map.
        /// </summary>
        abstract float FilterThickness(float3 pixelPositionWS,
                                       float3 meshNormalWS,
                                       float2 depthRanges,
                                       float4x4 worldToShadowCascadeUV,    // Transforms from world space to shadow cascade UV.
                                       float4x4 inverseWorldToShadowCascadeUV,  // Transforms from shadow cascade UV to world space.
                                       bool isOrthographic);
        
        // TODO: Maybe implement separate linear and a perspective functions?
        float SampleThickness(float3 shadowSpaceCoordinate,
                              float3 pixelPositionWS,
                              float2 depthRanges,
                              float4x4 inverseWorldToShadowCascadeUV,
                              bool isOrthographic)
        {
            const float shadowMapDepth = ShadowMapTexture.SampleLevel(LinearBorderSampler, shadowSpaceCoordinate.xy, 0).r;

            float thickness;
            
            // Now we calculate the thickness from the light's point of view:
            if(isOrthographic)
            {
                // Subtract the two linear depth values and multiply them by Z-far in order to get the thickness in world/view space.
                // This works because for directional lightmaps, Z-near is 0.0.
                thickness = abs(shadowMapDepth - shadowSpaceCoordinate.z) * depthRanges.y;  // Same as the above, but faster.   // TODO: Better use max(thickness, 0.0) instead of abs()?
            }
            else
            {
                //float znear = depthRanges.x;
                //float zfar = depthRanges.y;
                //float2 ZProjection = float2(zfar / (zfar - znear), (-zfar * znear) / (zfar - znear));
                //float d1 = CalculateViewSpaceDepthFromNonlinearDepth(shadowMapDepth, ZProjection);
                //float d2 = CalculateViewSpaceDepthFromNonlinearDepth(positionDepth, ZProjection);

                //float d1 = ConvertToLinearDepth(shadowMapDepth, depthRanges.x, depthRanges.y);    // TODO: Multiply by (far - near)?
                //float d2 = ConvertToLinearDepth(shadowSpaceCoordinate.z, depthRanges.x, depthRanges.y);
                
                // TODO: PERFORMANCE: Instead of doing a matrix multiplication, just correctly linearize the depth and calculate the DISTANCE (not just the depth!)!
                float4 shadowmapPositionWorldSpace = Project(float4(shadowSpaceCoordinate.xy, shadowMapDepth, 1.0), inverseWorldToShadowCascadeUV);
                thickness = distance(shadowmapPositionWorldSpace.xyz, pixelPositionWS.xyz);
            }
            
            return(thickness);
        }
    };
}
