// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.

namespace Stride.Rendering.Images
{
    /// <summary>
	/// Screen Space Local Reflections shader for Resolve Pass
    /// </summary>
    shader SSLRResolvePass<int ResolveSamples, bool ReduceHighlights> : ImageEffectShader, SSLRCommon, NormalPack, Math, BRDFMicrofacet
    {
        static const float2 Offsets[8] =
        {
            float2( 0,  0),
            float2( 2, -2),
            float2(-2, -2),
            float2( 0,  2),
            float2(-2,  0),
            float2( 0, -2),
            float2( 2,  0),
            float2( 2,  2),
        };

        override stage float4 Shading()
        {
			// Inputs Mapping:
			// Texture0 - Scene Color (with blurred mip maps chain or without)
			// Texture1 - Depth
			// Texture2 - World Space Normals
			// Texture3 - Specular Color + Roughness
			// Texture4 - Ray Trace result
			
			float2 uv = streams.TexCoord;
			
            // Early out for pixels with no hit result
            if(Texture4.SampleLevel(LinearSampler, uv, 0).z <= 0.001)
                return 0;
            
			// Sample material roughness
			float4 specularRoughnessBuffer = Texture3.SampleLevel(PointSampler, uv, 0);
			float roughness = specularRoughnessBuffer.a;
			
			// Get view space position
			float depth = SampleZ(uv);
			float3 positionVS = ComputeViewPosition(uv, depth);
			
			// Reject invalid pixels
			if(positionVS.z > 100.0f || roughness > RoughnessFade)
				return 0;

			// Calculate view space normal vector
			float4 normalsBuffer = Texture2.SampleLevel(PointSampler, uv, 0);
			float3 normalWS = DecodeNormal(normalsBuffer.rgb);
			
            // Calculate view vector
            float3 positionWS = ComputeWorldPosition(uv, depth);
			float3 viewVector = normalize(CameraPosWS.xyz - positionWS);

            // Randomize it a little
            float2 random = RandN2(uv, TemporalTime);
            float2 blueNoise = random.xy * 2.0 - 1.0;
            float2x2 offsetRotationMatrix = float2x2(blueNoise.x, blueNoise.y, -blueNoise.y, blueNoise.x);

            float NdotV = saturate(dot(normalWS, viewVector));
		    float coneTangent = lerp(0.0, roughness * 5 * (1.0 - BRDFBias), pow(NdotV, 1.5) * sqrt(roughness));
		    
            // Resolve samples
		    float4 result = 0.0;
            for(int i = 0; i < ResolveSamples; i++)
            {
                float2 offsetUV = Offsets[i] * Texture4TexelSize;
                offsetUV =  mul(offsetRotationMatrix, offsetUV);

                // "uv" is the location of the current (or "local") pixel. We want to resolve the local pixel using
                // intersections spawned from neighboring pixels. The neighboring pixel is this one:
                float2 neighborUv = uv + offsetUV;

                // Now we fetch the intersection point
                float4 hitPacked = Texture4.SampleLevel(LinearSampler, neighborUv, 0);
                float2 hitUv = hitPacked.xy;
                float hitMask = hitPacked.z;

                float intersectionCircleRadius = coneTangent * length(hitUv - uv);
                float mip = clamp(log2(intersectionCircleRadius * TraceSizeMax), 0.0, MaxColorMiplevel);

                float4 sampleColor = float4(Texture0.SampleLevel(LinearSampler, hitUv, mip).rgb, 1);
                if(ReduceHighlights)
                    sampleColor.rgb /= 1 + Luminance(sampleColor.rgb);

                result += sampleColor * hitMask;
            }

            // Calculate final result value
            result /= ResolveSamples;
            if(ReduceHighlights)
			    result.rgb /= 1 - Luminance(result.rgb);
            result.rgb *= result.a;

    	    return max(1e-5, result);
        }
    };

    effect SSLRResolvePassEffect
    {
        using params SSLRKeys;

        mixin SSLRResolvePass<SSLRKeys.ResolveSamples, SSLRKeys.ReduceHighlights>;
    }
}