// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.

namespace Stride.Rendering.Images
{
    /// <summary>
    /// Replicates lens flare artifacts around. 
    /// </summary>
    internal shader FlareReplicate : ImageEffectShader
    {

        // Amount of blending
        float Amount;

        // Halo factor
        float HaloFactor;

        stage override float4 Shading()
        {
            float3 result = float3(0.0, 0.0, 0.0);
            float2 originalUV = streams.TexCoord;
            float2 uv = originalUV;

            // Initial flares
            result += softBorderTap(uv);

            // Same flares downscaled
            uv = (originalUV - 0.5) * 2.5 + 0.5;
            result += softBorderTap(uv);

            uv = (originalUV - 0.5) * 4.0 + 0.5;
            result += softBorderTap(uv);


            // Symetry with scaling
            uv = (originalUV - 0.5) * -4.5 + 0.5;
            result += softBorderTap(uv);

            uv = (originalUV - 0.5) * -8.0 + 0.5;
            result += softBorderTap(uv);


            // Add some scale of the original bright pass + double-halo
            uv = ( originalUV - 0.5) * -1.0 + 0.5;
            result += Texture1.Sample(LinearSampler, uv).rgb * HaloFactor;

            uv = ( originalUV - 0.5) * -0.05 + 0.5;
            result += Texture1.Sample(LinearSampler, uv).rgb * Amount;

            uv = ( originalUV - 0.5) * 0.1 + 0.5;
            result += Texture1.Sample(LinearSampler, uv).rgb * HaloFactor * 0.5;

            return float4(result, 1.0);
        }

        float3 softBorderTap(float2 uv) 
        {
            float border = 0.18;
            float2 borderNear = lerp( float2(0.0, 0.0), float2(1.0, 1.0), (0.5 - abs(uv - 0.5)) / border);
            float alpha = saturate(borderNear.x * borderNear.y);
            float3 result = Texture0.Sample(LinearSampler, uv).rgb * alpha;
            if (uv.x < 0.0 || uv.x > 1.0 || uv.y < 0.0 || uv.y > 1.0) result = float3(0.0, 0.0, 0.0);
            return result;
        }
    };
}
