// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.
shader TemporalAntiAliasShader : ImageEffectShader
{
    static const int OFFSET_LENGTH = 2;

    cbuffer PerDraw
    {
        float  u_BlendWeightMin;     // default = 1.0 / 8.0
        float  u_BlendWeightMax;     // default = 0.5
        float  u_HistoryBlurAmp;     // default = 2.0
        float  u_LumaContrastFactor; // default = 128.0
        float  u_VelocityDecay;      // default = 0.5
        float  u_WeightCenter;
        float  u_WeightLowCenter;
        float4 u_Weight1; 
        float4 u_Weight2;
        float4 u_WeightLow1; 
        float4 u_WeightLow2;
    }

    // Texture0: color
    // Texture1: depth
    // Texture2: velocity
    // Texture3: color previous frame (blurred)
    stage override float4 Shading()
    {
        var texCoord = streams.TexCoord;
        
        // fetch position of current color 
        float centerDepth = Texture1.SampleLevel(PointSampler, streams.TexCoord, 0).r;
        float3 currentUV = float3(streams.TexCoord, centerDepth);

        // fetch position of history color
        float3 historyUV = currentUV;

        //--------------------------------------------------------------------------
        // Find the offset to the position with minimum depth in neighborhood
        // for diolation of foreground velocity map
        //--------------------------------------------------------------------------
        int2 offsets[] = 
        {
            {-OFFSET_LENGTH, -OFFSET_LENGTH},
            { OFFSET_LENGTH, -OFFSET_LENGTH},
            {-OFFSET_LENGTH,  OFFSET_LENGTH},
            { OFFSET_LENGTH,  OFFSET_LENGTH}
        };

        float4 neighbor4Depths = Texture1.GatherRed(PointSampler,
                                                       streams.TexCoord,
                                                       offsets[0],
                                                       offsets[1],
                                                       offsets[2],
                                                       offsets[3]);

        float2 neighborDepthOffset = float2(OFFSET_LENGTH, OFFSET_LENGTH);
        float neighborDepthOffsetX = OFFSET_LENGTH;

        if(neighbor4Depths.x < neighbor4Depths.y)
        {
            neighborDepthOffsetX = -OFFSET_LENGTH;
        }
        if(neighbor4Depths.z < neighbor4Depths.w)
        {
            neighborDepthOffset.x = -OFFSET_LENGTH;
        }
        float depthXY = min(neighbor4Depths.x, neighbor4Depths.y);
        float depthZW = min(neighbor4Depths.z, neighbor4Depths.w);
        if(depthXY < depthZW)
        {
            neighborDepthOffset.y = -OFFSET_LENGTH;
            neighborDepthOffset.x = neighborDepthOffsetX;
        }

        float depthXYZW = min(depthXY, depthZW);
        if(centerDepth > depthXYZW)
        {
            historyUV.xy += neighborDepthOffset * Texture0TexelSize;
            historyUV.z = depthXYZW;
        }

        // neighbor 3x3 pixels
        // 012
        // 345
        // 678
        float4 currentNeighborColor0 = Texture0.SampleLevel(PointSampler, currentUV.xy, 0, int2(-1, -1));
        float4 currentNeighborColor1 = Texture0.SampleLevel(PointSampler, currentUV.xy, 0, int2( 0, -1));
        float4 currentNeighborColor2 = Texture0.SampleLevel(PointSampler, currentUV.xy, 0, int2( 1, -1));
        float4 currentNeighborColor3 = Texture0.SampleLevel(PointSampler, currentUV.xy, 0, int2(-1,  0));
        float4 currentNeighborColor4 = Texture0.SampleLevel(PointSampler, currentUV.xy, 0, int2( 0,  0));
        float4 currentNeighborColor5 = Texture0.SampleLevel(PointSampler, currentUV.xy, 0, int2( 1,  0));
        float4 currentNeighborColor6 = Texture0.SampleLevel(PointSampler, currentUV.xy, 0, int2(-1,  1));
        float4 currentNeighborColor7 = Texture0.SampleLevel(PointSampler, currentUV.xy, 0, int2( 0,  1));
        float4 currentNeighborColor8 = Texture0.SampleLevel(PointSampler, currentUV.xy, 0, int2( 1,  1));

        // Apply tonemapping
        currentNeighborColor0.rgb *= SimpleTonemap(currentNeighborColor0.rgb);
        currentNeighborColor1.rgb *= SimpleTonemap(currentNeighborColor1.rgb);
        currentNeighborColor2.rgb *= SimpleTonemap(currentNeighborColor2.rgb);
        currentNeighborColor3.rgb *= SimpleTonemap(currentNeighborColor3.rgb);
        currentNeighborColor4.rgb *= SimpleTonemap(currentNeighborColor4.rgb);
        currentNeighborColor5.rgb *= SimpleTonemap(currentNeighborColor5.rgb);
        currentNeighborColor6.rgb *= SimpleTonemap(currentNeighborColor6.rgb);
        currentNeighborColor7.rgb *= SimpleTonemap(currentNeighborColor7.rgb);
        currentNeighborColor8.rgb *= SimpleTonemap(currentNeighborColor8.rgb);

        // Fetch velocity map with offset position by depth check
        float2 velocityUV     = Texture2.Sample(PointSampler, historyUV.xy).xy; // [-1, 1]
        velocityUV.x = -velocityUV.x;

        float2 velocityPixels = velocityUV / Texture0TexelSize;                        // [-1, 1] * (RTWidth, RTHeight)
    
        // Fetch history color (bilinear filter)
        float4 historyColor = Texture3.Sample(LinearSampler, currentUV.xy + velocityUV);
    
        // Apply tonemapping
        historyColor.rgb *= SimpleTonemap(historyColor.rgb);
    
        //--------------------------------------------------------------------------
        // Find min/max luminance on current neighbor pixels
        //--------------------------------------------------------------------------
        // Find minmax on cross shaped pixels (1)
        // x1x
        // 345
        // x7x
        float4 neighborMin = min(min(min(currentNeighborColor1, currentNeighborColor3), min(currentNeighborColor4, currentNeighborColor5)), currentNeighborColor7);        
        float4 neighborMax = max(max(max(currentNeighborColor1, currentNeighborColor3), max(currentNeighborColor4, currentNeighborColor5)), currentNeighborColor7);        
    
        // Find minmax on 3x3 pixels (2)
        // 012
        // 345
        // 678
        float4 neighborMin2 = min(min(currentNeighborColor0, currentNeighborColor2), min(currentNeighborColor6, currentNeighborColor8));
        float4 neighborMax2 = max(max(currentNeighborColor0, currentNeighborColor2), max(currentNeighborColor6, currentNeighborColor8));
        neighborMin2 = min(neighborMin2, neighborMin);
        neighborMax2 = max(neighborMax2, neighborMax);

        // Blend (1) and (2)
        neighborMin = neighborMin * 0.5 + neighborMin2 * 0.5;
        neighborMax = neighborMax * 0.5 + neighborMax2 * 0.5;

        // luminance range of current neighbor pixels
        float currentLumaMin = Luma(neighborMin.rgb);
        float currentLumaMax = Luma(neighborMax.rgb);
        float currentLumaContrast = currentLumaMax - currentLumaMin;


        // Apply LPF to current color
        float4 currentLPFColor =
            currentNeighborColor0 * u_WeightLow1.x +
            currentNeighborColor1 * u_WeightLow1.y +
            currentNeighborColor2 * u_WeightLow1.z +
            currentNeighborColor3 * u_WeightLow1.w +
            currentNeighborColor4 * u_WeightLowCenter +
            currentNeighborColor5 * u_WeightLow2.x +
            currentNeighborColor6 * u_WeightLow2.y +
            currentNeighborColor7 * u_WeightLow2.z +
            currentNeighborColor8 * u_WeightLow2.w;


        //--------------------------------------------------------------------------
        // Blend history color and current LPF color
        //
        // Blend weight is computed from the intersect point between
        // AABB(neighborMin-neighborMax) and line(historyColor-currentLPFColor).
        //--------------------------------------------------------------------------
        historyColor.rgb = IntersectAABBWithLine(historyColor.rgb, 
                                                 currentLPFColor.rgb, 
                                                 neighborMin.rgb, 
                                                 neighborMax.rgb);


        //--------------------------------------------------------------------------
        // Apply reconstruction filter current color
        // Use Blackman-Harris 3.3
        //--------------------------------------------------------------------------
        float4 currentColor =
            currentNeighborColor0 * u_Weight1.x +
            currentNeighborColor1 * u_Weight1.y +
            currentNeighborColor2 * u_Weight1.z +
            currentNeighborColor3 * u_Weight1.w +
            currentNeighborColor4 * u_WeightCenter +
            currentNeighborColor5 * u_Weight2.x +
            currentNeighborColor6 * u_Weight2.y +
            currentNeighborColor7 * u_Weight2.z +
            currentNeighborColor8 * u_Weight2.w;

        // Sharpening of current filtered color
        const float historyBlur = saturate((abs(velocityPixels.x) + abs(velocityPixels.y)) * u_HistoryBlurAmp);
        const float sharpness = saturate(saturate(historyBlur) * 0.5 + rcp(1.0 + currentLumaContrast * u_LumaContrastFactor));
        currentColor.rgb = lerp(currentColor.rgb, currentNeighborColor4.rgb, sharpness);

        //--------------------------------------------------------------------------
        // Compute blend weight from luminance and velocity amounts
        //--------------------------------------------------------------------------
        const float historyAmount = (1.0f + historyBlur) * u_BlendWeightMin;
        float historyLuma = Luma(historyColor.rgb);
        historyLuma = min(abs(currentLumaMin - historyLuma), abs(currentLumaMax - historyLuma));
        const float historyFactor = historyLuma * historyAmount * (1.0 + historyBlur * historyAmount * 8.0);
        float blendWeight = saturate(historyFactor * rcp(max(0.001f, historyLuma + currentLumaContrast)));
    
        //--------------------------------------------------------------------------
        // Clamp blend weight by velocity amounts and blend weight of previous frame
        //--------------------------------------------------------------------------
        const float velocityLength = sqrt(dot(velocityPixels, velocityPixels));
        const float prevBlendWeight = historyColor.a;
        const float velocityDiff = abs(prevBlendWeight - velocityLength) / max(1.0, max(prevBlendWeight, velocityLength));
        blendWeight = clamp(blendWeight, velocityDiff * u_BlendWeightMin, u_BlendWeightMax);

        //--------------------------------------------------------------------------
        // Blend filtered current color and history color
        //--------------------------------------------------------------------------
        float4 outputColor = float4(0.0f, 0.0f, 0.0f, 0.0f);
        outputColor.rgb = lerp(historyColor.rgb, currentColor.rgb, blendWeight);

        // Save alpha channel for velocityUV weighting
        outputColor.a = max(historyColor.a * u_VelocityDecay, velocityLength * rcp(u_VelocityDecay));

        // Revert tonemapping
        outputColor.rgb *= SimpleTonemapInv(outputColor.rgb);

        // Avoid NaN : transform to 0
        outputColor.rgb = -min(-outputColor.rgb, 0.0);

        return outputColor;
    }

    //------------------------------------------------------------------------------
    //
    // Utility functions
    //
    //------------------------------------------------------------------------------
    float nonzero(float a)
    {
        const float CLAMP_MIN= 0.001f;
        return  a > -CLAMP_MIN && a < CLAMP_MIN ? CLAMP_MIN : a;
    }
    float3 nonzero3(float3 a)
    {
        return  float3( nonzero(a.x), nonzero(a.y), nonzero(a.z) );
    }

    float3 IntersectAABBWithLine(float3 startLine, float3 endLine, float3 minAABB, float3 maxAABB)
    {
        float3 minPos = min(endLine, min(minAABB, maxAABB));
        float3 maxPos = max(endLine, max(minAABB, maxAABB));
        float3 centerAABB = (maxPos + minPos) * 0.5;
        float3 dir = nonzero3(endLine - startLine);
        float3 invDir = rcp(dir);
        float3 org = startLine - centerAABB;
        float3 scaleAABB = maxPos - centerAABB;

        float3 pos0 = (scaleAABB - org) * invDir;
        float3 pos1 = ((-scaleAABB) - org) * invDir;
        float intersectPos = saturate(max(max(min(pos0.x, pos1.x), min(pos0.y, pos1.y)), min(pos0.z, pos1.z)));
        return lerp(startLine, endLine, intersectPos);
    }

    float Luma(float3 rgbColor)
    {
        return dot(rgbColor, float3(0.299, 0.587, 0.114));
    }


    float3 SimpleTonemap(float3 linearColor) 
    {
        return rcp(linearColor + 1.0f);
    }


    float3 SimpleTonemapInv(float3 tonemappedColor) 
    {
        return rcp(1.0f - tonemappedColor);
    }
};
