// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.

namespace Stride.Rendering.Shadows
{
    /// <summary>
    /// Selects the shadow map and computes the shadow factor.
    /// </summary>
    internal shader ShadowMapReceiverPointCubeMap<int TLightCount> : ShadowMapGroup<PerDraw.Lighting>, ShadowMapFilterBase<PerDraw.Lighting>, PositionStream4, ShaderBaseStream, LightStream, Texturing, NormalStream
    {
        cbuffer PerDraw.Lighting
        {
            float4x4 WorldToShadow[TLightCount*6];
            float4x4 InverseWorldToShadow[TLightCount*6];
            float DepthBiases[TLightCount];
            float OffsetScales[TLightCount];
            float2 DepthParameters[TLightCount];
        };

        // TODO: Deduplicate
        float3 GetShadowPositionOffset(float offsetScale, float nDotL, float3 normal)
        {
            float normalOffsetScale = saturate(1.0f - nDotL);
            return 2.0f * ShadowMapTextureTexelSize.x * offsetScale * normalOffsetScale * normal;
        }

        float ComputeThickness(float3 positionWS, int cascadeIndex)
        {
            // Calculate thickness for SSS:
            float tempThickness = 0.0;

            const bool ComputeThickness = true; // TODO: This should be a mixin parameter or something!
            if(ComputeThickness)
            {
                // TODO: I don't know if the shadow map filtering can be done for cube maps in the same way as for directional lights or spot lights.
                tempThickness = FilterThickness(positionWS,
                                                streams.meshNormalWS,
                                                float2(0.0f, 1.0f), //DepthRanges[lightIndex*6+faceIndex],    // TODO: Currently not needed for perspective shadow maps.
                                                WorldToShadow[cascadeIndex],
                                                InverseWorldToShadow[cascadeIndex],
                                                false);
            }

            return tempThickness;
        }

        override float3 ComputeShadow(float3 positionWS, int lightIndex)
        {
            // Calculate shadow:
            float3 lightPosition = LightPointGroup<TLightCount>.Lights[lightIndex].PositionWS.xyz;
            float3 lightDelta = positionWS.xyz - lightPosition;
            float distanceToLight = length(lightDelta);
            float3 direction = lightDelta / distanceToLight;
            float3 directionAbs = abs(direction);
            
            float longestAxis = max(directionAbs.x, max(directionAbs.y, directionAbs.z));

            int faceIndex;
            float lightSpaceZ;
            
            // Select the base face index for either X,Y or Z facing
            [flatten]
            if(directionAbs.x == longestAxis)
            {
                lightSpaceZ = lightDelta.x;
                faceIndex = 2;
            }
            else if(directionAbs.y == longestAxis)
            {
                lightSpaceZ = lightDelta.y;
                faceIndex = 4;
            }
            else // direction.z == longestAxis
            {
                lightSpaceZ = lightDelta.z;
                faceIndex = 0;
            }

            // Apply offset for the negative side of a direction (+1)
            float lightSpaceZDirection = sign(lightSpaceZ);
            faceIndex += int(-min(0.0, lightSpaceZDirection));


            int cascadeIndex = lightIndex * 6 + faceIndex;

            // Compute the thickness before modifying "positionWS":
            streams.thicknessWS = ComputeThickness(positionWS, cascadeIndex);


            // Apply normal scaled bias
            positionWS += GetShadowPositionOffset(OffsetScales[lightIndex], streams.NdotL, streams.normalWS);

            // Map to texture space
            float4 projectedPosition = mul(float4(positionWS,1), WorldToShadow[cascadeIndex]);
            projectedPosition /= projectedPosition.w;

            // Apply bias in view space
            lightSpaceZ = abs(lightSpaceZ);
            lightSpaceZ -= DepthBiases[lightIndex];

            // Project view space depth into the same space as the shadow map
            float depth = DepthParameters[lightIndex].x + (DepthParameters[lightIndex].y / lightSpaceZ);

            if(depth < 0 || depth > 1)
                return 1;

            // Compare distance to light to value inside of the shadow map
            float shadow = FilterShadow(projectedPosition.xy, depth);
            
            return(shadow);
        }
    };
}
