// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.

namespace Stride.Rendering.Images
{
    /// <summary>
    /// The first pass of a shader performing Lambertian pre-filtering using Spherical Harmonics
    /// </summary>
    shader LambertianPrefilteringSHNoComputePass1<int THarmonicsOrder> : SphericalHarmonicsBase<THarmonicsOrder>, ImageEffectShader, Texturing
    {
        // the input texture containing the radiance
        TextureCube RadianceMap;

		// Index of the spherical harmonics coefficient to compute
		int CoefficientIndex;

        // The Cosine kernel factors
        static const float A0 = 1.0;
        static const float A1 = 2.0 / 3.0;
        static const float A2 = 1.0 / 4.0;
        static const float A3 = 0.0;
        static const float A4 = -1.0 / 24.0;
        static const float A[5 * 5] = 
        { 
            A0, 
            A1, A1, A1, 
            A2, A2, A2, A2, A2,
            A3, A3, A3, A3, A3, A3, A3,
            A4, A4, A4, A4, A4, A4, A4, A4, A4
        };

        stage override float4 Shading()
        {
			float3 result = 0;

			float2 uv = streams.TexCoord * 2.0 - 1.0;

			// Calculate weight
	        float dist = 1.0f + dot(uv, uv);
	        float weight = 4.0f / (sqrt(dist) * dist);

			[unroll]
			for (int faceIndex = 0; faceIndex < 6; faceIndex++)
			{
				// Extract direction from texel u, v
				float3 dirVS = normalize(uvToDirectionVS(uv.x, uv.y, faceIndex));

				// Calculates the values of the SH bases
				EvaluateSHBases(dirVS);

				float3 radiance = RadianceMap.Sample(PointSampler, dirVS).xyz;

				result += A[CoefficientIndex] * streams.SHBaseValues[CoefficientIndex] * radiance * weight;
			}

			return float4(result, weight * 6);
		}

        float3 uvToDirectionVS(float u, float v, int viewIndex)
        {
            if (viewIndex == 0)
                return float3(1, -v, -u); // face X
            if (viewIndex == 1)
                return float3(-1, -v, u); // face -X
            if (viewIndex == 2)
                return float3(u, 1, v); // face Y
            if (viewIndex == 3)
                return float3(u, -1, -v); // face -Y
            if (viewIndex == 4)
                return float3(u, -v, 1); // face Z
            if (viewIndex == 5)
                return float3(-u, -v, -1); // face -Z
        
            return 0;
        }
	};
}
