// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.
namespace Stride.Rendering.Materials
{
    /// <summary>
    /// Modify the alpha color based on the matDiffuseSpecularAlphaBlend alpha
    /// </summary>
    shader MaterialTransmittanceReflectanceStream : MaterialPixelStream
    {
        cbuffer PerMaterial
        {
            stage float RefractiveIndex;
        }

        stage stream float3  matTransmittance;
        stage stream float3  matReflectance;

        override void ResetStream()
        {
            base.ResetStream();

            streams.matTransmittance = 0.0f;
            streams.matReflectance = 1.0f;
        }

        override void PrepareMaterialForLightingAndShading()
        {
            base.PrepareMaterialForLightingAndShading();

		    // Angle between view vector and surface normal
			const float cosTheta = streams.NdotV;
			const float sinTheta2 = 1 - cosTheta * cosTheta;    // Square of sinTheta

			float eta = max(RefractiveIndex, 1.0001);

		    const float sinRefractedTheta2 = sinTheta2 / (eta * eta);        // Square of sinRefractedTheta, We don't actually need sinRefractedTheta
			const float cosRefractedTheta = sqrt(1 - sinRefractedTheta2);

			const float q0 = (eta * cosRefractedTheta - cosTheta);
			const float q1 = (eta * cosRefractedTheta + cosTheta);
			const float q2 = (eta * cosTheta - cosRefractedTheta);
			const float q3 = (eta * cosTheta + cosRefractedTheta);

			const float r0 = q0 / q1;
			const float r1 = q2 / q3;

			// Fresnel reflectance at the entering interface
			const float R0 = 0.5 * saturate(r0 * r0 + r1 * r1);    // TODO: Test if this command can be optimized by using float2(r0, r1).length() on target platforms
			// Fresnel transmittance at the entering interface
			const float T0 = 1 - R0;

			// intermediate float3 values
			const float3 R = float3(R0, R0, R0);
			const float3 T = float3(T0, T0, T0);
			const float3 C = float3(cosRefractedTheta, cosRefractedTheta, cosRefractedTheta);

			// Coefficient to account for absorption
			const float3 K = pow(max(streams.matColorBase.rgb, 0.001), 1 / C);

			const float3 RK = R*K;     // intermediate value

			float3 transmittance   = saturate(T*T * K / (1 - RK * RK));
			streams.matReflectance = saturate(RK  * transmittance + R);
            streams.matTransmittance = transmittance;
        }
    };
}
