// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.

namespace Stride.Rendering.Images
{
    /// <summary>
    /// A gaussian blur shader
    /// </summary>
    internal shader AmbientOcclusionBlurShader<int BlurCount, bool IsVertical, float BlurScale, float EdgeSharpness, bool IsOrthographic> : ImageEffectShader, Camera
    {
        stage float  Weights[BlurCount];

        stage float reconstructCSZ(float depth)
        {
            if (IsOrthographic) //near + z * (far - near)
                return ZProjection.x + depth * ZProjection.y;
            else
                return ZProjection.y / (depth - ZProjection.x);
        }

        stage override float4 Shading()
        {
            const float epsilon = 0.0001;

            // Direction in texel size: (float2(1,0) or float2(0,1)) * texel size
            float2 direction = (IsVertical ? float2(0, 1) : float2(1, 0)) * Texture0TexelSize;

            // Add center
            float totalWeight = Weights[0];
            float3 sum = Texture0.Sample(LinearSampler, streams.TexCoord).rgb * totalWeight;

            float linearDepth = reconstructCSZ(Texture1.Sample(LinearSampler, streams.TexCoord).x);
            if (linearDepth >= 300)
            {
                sum /= (totalWeight + epsilon);
                return float4(sum, 1);
            }

            // mirrored samples using bilinear filtering
            [unroll]
            for (int i = 1; i < BlurCount; i++)
            {
                // Handle both directions
                [unroll]
                for (int j = -1; j <= 1; j += 2)
                {
                    float weight = 0.3 + Weights[i];

                    float value = Texture0.Sample(LinearSampler, streams.TexCoord + direction * j * i * BlurScale).rgb;

                    float linearDepthOther = reconstructCSZ(Texture1.Sample(LinearSampler, streams.TexCoord + direction * j * i * BlurScale).x);
                    weight *= max(0.0, 1.0 - EdgeSharpness * abs(linearDepth - linearDepthOther));

                    sum += value * weight;

                    totalWeight += weight;
                }
            }

            sum /= (totalWeight + epsilon);
            return float4(sum, 1);
        }
    };
}
