// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.

namespace Stride.Rendering.Images
{
    shader AmbientOcclusionRawAOShader <int SamplesCount, bool IsOrthographic>  : ImageEffectShader, Camera
    {
        float4   ProjInfo;    // .x = zN * zF, .y = zN - zF, .z = zF
        float4   ScreenInfo;  // .x = Width, .y = Height, .z = Aspect

        float    ParamProjScale = 1;
        float    ParamIntensity = 1;
        float    ParamBias = 0.01f;
        float    ParamRadius = 1;
        float    ParamRadiusSquared = 1;

        stage float reconstructCSZ(float depth)
        {
            if (IsOrthographic) //near + z * (far - near)
                return ZProjection.x + depth * ZProjection.y;
            else
                return ZProjection.y / (depth - ZProjection.x);
        }

        stage float3 reconstructCSPosition(float2 S, float z)
        {
            if (IsOrthographic)
            {
                float2 uv = S.xy / ScreenInfo.xy;
                uv = uv * 2 - 1;
                return float3(uv * ProjInfo.xy, z);
            }
            else
            {
                return float3((S.xy * ProjInfo.xy + ProjInfo.zw) * z, z);
            }
        }

        stage float3 reconstructCSNormal(float3 position)
        {
            return normalize(cross(ddy(position), ddx(position)));
        }

        stage float sampleAO(int2 screenPosition, float3 viewPosition, float3 viewNormal, float diskRadius, int i, float randomPatternRotationAngle)
        {
            //*****************************
            //  Sample Offset
            float alpha = 1 * (i + 0.5) * 0.675f / SamplesCount;
            float angle = 1 * 43.9822971503f * alpha + randomPatternRotationAngle;

            float2 offset = float2(cos(angle), sin(angle));
            float ssRadius = alpha * diskRadius;

            //*****************************
            //  Depth
            float2 samplePos = streams.TexCoord + offset * ssRadius;
            int2 samplePosInt = saturate(samplePos) * ScreenInfo.xy;
            
            float depth = Texture0.Load(int3(samplePosInt, 0));
            float linearDepth = reconstructCSZ(depth);

            //*****************************
            // View Position
            float3 position = reconstructCSPosition(samplePosInt + float2(0.5, 0.5), linearDepth);
            position.x *= -1;

            //*****************************
            // View Normal
            float3 v = position - viewPosition;
            v.z *= -1;
            
            //*****************************
            // Ambient Occlusion
            float distSq = dot(v, v);
            float vn = dot(v, viewNormal);

            const float epsilon = 0.01;

            float f = max(ParamRadiusSquared - distSq, 0.0);

            return f * f * f * max((vn - ParamBias) / (epsilon + distSq), 0.0);
        }

        stage override float4 Shading()
        {
            //*****************************
            // Reconstruct View space linear depth Z from the depth buffer
            float depth = Texture0.SampleLevel(Sampler, streams.TexCoord, 0).x;
            float linearDepth = reconstructCSZ(depth);

            //*****************************
            // Reconstruct View space position XYZ
            int2 screenPosition = streams.TexCoord.xy * ScreenInfo.xy;
            float3 viewPosition = reconstructCSPosition(screenPosition + float2(0.5, 0.5), linearDepth);
            viewPosition.x *= -1;

            //*****************************
            // Reconstruct View space normal NxNyNz
            float3 viewNormal = reconstructCSNormal(viewPosition.xyz);
            viewNormal.xy *= -1;

            //*****************************
            // Hash function used in the HPG12 AlchemyAO paper
            int linearDepthInt = (int)linearDepth;
            //float randomPatternRotationAngle = (3 * screenPosition.x ^ screenPosition.y + screenPosition.x * screenPosition.y) * 10;
            float randomPatternRotationAngle = ((15 * linearDepthInt + 3 * screenPosition.x ^ 2 * screenPosition.y + screenPosition.x * screenPosition.y) & 0x0000FFFF) * 10;

            //*****************************
            // Choose a sample radius proportional to the projected area of the half-sphere
            //float diskRadius = -projScale * radius / linearDepth;
            float diskRadius;
            if (IsOrthographic) 
                diskRadius = ParamProjScale / ProjInfo.z;
            else
                diskRadius = ParamProjScale / linearDepth;

            //*****************************
            // Compute the ambient occlusion
            float sum = 0.0;
            for (int i = 0; i < SamplesCount; i++)
            {
                sum += sampleAO(screenPosition, viewPosition, viewNormal, diskRadius, i, randomPatternRotationAngle);
            }

            float temp = ParamRadiusSquared * ParamRadius;
            sum /= temp * temp;
            float A = max(0.0, 1.0 - sum * 5 * ParamIntensity / SamplesCount);
            
            float nearPlaneFade = saturate(linearDepth * 2.0 - 0.5);
            A = lerp(1, A, nearPlaneFade);

            //*****************************
            // Bilateral box-filter over a quad for free, respecting depth edges
            // (the difference that this makes is subtle)
            if (abs(ddx(linearDepth)) < 0.02)
            {
                A -= ddx(A) * ((screenPosition.x & 1) - 0.5);
            }
            if (abs(ddy(linearDepth)) < 0.02)
            {
                A -= ddy(A) * ((screenPosition.y & 1) - 0.5);
            }

            //************************
            // A now contains the light intensity factor (0 to 1) which can be applied to the ambient light illuminating the pixel



            //************************
            // Debug - visualize different
            //************************

            /************************
            // Visualize depth as color bands
            //************************
            linearDepth = sum;
            float4 color = Texture0.Sample(Sampler, streams.TexCoord);
            color.r = ((float)(linearDepth % 4)) / 4.0;
            color.g = ((float)((linearDepth / 4) % 4)) / 4.0;
            color.b = ((float)((linearDepth / 16) % 4)) / 4.0;
            return color; //*/



            return float4(A, A, A, A);
        }
    };
}
