// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.

namespace Stride.Rendering.Images
{
    /// <summary>
	/// Screen Space Local Reflections shader for Ray Trace Pass
    /// </summary>
    shader SSLRRayTracePass : ImageEffectShader, SSLRCommon, NormalPack, Math
    {
        cbuffer Data
        {
            stage float EdgeFadeFactor;

			stage float4x4 VP;
        };

		// go into clip space (-1:1 from bottom/left to up/right)
        float3 ProjectWorldToClip(float3 wsPos)
        {
            float4 uv = mul(float4(wsPos, 1), VP);
            uv /= uv.w;
            return uv.xyz;
        }
		
        // go into UV space. (0:1 from top/left to bottom/right)
        float3 ProjectWorldToUv(float3 wsPos)
        {
            float3 pos = ProjectWorldToClip(wsPos);
            return float3(ClipToUv(pos.xy), pos.z);
        }
		
        float4 TangentToWorld(float3 N, float4 H)
        {
	        float3 UpVector = abs(N.z) < 0.999 ? float3(0.0, 0.0, 1.0) : float3(1.0, 0.0, 0.0);
	        float3 T = normalize( cross( UpVector, N ) );
	        float3 B = cross( N, T );
				 
	        return float4((T * H.x) + (B * H.y) + (N * H.z), H.w);
        }

        // Brian Karis, Epic Games "Real Shading in Unreal Engine 4"
        float4 ImportanceSampleGGX(float2 Xi, float Roughness)
        {
	        float m = Roughness * Roughness;
	        float m2 = m * m;
		
	        float Phi = 2 * PI * Xi.x;
				 
	        float CosTheta = sqrt((1.0 - Xi.y) / (1.0 + (m2 - 1.0) * Xi.y));
	        float SinTheta = sqrt(max(1e-5, 1.0 - CosTheta * CosTheta));
				 
	        float3 H;
	        H.x = SinTheta * cos(Phi);
	        H.y = SinTheta * sin(Phi);
	        H.z = CosTheta;
		
	        float d = (CosTheta * m2 - CosTheta) * CosTheta + 1;
	        float D = m2 / (PI * d * d);
	        float pdf = D * CosTheta;

	        return float4(H, pdf); 
        }

        float RayAttenBorder(float2 pos, float value)
        {
	        float borderDist = min(1.0 - max(pos.x, pos.y), min(pos.x, pos.y));
	        return saturate(borderDist > value ? 1.0 : borderDist / value);
        }

        override stage float4 Shading()
        {
			// Inputs Mapping:
			// Texture0 - Scene Color
			// Texture1 - Depth
			// Texture2 - World Space Normals
			// Texture3 - Specular Color + Roughness

			float2 uv = streams.TexCoord;
			
			// Sample material roughness
			float4 specularRoughnessBuffer = Texture3.SampleLevel(PointSampler, uv, 0);
			float roughness = specularRoughnessBuffer.a;
			
			// Get view space position
			float depth = SampleZ(uv);
			float3 positionVS = ComputeViewPosition(uv, depth);
			
			// Reject invalid pixels
			if(positionVS.z > 100.0f || roughness > RoughnessFade)
				return 0;
			
			// Calculate view space normal vector
			float4 normalsBuffer = Texture2.SampleLevel(PointSampler, uv, 0);
			float3 normalWS = DecodeNormal(normalsBuffer.rgb);
			float3 normalVS = mul(normalWS, (float3x3)V);

            // Randomize it a little
            float2 jitter = RandN2(uv, TemporalTime);
		    float2 Xi = jitter;
		    Xi.y = lerp(Xi.y, 0.0, BRDFBias);

            float4 H = TangentToWorld(normalWS, ImportanceSampleGGX(Xi, roughness));

			// Calculate normalized view space reflection vector
			float3 reflectVS = normalize(reflect(positionVS, normalVS));

            if(positionVS.z < 1.0 && reflectVS.z < 0.4)
                return 0;

			float3 positionWS = ComputeWorldPosition(uv, depth);
			float3 viewWS = normalize(positionWS - CameraPosWS.xyz);
			float3 reflectWS = reflect(viewWS, H.xyz);

			float3 startWS = positionWS + normalWS * WorldAntiSelfOcclusionBias;
            float3 startUV = ProjectWorldToUv(startWS);
            float3 endUV = ProjectWorldToUv(startWS + reflectWS);
			
            float3 rayUV = endUV - startUV;
			float screenStep = Texture1TexelSize.x;
			rayUV *= screenStep / max2(abs(rayUV.xy));
			float3 startUv = startUV + rayUV * 2;
			
			float3 currOffset = startUv;
			float3 rayStep = rayUV * 2;
			
			// Calculate number of samples
			float3 samplesToEdge = ((sign(rayStep.xyz) * 0.5 + 0.5) - currOffset.xyz) / rayStep.xyz;
			samplesToEdge.x = min(samplesToEdge.x, min(samplesToEdge.y, samplesToEdge.z)) * 1.05f;
			float numSamples = min(MaxTraceSamples, samplesToEdge.x);
			rayStep *= samplesToEdge.x / numSamples;

			// Calculate depth diffrence error
			float depthDiffError = 1.3f * abs(rayStep.z);

			// Ray trace
			float currSampleIndex = 0;
			float currSample, depthDiff;
			[loop]
			while (currSampleIndex < numSamples)
			{
				// Sample depth buffer and calculate depth diffrence
				currSample = SampleZ(currOffset.xy);
				depthDiff = currOffset.z - currSample;
				
				// Check intersection
				if(depthDiff >= 0)
				{
                    if (depthDiff < depthDiffError)
                    {
                        break;
                    }
                    else
                    {
                        currOffset -= rayStep;
                        rayStep *= 0.5;
                    }
				}

				// Move forward
				currOffset += rayStep;
                currSampleIndex++;
			}
			
			// Check if has valid result after ray traycing
			if(currSampleIndex >= numSamples || currOffset.z > 0.999)
			{
				// All samples done but no result
				return 0;
			}

			float2 hitUV = currOffset.xy;
			
			// Fade rays close to screen edge
			const float fadeStart = 0.9f;
			const float fadeEnd = 1.0f;
			const float fadeDiffRcp = 1.0f / (fadeEnd - fadeStart);
			float2 boundary = abs(hitUV - float2(0.5f, 0.5f)) * 2.0f;
			float fadeOnBorder = 1.0f - saturate((boundary.x - fadeStart) * fadeDiffRcp);
			fadeOnBorder *= 1.0f - saturate((boundary.y - fadeStart) * fadeDiffRcp);
			fadeOnBorder = smoothstep(0.0f, 1.0f, fadeOnBorder);
			fadeOnBorder *= RayAttenBorder(hitUV, EdgeFadeFactor);

			// Fade rays on high roughness
			float roughnessFade = saturate((RoughnessFade - roughness) * 20);

            // Output: xy: hitUV, z: hitMask, w: unused
			return float4(hitUV, fadeOnBorder * roughnessFade, 0);
        }
    };
}