// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.

namespace Stride.Rendering.Images
{
    /// <summary>
    /// Simple fog
    /// </summary>
    internal shader OutlineEffect : ImageEffectShader
    {
        stage float2 ScreenDiffs;  // .x = Width, .y = Height

        stage float zFar;
        stage float zNear;

        stage float NormalWeight;
        stage float DepthWeight;
        stage float NormalNearCutoff;

        stage Texture2D DepthTexture;
        
        float3 normal_from_depth(float depth, float2 texcoords) {
          const float2 offset1 = float2(0.0,ScreenDiffs.y);
          const float2 offset2 = float2(ScreenDiffs.x,0.0);
  
          float depth1 = DepthTexture.SampleLevel(PointSampler, texcoords + offset1, 0.0).x;
          float depth2 = DepthTexture.SampleLevel(PointSampler, texcoords + offset2, 0.0).x;
  
          float3 p1 = float3(offset1, depth1 - depth);
          float3 p2 = float3(offset2, depth2 - depth);
  
          float3 normal = cross(p1, p2);
          normal.z = -normal.z;
  
          return normalize(normal);
        }

        float4 fetchNormalDepth(float2 tc){
            float4 nd; // return value

            // get depth
            float z_b = DepthTexture.SampleLevel(PointSampler, tc, 0.0).x;
            float z_n = 2.0 * z_b - 1.0;
            float linearDepth = 2.0 * zNear * zFar / (zFar + zNear - z_n * (zFar - zNear));

            // linear depth
            nd.w = DepthWeight * linearDepth;
            
            // normal, but skip stuff really close
            nd.xyz = step(NormalNearCutoff, linearDepth) * normal_from_depth(z_b, tc) * NormalWeight;
            
            return nd;
        }

        stage override float4 Shading() {
            float4 color = Texture0.Sample(PointSampler, streams.TexCoord);

            float4 n1 = fetchNormalDepth(streams.TexCoord + float2(-ScreenDiffs.x, -ScreenDiffs.y));
            float4 n2 = fetchNormalDepth(streams.TexCoord + float2( ScreenDiffs.x,  ScreenDiffs.y));
            float4 n3 = fetchNormalDepth(streams.TexCoord + float2(-ScreenDiffs.x,  ScreenDiffs.y));
            float4 n4 = fetchNormalDepth(streams.TexCoord + float2( ScreenDiffs.x, -ScreenDiffs.y));

            // Work out how much the normal and depth values are changing.
            float4 diagonalDelta = abs(n1 - n2) + abs(n3 - n4);

            float normalDelta = dot(diagonalDelta.xyz, float3(1.0, 1.0, 1.0));
            float totalDelta = diagonalDelta.w + normalDelta * 0.4;
            
            return float4(color.xyz * (1.0 - clamp(totalDelta, 0.0, 1.0)), 1.0);
        }
    };
}
