// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.

namespace Stride.Rendering.Images
{
    /// <summary>
	/// Screen Space Local Reflections shader for Temporal Pass
    /// </summary>
    shader SSLRTemporalPass : ImageEffectShader, Texturing
    {
        stage float TemporalResponse;
        stage float TemporalScale;
        stage float4x4 IVP; // Current frame inverse view*projection matrix
        stage float4x4 prevVP; // Previous frame view*projection matrix
        
        // 1:-1 to 0:1
        float2 ClipToUv(float2 clipPos)
        {
			return clipPos * float2(0.5, -0.5) + float2(0.5, 0.5);
        }
		
        // 0:1 to 1:-1
        float2 UvToClip(float2 uv)
        {
			return uv * float2(2, -2) + float2(-1, 1);
        }

        float3 ComputeWorldPosition(float2 uv, float rawDepth)
        {
			float4 clipPos = float4(UvToClip(uv), rawDepth, 1);
			float4 pos = mul(clipPos, IVP);
			return pos.xyz / pos.w;
        }

        float3 SampleWorldPosition(float2 uv)
        {
			float rawDepth = Texture2.SampleLevel(PointSampler, uv, 0).r;
			return ComputeWorldPosition(uv, rawDepth);
        }

        override stage float4 Shading()
        {
			// Inputs Mapping:
			// Texture0 - Resolved reflections
			// Texture1 - Previous frame resolved reflections
			// Texture2 - Depth

			float2 uv = streams.TexCoord;

            // Reconstruct previous frame screen space position
            float3 posWS = SampleWorldPosition(uv);
            float4 prevSS = mul(float4(posWS, 1), prevVP);
            prevSS.xy /= prevSS.w;

            float2 prevUV = ClipToUv(prevSS.xy);

            float4 current = Texture0.SampleLevel(LinearSampler, uv, 0);
            float4 previous = Texture1.SampleLevel(LinearSampler, prevUV, 0);

            float2 du = float2(Texture0TexelSize.x, 0.0);
            float2 dv = float2(0.0, Texture0TexelSize.y);

            float4 currentTopLeft = Texture0.SampleLevel(LinearSampler, uv.xy - dv - du, 0);
            float4 currentTopCenter = Texture0.SampleLevel(LinearSampler, uv.xy - dv, 0);
            float4 currentTopRight = Texture0.SampleLevel(LinearSampler, uv.xy - dv + du, 0);
            float4 currentMiddleLeft = Texture0.SampleLevel(LinearSampler, uv.xy - du, 0);
            float4 currentMiddleCenter = Texture0.SampleLevel(LinearSampler, uv.xy, 0);
            float4 currentMiddleRight = Texture0.SampleLevel(LinearSampler, uv.xy + du, 0);
            float4 currentBottomLeft = Texture0.SampleLevel(LinearSampler, uv.xy + dv - du, 0);
            float4 currentBottomCenter = Texture0.SampleLevel(LinearSampler, uv.xy + dv, 0);
            float4 currentBottomRight = Texture0.SampleLevel(LinearSampler, uv.xy + dv + du, 0);

            float4 currentMin = min(currentTopLeft, min(currentTopCenter, min(currentTopRight, min(currentMiddleLeft, min(currentMiddleCenter, min(currentMiddleRight, min(currentBottomLeft, min(currentBottomCenter, currentBottomRight))))))));
            float4 currentMax = max(currentTopLeft, max(currentTopCenter, max(currentTopRight, max(currentMiddleLeft, max(currentMiddleCenter, max(currentMiddleRight, max(currentBottomLeft, max(currentBottomCenter, currentBottomRight))))))));

            float scale = TemporalScale;

            float4 center = (currentMin + currentMax) * 0.5f;
            currentMin = (currentMin - center) * scale + center;
            currentMax = (currentMax - center) * scale + center;

            previous = clamp(previous, currentMin, currentMax);

            return lerp(current, previous, TemporalResponse);
        }
    };
}