// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.
namespace Stride.Rendering.Lights
{
    /// <summary>
    /// Base class for computing texture projection for a light. Defines functions for computing the color of the projected texture at this world space position.
    /// </summary>
    /// <remarks>
    /// PerLightGroup: Parameter used to uniquely identify this group of lights.
    /// TCascadeCountBase: The number of cascades of the current light.
    /// TLightCountBase: The number of lights inside of this light group.
    /// </remarks>
    internal shader TextureProjectionReceiverBase<MemberName PerLightGroup, int TCascadeCountBase, int TLightCountBase, int TFlipMode> :   // TODO: Rename to "TextureProjectionReceiver".
        TextureProjectionCommon<PerLightGroup>, // Required for accessing the texture that is projected.
        TextureProjectionFilterDefault<PerLightGroup>, // Defines "FilterProjectedTexture()".
        Texturing   // Required for the texture sampling.
    {
        // Enum values for the texture flip mode:
        // These values have to match the ones defined in "LightSpot.cs".
        const int FlipModeNone = 0;
        const int FlipModeX = 1;
        const int FlipModeY = 2;
        const int FlipModeXY = 3;
        /////////////////////////////////////////////////////////////////

        cbuffer PerLightGroup
        {
            float4x4 WorldToProjectiveTextureUV[TCascadeCountBase * TLightCountBase];
            float4x4 ProjectorPlaneMatrices[TCascadeCountBase * TLightCountBase];   // Contains the world matrix of the projector plane. Required for ray-plane intersection testing.
            float ProjectionTextureMipMapLevels[TLightCountBase];   // TODO: Not sure how to handle this in combination with cascades.
            float TransitionAreas[TLightCountBase];
        };
        
        /*
        // Returns "1.0" if "point" is inside the box defined by "bottomLeft" and "topRight". Returns "0.0" otherwise.
        // This function was taken from here: http://stackoverflow.com/questions/12751080/glsl-point-inside-box-test
        float InsideOfRectangle(float2 p, float2 bottomLeft, float2 topRight)
        {
            float2 s = step(bottomLeft, p) - step(topRight, p);
            return(s.x * s.y);
        }
        */

        // The "transitionArea" parameter defines the size of the transition area between inside and outside the rectangle.
        // The transition is faded inside the rectangle, not outside of it.
        // NOTE: The "transitionArea" parameter must be smaller than the distance between "bottomLeft" and "topRight" (in each dimension).
        // More information: http://stackoverflow.com/questions/12751080/glsl-point-inside-box-test
        float insideOfRectangleSmooth(float2 p, float2 bottomLeft, float2 topRight, float transitionArea)
        {
            float2 s = smoothstep(bottomLeft, bottomLeft + transitionArea, p) -
			           smoothstep(topRight - transitionArea, topRight, p);
            return(s.x * s.y);
        }

        float CalculateRectangularMask(float2 projectedTextureCoordinate, float clipSpaceZ, float transitionArea)
        {
            // Mask the projection at the edges of the frustum:
            //float mask = InsideOfRectangle(projectedTextureCoordinate, 0.0f, 1.0f);    // TODO: Move the mask to the light attenuation code? It would be cleaner but also more difficult, because we don't have the texture coordinates there.
            float mask = insideOfRectangleSmooth(projectedTextureCoordinate, 0.0f, 1.0f, transitionArea);    // TODO: Move the mask to the light attenuation code? It would be cleaner but also more difficult, because we don't have the texture coordinates there.

            // Now mask the back projection, because we don't want any texture on the back of the light:
            // TODO: PERFORMANCE: Profile performance difference between branching and masking the back projection.
            //if(clipSpaceCoordinate.z < 0.0f)
            //{
            //    return float3(0.0f, 0.0f, 0.0f);
            //}
            
            // Same as the above but branchless:
            mask *= step(0.0f, clipSpaceZ);  // If clipSpaceCoordinate.z >= 0.0f, return 1.0f. Otherwise return 0.0f.    // TODO: Maybe we should move this to the light attenuation code.

            return mask;
        }
        
        void ModifyTextureCoordinate(inout float2 textureCoordinate)
        {
            if(TFlipMode == FlipModeX || TFlipMode == FlipModeXY)
            {
                textureCoordinate.x = 1.0f - textureCoordinate.x;
            }

            if(TFlipMode == FlipModeY || TFlipMode == FlipModeXY)
            {
                textureCoordinate.y = 1.0f - textureCoordinate.y;
            }
        }
        
        // Implemented according to "http://geomalgorithms.com/a06-_intersect-2.html".
        bool IntersectPlane(float3 rayOrigin, float3 rayDirection, float3 planeNormal, float3 planeOrigin, out float3 pointOfIntersection)
        {
            const float epsilon = 0.001f;

            float planeDotRayOrigin = dot(planeNormal, planeOrigin - rayOrigin);
            float planeDotRayDirection = dot(planeNormal, rayDirection);

            if((planeDotRayDirection > -epsilon)&&(planeDotRayDirection < epsilon))	// The ray is (almost) parallel to the plane. No intersection is possible:
            {
                return(false);
            }

            /*
            When the denominator n_dot_(P1-P0)=0, the line L is parallel to the plane P,
            and thus either does not intersect it or else lies completely in the plane
            (whenever either P0 or P1 is in P ).
            Otherwise, when the denominator is nonzero and rI is a real number,
            then the ray R intersects the plane P only when rI.ge.0.
            A segment S intersects P only if rI.ge-0.le-1.
            In all algorithms, the additional test rI.le.1 is the only difference for a segment instead of a ray.
            */

            float intersectionDistance = planeDotRayOrigin / planeDotRayDirection;

            if(intersectionDistance >= 0.0f)	// TODO: Remove branch?
            {
                pointOfIntersection = rayOrigin + rayDirection * intersectionDistance;
                return(true);
            }

            return(false);
        }
        
        // Computes a reflection fo the texture projector the world position "positionWS".
        float3 ComputeSpecularTextureProjectionFromCascade(float3 positionWS, float3 reflectionWS, int cascadeIndex, int lightIndex)
        {
            int matrixIndex = cascadeIndex + lightIndex * TCascadeCountBase;

            float3 rayOrigin = positionWS;
            float3 rayDirection = reflectionWS;
            
            float4x4 projectorPlaneMatrix = ProjectorPlaneMatrices[matrixIndex];
            float3 planeAxisX = float3(projectorPlaneMatrix._m00, projectorPlaneMatrix._m01, projectorPlaneMatrix._m02);
            float3 planeAxisY = float3(projectorPlaneMatrix._m10, projectorPlaneMatrix._m11, projectorPlaneMatrix._m12);
            float3 planeNormal = float3(projectorPlaneMatrix._m20, projectorPlaneMatrix._m21, projectorPlaneMatrix._m22);   // Z axis
            float3 planeOrigin = float3(projectorPlaneMatrix._m30, projectorPlaneMatrix._m31, projectorPlaneMatrix._m32);   // Position/Origin

            float3 pointOfIntersection;
            bool intersectionFound = IntersectPlane(rayOrigin, rayDirection, planeNormal, planeOrigin, pointOfIntersection);

            if(intersectionFound)   // TODO: PERFORMANCE: Branch or just mask the result?
            {
                float mipMapLevel = ProjectionTextureMipMapLevels[lightIndex];
                float transitionArea = TransitionAreas[lightIndex];

                float planeWidth = length(planeAxisX);
                float planeHeight = length(planeAxisY);

                // Project the intersection point to plane space:
                float3 planeOriginToPointOfIntersection = pointOfIntersection - planeOrigin;
                float planeSpaceX = dot(planeOriginToPointOfIntersection, planeAxisX) / planeWidth; 
                float planeSpaceY = dot(planeOriginToPointOfIntersection, planeAxisY) / planeHeight;
                
                // Normalize the plane space coordinates:
                float2 planeTextureCoordinate = float2(planeSpaceX, planeSpaceY) * 0.5f / float2(planeWidth, planeHeight) + 0.5f;
                //planeTextureCoordinate = 1.0f - planeTextureCoordinate;
                
                //float planeReflectionMask = insideOfRectangleSmooth(planeTextureCoordinate, 0.0f, 1.0f, transitionArea);    // TODO: Move the mask to the light attenuation code? It would be cleaner but also more difficult, because we don't have the texture coordinates there.
                float planeReflectionMask = CalculateRectangularMask(planeTextureCoordinate,
                                                                     dot(planeNormal, rayDirection),    // Pass dot product instead of clip space z, because all we need is a negative value to mask out he back side.
                                                                     transitionArea);

                return FilterProjectedTexture(planeTextureCoordinate, mipMapLevel) * planeReflectionMask;

            }

            return 0.0f;
        }

        // Computes the color of the projected texture at the world position "positionWS".
        float3 ComputeTextureProjectionFromCascade(float3 positionWS, int cascadeIndex, int lightIndex)
        {
            int matrixIndex = cascadeIndex + lightIndex * TCascadeCountBase;

            float4 clipSpaceCoordinate = mul(float4(positionWS, 1.0), WorldToProjectiveTextureUV[matrixIndex]);
            float2 projectedTextureCoordinate = clipSpaceCoordinate.xy / clipSpaceCoordinate.w;  // W-divide because it's a projection matrix.
            projectedTextureCoordinate.xy = projectedTextureCoordinate.xy * 0.5 + 0.5;  // Offset the clip space coordinates from [-1.0 ... 1.0] to [0.0 ... 1.0].
            projectedTextureCoordinate.y = 1.0f - projectedTextureCoordinate.y;

            ModifyTextureCoordinate(projectedTextureCoordinate);

            // TODO: PERFORMANCE: Using a texture border and setting the wrapping mode to "clamp" would be faster. Not sure if that would be a possibility though.
            float mipMapLevel = ProjectionTextureMipMapLevels[lightIndex];
            float transitionArea = TransitionAreas[lightIndex];
            float mask = CalculateRectangularMask(projectedTextureCoordinate, clipSpaceCoordinate.z, transitionArea);
            return FilterProjectedTexture(projectedTextureCoordinate, mipMapLevel) * mask;
        }
    };
}
