// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.
namespace Stride.Rendering.Shadows
{
    /// <summary>
    /// Selects the shadow map and computes the shadow factor.
    /// </summary>
    /// <remarks>
    /// TCascadeCount: Number of cascades.
    /// TCascadeDebug: Flag to enable debug mode (1 color per cascade).
    /// </remarks>
    internal shader ShadowMapReceiverDirectional<int TCascadeCount,
                                                 int TLightCount,
                                                 bool TBlendCascades,
                                                 bool TDepthRangeAuto,
                                                 bool TCascadeDebug,
                                                 bool TComputeTransmittance> :
        ShadowMapReceiverBase<PerView.Lighting, TCascadeCount, TLightCount>,
        Transformation  // Required for "WorldInverseTranspose".
    {
        cbuffer PerView.Lighting // TODO: Use a proper cbuffer for this?
        {
            float CascadeDepthSplits[TCascadeCount * TLightCount];
        };

        override float3 ComputeShadow(float3 position, int lightIndex)
        {
            int cascadeIndexBase = lightIndex * TCascadeCount;

            // Only support a single light per group
            int cascadeIndex = 0;
            [unroll] 
            for(int i = 0; i < TCascadeCount - 1; i++) 
            {
                [flatten]
                if (streams.DepthVS > CascadeDepthSplits[cascadeIndexBase + i])
                {
                    cascadeIndex = i + 1;
                }
            }
            float3 shadow = 1.0;
            float tempThickness = 999.0;

            // Offset the shadow position
            float3 shadowPosition = position.xyz;
            shadowPosition += GetShadowPositionOffset(OffsetScales[lightIndex], streams.NdotL, streams.normalWS);
            // If we are within the cascades
            if (cascadeIndex < TCascadeCount)
            {
                shadow = ComputeShadowFromCascade(shadowPosition, cascadeIndex, lightIndex);

                if(TComputeTransmittance)
                {
                    tempThickness = ComputeThicknessFromCascade(streams.PositionWS.xyz,
                                                                streams.meshNormalWS,  // Use the vertex normal, not the normal map normal.
                                                                cascadeIndex,
                                                                lightIndex,
                                                                true);
                }

                float nextSplit = CascadeDepthSplits[cascadeIndexBase + cascadeIndex];
                float splitSize = nextSplit;
                if(cascadeIndex > 0)
                {
                    splitSize = nextSplit - CascadeDepthSplits[cascadeIndexBase + cascadeIndex - 1];
                }
                float splitDist = (nextSplit - streams.DepthVS) / splitSize;

                if (splitDist < 0.2)
                {
                    float lerpAmt = smoothstep(0.0, 0.2, splitDist);
                
                    if (cascadeIndex == TCascadeCount - 1)
                    {
                        if (!TDepthRangeAuto)
                        {
                            shadow = lerp(1.0f, shadow, lerpAmt);

                            if(TComputeTransmittance)
                            {
                                tempThickness = lerp(0.0, tempThickness, lerpAmt);
                            }
                        }
                    }
                    else if (TBlendCascades)
                    {
                        float nextShadow = ComputeShadowFromCascade(shadowPosition, cascadeIndex + 1, lightIndex);
                        shadow = lerp(nextShadow, shadow, lerpAmt);

                        if(TComputeTransmittance)
                        {
                            float nextThickness = ComputeThicknessFromCascade(streams.PositionWS.xyz,
                                                                              streams.meshNormalWS,  // Use the vertex normal, not the normal map normal.
                                                                              cascadeIndex + 1,
                                                                              lightIndex,
                                                                              true);

                            tempThickness = lerp(nextThickness, tempThickness, lerpAmt);
                        }
                    }
                }
            }
            
            streams.thicknessWS = tempThickness;
            
            // Output the shadow color
            if (TCascadeDebug)
            {
                //// Display Cascade with colors in debug mode
                ////                                   GREEN          BLUE           PURPLE         RED            WHITE           
                static const float3 colors[5] = { float3(0,1,0), float3(0,0,1), float3(1,0,1), float3(1,0,0), float3(1,1,1)};
                return colors[cascadeIndex] * shadow;
            }

            return shadow;
        }
    };
}
