// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.

namespace Stride.Rendering.Images
{
    /// <summary>
    /// Shader performing radiance GGX pre-filtering
    /// </summary>
    shader RadiancePrefilteringGGXShader<int TNbOfSamples> : Math, ComputeShaderBase
    {            
        // the input texture containing the radiance
        int RadianceMapSize;
        TextureCube<float4> RadianceMap;

        // the output cube map containing the filtered radiance.
        RWTexture2DArray<float4> FilteredRadiance;
            
        // Shared memory for summing SH-Basis coefficients for a block
        groupshared float4 PrefilteredSamples[TNbOfSamples];
        
        // The number of mipmap available
        stage float MipmapCount;

        // The roughness of the GGX distribution
        stage float Roughness;

        // compute the pre-filtered environment map for input (group) direction
        override void Compute()
        {
            int2 pixel = streams.GroupId.xy;
            int face = streams.GroupId.z;
            int threadId = streams.GroupThreadId.x;

            // Calculate the uv of the pixel in [0, 1]
            float u = (pixel.x + 0.5) / float(streams.ThreadGroupCount.x);
	        float v = (pixel.y + 0.5) / float(streams.ThreadGroupCount.y);

            // Calculate the direction of the texel in the cubemap
            float3 R = normalize(CubemapUtils.ConvertTexcoordsNoFlip(float2(u, v), face));

            // Perform one sampling, calculate pre-filtered color and weight contribution
            var xi = Hammersley.GetSamplePlane(threadId, TNbOfSamples);
            var H = ImportanceSamplingGGX.GetSample(xi, Roughness, R);

            float3 L = 2 * dot( R, H ) * H - R;
            float NoL = saturate( dot( R, L ) );
            float pdf = BRDFMicrofacet.NormalDistributionGGX(Roughness*Roughness, NoL) / 4;
            float omegaS = 1.0 / ( TNbOfSamples * pdf );
            float omegaP = 4.0 * Math.PI / (6.0 * RadianceMapSize * RadianceMapSize ) ;
            float mipLevel = clamp (0.5 * log2 ( omegaS / omegaP ) , 0, MipmapCount );
            
            float3 prefilteredColor = 0;
            float weight = 0;
            if( NoL > 0 )
            {
                weight = NoL;
                prefilteredColor = RadianceMap.SampleLevel(Texturing.LinearSampler, L, mipLevel).rgb * weight;
            }

            // Stock the result in group-shared memory
            PrefilteredSamples[threadId] = float4(prefilteredColor, weight);
		    GroupMemoryBarrierWithGroupSync();

            // Perform the sums among the group
            for(int s = TNbOfSamples / 2; s > 0; s >>= 1) 
            {
		        if(threadId < s)
			        PrefilteredSamples[threadId] += PrefilteredSamples[threadId + s];

		        GroupMemoryBarrierWithGroupSync();
	        }

            // Let the first thread stock the final result in output texture
            if(IsFirstThreadOfGroup())
            {
                FilteredRadiance[float3(pixel.xy, face)] = PrefilteredSamples[0] / PrefilteredSamples[0].w;
            }
        }
    };
}
