// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.

// Define to visualize debug colors for the different CoC levels.
#define DEBUG_COC_LEVEL_COLOR 0

namespace Stride.Rendering.Images
{
    /// <summary>
    /// This takes in entry several blurred levels, and depending on the pixel CoC,
    /// the final color will be an interpolation between 2 of these levels. 
    /// Level 0 is the original sharp image. The last level is the blurriest version. 
    /// Expects as input:
    /// - Texture0: a [CoC, Linear Depth] buffer
    /// - Texture1 ~ TextureX: the different blur levels. (0 == no blur)
    /// </summary>
    ///
    /// <typeparam name="TLevelCount">Total number of layers used, including the original non-blurred image.</typeparam>

    shader CombineLevelsFromCoCShader<int TLevelCount> : ImageEffectShader
    {
        // The CoC corresponding to each level of blur
        stage float CoCLevelValues[TLevelCount];

        stage override float4 Shading()
        {
            // Need to be able to access blur textures by index
            //Texture2D dofTextureLevels[8] = 
            //{
            //    Texture2,
            //    Texture3,
            //    Texture4,
            //    Texture5,
            //    Texture6,
            //    Texture7,
            //    Texture8,
            //    Texture9
            //};

#if DEBUG_COC_LEVEL_COLOR
            // Some debug colors to visualize each layer
            float3 debugColors[8] = 
            {
                float3(1.0, 1.0, 1.0),
                float3(0.5, 0.5, 1.0),
                float3(0.5, 1.0, 0.5),
                float3(1.0, 0.5, 0.5),
                // Set more colors here
                float3(1.0, 0.0, 0.0),
                float3(1.0, 0.0, 0.0),
                float3(1.0, 0.0, 0.0),
                float3(1.0, 0.0, 0.0)
            };
#endif

            // Fetch all our levels
            float3 colorLevels[8];

			// Note: Manually unrolled until better HLSL2GLSL support
			if (TLevelCount >= 1)
	            colorLevels[0] = Texture2.Sample(LinearSampler, streams.TexCoord).rgb;
			if (TLevelCount >= 2)
				colorLevels[1] = Texture3.Sample(LinearSampler, streams.TexCoord).rgb;
			if (TLevelCount >= 3)
				colorLevels[2] = Texture4.Sample(LinearSampler, streams.TexCoord).rgb;
			if (TLevelCount >= 4)
				colorLevels[3] = Texture5.Sample(LinearSampler, streams.TexCoord).rgb;
			if (TLevelCount >= 5)
				colorLevels[4] = Texture6.Sample(LinearSampler, streams.TexCoord).rgb;
			if (TLevelCount >= 6)
				colorLevels[5] = Texture7.Sample(LinearSampler, streams.TexCoord).rgb;
			if (TLevelCount >= 7)
				colorLevels[6] = Texture8.Sample(LinearSampler, streams.TexCoord).rgb;
			if (TLevelCount >= 8)
				colorLevels[7] = Texture9.Sample(LinearSampler, streams.TexCoord).rgb;

            [unroll]
            for (int k = 0; k < TLevelCount; k++) 
            {
				//colorLevels[k] = dofTextureLevels[k].Sample(LinearSampler, streams.TexCoord).rgb;
#if DEBUG_COC_LEVEL_COLOR
                // Affects a debug color
                colorLevels[k] *= debugColors[k];
#endif
            }

            // Gets the CoC of the current pixel
            float CoC = abs(Texture0.Sample(LinearSampler, streams.TexCoord).x);

            // If the pixel is not in focus, use a blur version of the CoC to avoid sharp transitions
            float blurredCoC = Texture1.Sample(LinearSampler, streams.TexCoord).x;
            CoC = lerp(CoC, blurredCoC, sign(blurredCoC));

            float3 result = float3(0.0, 0.0, 0.0);

            // We now find the 2 levels closest to the pixel CoC.
            // We go down the levels, starting at the blurriest version. Once we find a level pair 
            // whose range contains our CoC, we keep the lerp between these 2 levels. 
            // (This part also supports a branch-less version.)
            [unroll]
            for (int i = TLevelCount - 2; i >= 0; i--) 
            {
                // Current range we consider
                float rangeMin = CoCLevelValues[i];
                float rangeMax = CoCLevelValues[i + 1];

                // Does our CoC belong to this range?
                float cocInRange = ((rangeMin < CoC && CoC <= rangeMax) || (rangeMin == CoC && rangeMin == 0))? 1.0 : 0.0;
                // Here is the same test in a branch-less version for reference:
                // float cocInRange = step(rangeMin, CoC) * step(CoC, rangeMax) * sign( abs(CoC - rangeMin));
                // cocInRange += (1.0 - sign(rangeMin)) * (1.0 - sign(CoC)); //Special edge-case for CoC 0

                // We calculate the lerp factor between the 2 levels.
                float lerpFactor = clamp( (CoC - rangeMin) / (rangeMax - rangeMin), 0.0, 1.0 ); // try smoothstep()?

                // We keep the lerp result only if the current level pair contains our CoC
                result += cocInRange * lerp(colorLevels[i], colorLevels[i+1], lerpFactor);
            }

            return float4( result, 1.0 );
        }
    };
}
