// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.

namespace Stride.Rendering.Images
{
    /// <summary>
    /// The first pass of a shader performing Lambertian pre-filtering using Spherical Harmonics
    /// </summary>
    shader LambertianPrefilteringSHPass1<int TBlockSize, int THarmonicsOrder> : SphericalHarmonicsBase<THarmonicsOrder>, ComputeShaderBase, Texturing
    {    
        // the input texture containing the radiance
        TextureCube<float4> RadianceMap;

        // the output buffer containing SH coefficient partially summed.
        RWBuffer<float4> OutputBuffer;

        // The Cosine kernel factors
        static const float A0 = 1.0;
        static const float A1 = 2.0/3.0;
        static const float A2 = 1.0/4.0;
        static const float A3 = 0.0;
        static const float A4 = -1.0/24.0;
        static const float A[5*5] = 
        { 
            A0, 
            A1, A1, A1, 
            A2, A2, A2, A2, A2,
            A3, A3, A3, A3, A3, A3, A3,
            A4, A4, A4, A4, A4, A4, A4, A4, A4
        };
    
        // Shared memory for summing SH-Basis coefficients for a block
        groupshared float4 PartialSHCoeffs[TBlockSize][TBlockSize][CoefficientsCount];

        // Projects radiance on SH basis and sums results along rows.
        override void Compute()
        {
	        // Determine the indices of the texel to compute
	        const int3 location = int3(streams.GroupThreadId.xy + streams.GroupId.xy * TBlockSize, streams.GroupId.z);

	        // Calculate the location in [-1, 1] texture space (center at the pixel center)
            float inverseSize = 1 / float(TBlockSize * streams.ThreadGroupCount.x);
	        float u = ((location.x+0.5) * inverseSize) * 2.0f - 1.0f;
	        float v = ((location.y+0.5) * inverseSize) * 2.0f - 1.0f;

	        // Extract direction from texel u,v
            float3 dirVS = normalize(uvToDirectionVS(u, v, location.z));
	        float3 radiance = RadianceMap.SampleLevel(Texturing.PointSampler, dirVS, 0).xyz;

	        // Calculate weight
	        var dist = 1.0f + u * u + v * v;
	        var weight = 4.0f / (sqrt(dist) * dist);
	        radiance *= weight;

            // Calculates the values of the SH bases
            EvaluateSHBases(dirVS);

	        // Store the results in the shared memory
            [unroll]
            for(int c=0; c<CoefficientsCount; ++c)
            {
	            PartialSHCoeffs[streams.GroupThreadId.x][streams.GroupThreadId.y][c] = float4(A[c] * streams.SHBaseValues[c] * radiance, weight);
            }
	        GroupMemoryBarrierWithGroupSync();

	        // Sum the coefficients along the block columns
		    if(streams.GroupThreadId.y == 0)
            {
	            for(int col=1; col<TBlockSize; ++col)
	            {
                    [unroll]
                    for(int c=0; c<CoefficientsCount; ++c)
                    {
	                    PartialSHCoeffs[0][streams.GroupThreadId.x][c] += PartialSHCoeffs[col][streams.GroupThreadId.x][c];
                    }
	            }
            }
		    GroupMemoryBarrierWithGroupSync();

	        // Sum the columns results
	        if(IsFirstThreadOfGroup())
	        {
	            for(int r=1; r<TBlockSize; ++r)
	            {
                    [unroll]
                    for(int c=0; c<CoefficientsCount; ++c)
                    {
	                    PartialSHCoeffs[0][0][c] += PartialSHCoeffs[0][r][c];
                    }
	            }
            }

	        // Have the first thread write out to the output texture
	        if(IsFirstThreadOfGroup())
	        {
                int2 groupCount = streams.ThreadGroupCount.xy;
                int indexBias = CoefficientsCount * (streams.GroupId.x + streams.GroupId.y * groupCount.x + groupCount.x * groupCount.y * location.z);

                [unroll]
		        for(int c = 0; c < CoefficientsCount; ++c)
		        {
			        OutputBuffer[indexBias + c] = PartialSHCoeffs[0][0][c];
		        }
	        }
        }
        
        float3 uvToDirectionVS(float u, float v, int viewIndex)
        {
            if (viewIndex == 0)
                return float3(1, -v, -u); // face X
            if (viewIndex == 1)
                return float3(-1, -v, u); // face -X
            if (viewIndex == 2)
                return float3(u, 1, v); // face Y
            if (viewIndex == 3)
                return float3(u, -1, -v); // face -Y
            if (viewIndex == 4)
                return float3(u, -v, 1); // face Z
            if (viewIndex == 5)
                return float3(-u, -v, -1); // face -Z
        
            return 0;
        }
    };
}
