﻿// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.

namespace Stride.Rendering.Materials
{
    class MaterialHairLightAttenuationFunctionDirectional<int NormalMode, LinkType THardnessReciprocal, LinkType TBoundaryShift> : IMaterialHairLightAttenuationFunction, 
        NormalStream, LightStream  // These are required for accessing the normals and light direction.
    {
        cbuffer PerMaterial
        {
            [Link("THardnessReciprocal")]
			stage float HardnessReciprocal; // == 1.0 / hardness
            [Link("TBoundaryShift")]
			stage float BoundaryShift;    // Range: [0.0 ... 0.5]
        }

        float CalculateNdotL(void)
        {
            const float3 meshNormalWorldSpaceShifted = normalize(lerp(streams.meshNormalWS, -streams.lightDirectionWS, BoundaryShift));
            const float3 normalMapNormalShifted = normalize(lerp(streams.normalWS, -streams.lightDirectionWS, BoundaryShift));  // If no normal map is present, this will be equal to the mesh normal.

            if(NormalMode == 0) // Mesh normals:
            {
                return dot(meshNormalWorldSpaceShifted, streams.lightDirectionWS);
            }
            else if(NormalMode == 1) // Normal map normals:
            {
                return dot(normalMapNormalShifted, streams.lightDirectionWS);
            }
            else    // Mesh & normal map normals:
            {
                // Alternate approach:
                //return max(dot(normalMapNormalShifted, streams.lightDirectionWS),
                //           dot(meshNormalWorldSpaceShifted, streams.lightDirectionWS));

                // More conservative approach:
                return min(dot(normalMapNormalShifted, streams.lightDirectionWS),
                           dot(meshNormalWorldSpaceShifted, streams.lightDirectionWS));
            }
        }
        
        float Compute(void)
        {
            float saturatedNdotL = saturate(CalculateNdotL());
            return pow(saturatedNdotL, HardnessReciprocal);
        }
    };
}
