// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.

namespace Stride.Rendering.LightProbes
{
    /// <summary>
    /// Defines a skybox environment light
    /// </summary>
    shader LightProbeShader<int TOrder> : EnvironmentLight, MaterialPixelShadingStream, NormalStream, Transformation, ShaderBaseStream, PositionStream4, SphericalHarmonicsUtils<TOrder>
    {
        cbuffer PerView.LightProbes
        {
            stage int IgnoredProbeStart;
        }
        rgroup PerView.LightProbes
        {
#ifdef STRIDE_MULTISAMPLE_COUNT 
  #if STRIDE_MULTISAMPLE_COUNT > 1
            stage Texture2DMS<uint, STRIDE_MULTISAMPLE_COUNT> LightProbeTetrahedronIds;  // SV_Position => tetrahedron ID
  #else
            stage Texture2D<uint> LightProbeTetrahedronIds;  // SV_Position => tetrahedron ID
  #endif
#else
            stage Texture2DMS<uint> LightProbeTetrahedronIds;  // SV_Position => tetrahedron ID
#endif
            stage Buffer<uint4> LightProbeTetrahedronProbeIndices; // tetrahaedron ID => 4 probe IDs
            stage Buffer<float4> LightProbeTetrahedronMatrices; // tetrahaedron ID => world to tetrahedron space matrix
            stage Buffer<float3> LightProbeCoefficients; // probe ID => SH coefficients
        }

        void FetchLightProbe(inout float3 sphericalColors[TOrder * TOrder], uint lightprobeIndex, float weight)
        {
            // Early exit
            if (weight == 0.0f)
                return;

            int lightprobeIndexStart = lightprobeIndex * TOrder * TOrder;
            for (int i = 0; i < TOrder * TOrder; ++i)
            {
                // TODO: Need float4() because type inference don't work on generics Buffer<T>.Load() properly
                sphericalColors[i] += LightProbeCoefficients.Load(lightprobeIndexStart + i).rgb * weight;
            }
        }

        override void PrepareEnvironmentLight()
        {
            base.PrepareEnvironmentLight();

            var ambientAccessibility = streams.matAmbientOcclusion;

            var sampleDirection = streams.normalWS;
            sampleDirection = float3(sampleDirection.xy, -sampleDirection.z);

            var shadingPosition = int3(streams.ShadingPosition.xy, 0);
#if STRIDE_MULTISAMPLE_COUNT == 1
            uint tetrahedronIndex = LightProbeTetrahedronIds.Load(shadingPosition);
#else
            // TODO: Use SV_SampleIndex
            uint tetrahedronIndex = LightProbeTetrahedronIds.Load(shadingPosition, 0);
#endif

            uint4 probeIndices = LightProbeTetrahedronProbeIndices.Load(tetrahedronIndex);
            float3x4 tetrahedronMatrix = float3x4(LightProbeTetrahedronMatrices.Load(tetrahedronIndex * 3 + 0),
                                                 LightProbeTetrahedronMatrices.Load(tetrahedronIndex * 3 + 1),
                                                 LightProbeTetrahedronMatrices.Load(tetrahedronIndex * 3 + 2));

            float3 tetrahedronFactors3 = mul((float3x3)tetrahedronMatrix, streams.PositionWS.xyz - tetrahedronMatrix._14_24_34);

            // Protect ourselves against degenerate cases
            // TODO: Investigate why those happen (almost coplanar tetrahedron?)
            tetrahedronFactors3 = saturate(tetrahedronFactors3);

            float4 tetrahedronFactors4 = float4(tetrahedronFactors3, 1.0f - tetrahedronFactors3.x - tetrahedronFactors3.y - tetrahedronFactors3.z);

            // Zero all the barycentric coordinates that reference probes past IgnoredProbeStart (the one far away)
            tetrahedronFactors4 = lerp(tetrahedronFactors4, 0.0f, probeIndices >= IgnoredProbeStart ? float4(1.0f, 1.0f, 1.0f, 1.0f) : float4(0.0f, 0.0f, 0.0f, 0.0f));

            // Renormalize barycentric coordinates
            var totalSum = tetrahedronFactors4.x + tetrahedronFactors4.y + tetrahedronFactors4.z + tetrahedronFactors4.w;
            if (totalSum > 0.0f)
                tetrahedronFactors4 /= totalSum;

            float3 sphericalColors[TOrder * TOrder];
            for (int i = 0; i < TOrder * TOrder; ++i)
                sphericalColors[i] = 0.0f;

            FetchLightProbe(sphericalColors, probeIndices.x, tetrahedronFactors4.x);
            FetchLightProbe(sphericalColors, probeIndices.y, tetrahedronFactors4.y);
            FetchLightProbe(sphericalColors, probeIndices.z, tetrahedronFactors4.z);
            FetchLightProbe(sphericalColors, probeIndices.w, tetrahedronFactors4.w);

            streams.envLightDiffuseColor = EvaluateSphericalHarmonics(sphericalColors, sampleDirection).rgb * ambientAccessibility * streams.matDiffuseSpecularAlphaBlend.x;
            // TEST:
            //streams.envLightDiffuseColor = LightProbeCube.Sample(Texturing.LinearSampler, sampleDirection).rgb * ambientAccessibility * streams.matDiffuseSpecularAlphaBlend.y;
            //streams.envLightDiffuseColor = tetrahedronFactors3;
            //streams.envLightDiffuseColor = float4(tetrahedronIndex.xxx * 0.2, 1.0);
        }
    };
}
