// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.

namespace Stride.Rendering.Images
{
    /// <summary>
    /// Lens flare artifact shader. 
    /// </summary>
    internal shader FlareArtifactShader<int TapCount> : ImageEffectShader
    {
        // Amount of blending
        float Amount;

        // Offsets (zoom factor) and distortion of each tap
        float2 ZoomOffsetsDistortions[TapCount];

        // Modulate the color of each tap
        float3 ColorAberrations[TapCount];

        // Aberration strength
        float AberrationStrength = 0;

        stage override float4 Shading()
        {
            float2 uv = streams.TexCoord;

            float2 fromCenterVector = uv - float2(0.5, 0.5);
            float squareDistanceToCenter = dot(fromCenterVector, fromCenterVector);
            float distanceToCenter = sqrt(squareDistanceToCenter);

            float2 originalUV = uv;

            float3 result = float3(0.0, 0.0, 0.0);

            [unroll]
            for (int i = 0; i < TapCount; i++) 
            {
                // Zoom effect
                float2 zoomOffsetsDistortions = ZoomOffsetsDistortions[i];
                uv = ( originalUV - 0.5) * zoomOffsetsDistortions.x + 0.5;

                // Distort UV around the center
                float distortion = sin(pow(distanceToCenter, zoomOffsetsDistortions.y)); // NOTE: Introducing the zoomOffsetsDistortions local variable prevents glLinkProgram from freezing on some android devices
                float2 distortedUV = distortion * (uv - 0.5) + 0.5;
                float3 tapColor = Texture0.Sample(LinearSampler, distortedUV).rgb;

                // Avoid hard cuts on the edge (vignetting-like)
                float border = 0.1;
                float2 borderNear = lerp( float2(0.0, 0.0), float2(1.0, 1.0), (0.5 - abs(distortedUV - 0.5)) / border);
                float alpha = saturate(borderNear.x * borderNear.y);
                tapColor *= alpha * alpha;

                // Avoid bleeding (could be clamp to border instead)
                if (distortedUV.x < 0 || distortedUV.x > 1 || distortedUV.y < 0 || distortedUV.y > 1) tapColor = float3(0.0, 0.0, 0.0);

                /*
                // Debug colors
                if (i == 0 || i == 3) tapColor.r *= 10;
                if (i == 1 || i == 4) tapColor.g *= 10;
                if (i == 2 || i == 3 || i == 4) tapColor.b *= 10;
                */

                // Aberration
                tapColor *= lerp( float3(1, 1, 1), ColorAberrations[i], AberrationStrength );

                result += tapColor;
            }

            return float4(result * Amount, 1.0);
        }
    };
}
