// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.
namespace Stride.Rendering.Shadows
{
    /// <summary>
    /// Selects the shadow map and computes the shadow factor.
    /// </summary>
    /// <remarks>
    /// TCascadeCountBase: Number of cascades.
    /// TCascadeDebug: Flag to enable debug mode (1 color per cascade).
    /// </remarks>
    internal shader ShadowMapReceiverBase<MemberName PerLighting, int TCascadeCountBase, int TLightCountBase> :
        MaterialPixelShadingStream,
        ShadowMapGroup<PerLighting>,
        ShadowMapFilterBase<PerLighting>,
        PositionStream4
    {
        cbuffer PerLighting // TODO: Use a proper cbuffer for this?
        {
            float4x4 WorldToShadowCascadeUV[TCascadeCountBase * TLightCountBase];
            float4x4 InverseWorldToShadowCascadeUV[TCascadeCountBase * TLightCountBase];   // This is only required for SSS.
            float4x4 ViewMatrices[TCascadeCountBase * TLightCountBase];   // This is only required for SSS.
            float2 DepthRanges[TCascadeCountBase * TLightCountBase];   // x = z-near, y = z-far. This is only required for SSS.
            float DepthBiases[TLightCountBase];
            float OffsetScales[TLightCountBase];
        };

        float3 GetShadowPositionOffset(float offsetScale, float nDotL, float3 normal)
        {
            float normalOffsetScale = saturate(1.0f - nDotL);
            return 2.0f * ShadowMapTextureTexelSize.x * offsetScale * normalOffsetScale * normal;
        }

        float ComputeShadowFromCascade(float3 shadowPositionWS, int cascadeIndex, int lightIndex)
        {
            //float3 shadowPositionWSddx = ddx_fine(shadowPositionWS);
            //float3 shadowPositionWSddy = ddy_fine(shadowPositionWS);

            float4 shadowPosition = mul(float4(shadowPositionWS, 1.0), WorldToShadowCascadeUV[cascadeIndex + lightIndex * TCascadeCountBase]);
            shadowPosition.z -= DepthBiases[lightIndex];
            shadowPosition.xyz /= shadowPosition.w;

            return FilterShadow(shadowPosition.xy, shadowPosition.z);
        }

        float ComputeThicknessFromCascade(float3 pixelPositionWS,   // TODO: This is named "compute..." and the other function is named "calculate..."!
                                          float3 meshNormalWS, 
                                          int cascadeIndex,
                                          int lightIndex,
                                          bool isOrthographic)
        {
            const int arrayIndex = cascadeIndex + lightIndex * TCascadeCountBase;
            
            return FilterThickness(pixelPositionWS,
                                   meshNormalWS,
                                   DepthRanges[arrayIndex],
                                   WorldToShadowCascadeUV[arrayIndex],
                                   InverseWorldToShadowCascadeUV[arrayIndex],
                                   isOrthographic);
        }
    };
}
