// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.

namespace Stride.Rendering.Shadows
{
    shader ShadowMapCasterParaboloidProjection : TransformationBase, PositionStream4, Texturing
    {
        static const float ClippingEpsilon = 0.03;

        cbuffer PerView.ShadowCaster
        {
            // x = Near; y = 1/(Far-Near)
            float2 DepthParameters;
        }

        // Used to write the distance from an object to the light to the depth buffer
        stage stream float PixelDepth : SV_DEPTH;
         
        stage override void PostTransformPosition()
        {
            streams.ShadingPosition = ComputeParaboloidProjection(streams.PositionWS, streams.DepthVS);
        }

        stage override float4 ComputeShadingPosition(float4 world)
        {
            float dummy;
            return ComputeParaboloidProjection(world, dummy);
        }

        float4 ComputeParaboloidProjection(float4 world, out float depth)
        {
            // Project into light view space
            float4 lightSpace = mul(world, Transformation.View);

            // Store length and normalize
            float distanceToLight = length(lightSpace.xyz);
            float3 intermediate = lightSpace.xyz / distanceToLight;

            // Project x/y coordinates on parabola
            intermediate.xy /= 1.0f+intermediate.z;

            // 2 different depth values
            // The first one is the depth written to the depth buffer (always positive, since it is the distance to the point light)
            // The second one is used for clipping (world space along the light's z-axis)
            float2 depthValues = float2(distanceToLight, lightSpace.z + ClippingEpsilon) * DepthParameters.y;

            // Send projected depth to pixel shader
            depth = depthValues.x;

            return float4(intermediate.xy, depthValues.y, 1);
        }

        stage override void PSMain()
        {
            base.PSMain();

            streams.PixelDepth = streams.DepthVS;
        }
    };
}
