// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.
namespace Stride.Rendering.BRDF
{
    /// <summary>
    /// Utility shader for calculating the various variations of the functions 
    /// (Fresnel, NDF and Visibility) involved in a Microfacet shading model.
    /// </summary>
    shader BRDFMicrofacet : Math
    {
        // References:
        // http://graphicrants.blogspot.jp/2013/08/specular-brdf-reference.html
        // TODO: Add reference to original papers here

        // -------------------------------------------------
        // Normal Distribution Functions
        // -------------------------------------------------
        // Expected parameters:
        //    alphaR = roughness^2 (Burley)
        //    nDotH  = saturate(dot(n, h))
        // -------------------------------------------------
        // TODO Add GGX Anisotropic
        
        float ClampedPow(float x, float y)
        {
            return pow(max(x, 0.00001f), y);
        }

        /// <summary>
        /// Calculate the NDF Blinn-Phong
        /// </summary>
        float NormalDistributionBlinnPhong(float alphaR, float nDotH) 
        {
            var alphaR2 =  max(alphaR * alphaR, 0.1);  // Cap the value to avoid high exponents. TODO: Find an acceptable limit with 16 bits floats targets
            return ClampedPow(nDotH, 2 / alphaR2 - 2) / (PI * alphaR2);
        }

        /// <summary>
        /// Calculate the NDF Beckmann
        /// </summary>
        float NormalDistributionBeckmann(float alphaR, float nDotH) 
        {
            var alphaR2 =  max(alphaR * alphaR, 0.1);  // Cap the value to avoid high exponents. TODO: Find an acceptable limit with 16 bits floats targets
            var nDotH2 =  max(nDotH * nDotH, 0.0001);
            var nDotH4 =  nDotH2 * nDotH2;
            return exp((nDotH2 -1)/(alphaR2 * nDotH2))/(PI * alphaR2 * nDotH4);
        }

        /// <summary>
        /// Calculate the NDF GGX
        /// </summary>
        float NormalDistributionGGX(float alphaR, float nDotH) 
        {
            var alphaR2 =  alphaR * alphaR;
            var d = max(nDotH * nDotH * (alphaR2 - 1) + 1, 0.0001);
            return alphaR2 / (PI * d * d);
        }

        // -------------------------------------------------
        // Fresnel Functions
        // -------------------------------------------------
        // Expected parameters:
        //    f0 = fresnel specular color at angle 0
        //    vDotH  = saturate(dot(v, h))
        // -------------------------------------------------

        /// <summary>
        /// Calculate a nop Fresnel.
        /// </summary>
        float3 FresnelNone(float3 f0)
        {
            return f0;
        }

        /// <summary>
        /// Calculate a Schlick approximation to Fresnel 
        /// </summary>
        float3 FresnelSchlick(float3 f0, float lOrVDotH)
        {
            return FresnelSchlick(f0, 1.0f, lOrVDotH);
        }

        /// <summary>
        /// Calculate a Schlick approximation to Fresnel with f0, f90
        /// </summary>
        float3 FresnelSchlick(float3 f0, float3 f90, float lOrVDotH)
        {
            return f0 + (f90 - f0) * pow((1-lOrVDotH), 5);
        }

        // -------------------------------------------------
        // Geometric Shadowing Functions
        // -------------------------------------------------
        // We are using V (Visibility) instead of G (Geometric Shadowing function)
        // The formula for V is given by:
        //    V = G / (nDotL * nDotV)
        //
        // Expected parameters:
        //    alphaR = roughness^2 (Burley)
        //    nDotV  = max(dot(n, v), 1e-5f)
        //    nDotL  = saturate(dot(n, l))
        //    nDotH  = saturate(dot(n, h))
        // -------------------------------------------------
        
        /// <summary>
        /// Calculate the Implicit Geometric Shadowing
        /// </summary>
        float VisibilityImplicit(float nDotL, float nDotV)
        {
            // G = nDotL * nDotV
            return 1.0f;
        }

        /// <summary>
        /// Calculate the Neumann Geometric Shadowing
        /// </summary>
        float VisibilityNeumann(float nDotL, float nDotV)
        {
            // G = (nDotL * nDotV) / max(nDotL, nDotV) 
            return 1.0 / max(nDotL, nDotV);
        }

        /// <summary>
        /// Calculate the Cook-Torrance Geometric Shadowing
        /// </summary>
        float VisibilityCookTorrance(float nDotH, float vDotH, float nDotL, float nDotV)
        {
            // G = min(1, min(2 * nDotH * nDotV / vDotH, 2 * nDotH * nDotL / vDotH));
            return min(1, min(2 * nDotH * nDotV / vDotH, 2 * nDotH * nDotL / vDotH)) / (nDotL * nDotV);
        }

        /// <summary>
        /// Calculate the Kelemen Geometric Shadowing
        /// </summary>
        float VisibilityKelemen(float vDotH, float nDotL, float nDotV)
        {
            // G = nDotL * nDotV / (vDotH * vDotH);
            return 1.0f / (vDotH * vDotH);
        }

        float VisibilityBeckmann(float alphaR, float nDotX)
        {
            float c = nDotX / (alphaR * sqrt(1 - nDotX * nDotX));
            return c < 1.6f ? (3.535f * c + 2.181f * c * c) / ( 1 + 2.276f * c + 2577 * c * c) : 1.0f;
        }

        /// <summary>
        /// Calculate the Smith-Beckmann Geometric Shadowing (to use with their respective NDF)
        /// </summary>
        float VisibilitySmithBeckmann(float alphaR, float nDotL, float nDotV)
        {
            return (VisibilityBeckmann(alphaR, nDotL) * VisibilityBeckmann(alphaR, nDotV)) / (nDotL * nDotV);
        }

        float VisibilityGGXCorrelated(float alphaR, float nDotX)
        {
            var alphaR2 = alphaR * alphaR;
            var nDotX2 = nDotX * nDotX;
            return sqrt(1 + alphaR2 * ( 1 - nDotX2) / nDotX2);
        }

        /// <summary>
        /// Calculate the Smith-GGX Correlated Geometric Shadowing
        /// </summary>
        /// <remarks>See Moving Frostbite to PBR. SmithGGX Correlated</remarks>
        float VisibilitySmithGGXCorrelated(float alphaR, float nDotL, float nDotV)
        {
            // TODO: Expand (nDotL * nDotV)
            return 2.0f / ( VisibilityGGXCorrelated(alphaR, nDotL) + VisibilityGGXCorrelated(alphaR, nDotV)) / (nDotL * nDotV);
        }

        float VisibilityhSchlickBeckmann(float alphaR, float nDotX)
        {
            var k = alphaR * sqrt(2.0f / PI);
            return nDotX / (nDotX * (1 - k) + k);
        }

        /// <summary>
        /// Calculate the Smith-Schlick-Beckmann Geometric Shadowing
        /// </summary>
        float VisibilitySmithSchlickBeckmann(float alphaR, float nDotL, float nDotV)
        {
            return VisibilityhSchlickBeckmann(alphaR, nDotL) * VisibilityhSchlickBeckmann(alphaR, nDotV) / (nDotL * nDotV);
        }

        float VisibilityhSchlickGGX(float alphaR, float nDotX)
        {
            var k = alphaR * 0.5f;
            return nDotX / (nDotX * (1.0f - k) + k);
        }

        /// <summary>
        /// Calculate the Smith-Schlick-GGX Geometric Shadowing
        /// </summary>
        float VisibilitySmithSchlickGGX(float alphaR, float nDotL, float nDotV)
        {
            return VisibilityhSchlickGGX(alphaR, nDotL) * VisibilityhSchlickGGX(alphaR, nDotV) / (nDotL * nDotV);
        }

        float3 EnvironmentLightingDFG_GGX_Schlick_SmithSchlickGGX( float3 specularColor, float alphaR, float nDotV )
        {
            float x = 1 - alphaR;
            float y = nDotV;
 
            float b1 = -0.1688;
            float b2 = 1.895;
            float b3 = 0.9903;
            float b4 = -4.853;
            float b5 = 8.404;
            float b6 = -5.069;
            float bias = saturate( min( b1 * x + b2 * x * x, b3 + b4 * y + b5 * y * y + b6 * y * y * y ) );
 
            float d0 = 0.6045;
            float d1 = 1.699;
            float d2 = -0.5228;
            float d3 = -3.603;
            float d4 = 1.404;
            float d5 = 0.1939;
            float d6 = 2.661;
            float delta = saturate( d0 + d1 * x + d2 * y + d3 * x * x + d4 * x * y + d5 * y * y + d6 * x * x * x );
            float scale = delta - bias;
 
            bias *= saturate( 50.0 * specularColor.y );
            return specularColor * scale + bias;
        }
    };
}
