// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.
/// <summary>
/// Performs PN tessellation (ensures edges preservation).
/// </summary>
/// <remarks>
/// InputControlPointCount: Macro - number of input control points.
/// </remarks>
#ifndef InputControlPointCount
# define InputControlPointCount 3
#endif

shader TessellationPN : TessellationFlat
{
    patchstream float3 f3ViewB111;
    stream float3 nextPositionVS1;
    stream float3 nextPositionVS2;

    float3 ComputeControlPoint(float3 pA, float3 pB, float3 nA)
    {
        return (2.0 * pA + pB - (dot((pB - pA), nA) * nA)) / 3.0f;
    }

    float ComputeClippingGroup4(float3 f3Position1, float3 f3Position2, float3 f3Position3, float3 f3Position4)
    {
        float4 clipPos1 = this.ComputeShadingPosition(float4(f3Position1.xyz, 1.0f));
        float4 clipPos2 = this.ComputeShadingPosition(float4(f3Position2.xyz, 1.0f));
        float4 clipPos3 = this.ComputeShadingPosition(float4(f3Position3.xyz, 1.0f));
        float4 clipPos4 = this.ComputeShadingPosition(float4(f3Position4.xyz, 1.0f));

        float3 planeTest;

        planeTest.x = ((-clipPos1.w <= clipPos1.x && clipPos1.x <= clipPos1.w) ? 1.0f : 0.0f)
                    + ((-clipPos2.w <= clipPos2.x && clipPos2.x <= clipPos2.w) ? 1.0f : 0.0f)
                    + ((-clipPos3.w <= clipPos3.x && clipPos3.x <= clipPos3.w) ? 1.0f : 0.0f)
                    + ((-clipPos4.w <= clipPos4.x && clipPos4.x <= clipPos4.w) ? 1.0f : 0.0f);

        planeTest.y = ((-clipPos1.w <= clipPos1.y && clipPos1.y <= clipPos1.w) ? 1.0f : 0.0f)
                    + ((-clipPos2.w <= clipPos2.y && clipPos2.y <= clipPos2.w) ? 1.0f : 0.0f)
                    + ((-clipPos3.w <= clipPos3.y && clipPos3.y <= clipPos3.w) ? 1.0f : 0.0f)
                    + ((-clipPos4.w <= clipPos4.y && clipPos4.y <= clipPos4.w) ? 1.0f : 0.0f);

        planeTest.z = ((-clipPos1.w <= clipPos1.z && clipPos1.z <= clipPos1.w) ? 1.0f : 0.0f)
                    + ((-clipPos2.w <= clipPos2.z && clipPos2.z <= clipPos2.w) ? 1.0f : 0.0f)
                    + ((-clipPos3.w <= clipPos3.z && clipPos3.z <= clipPos3.w) ? 1.0f : 0.0f)
                    + ((-clipPos4.w <= clipPos4.z && clipPos4.z <= clipPos4.w) ? 1.0f : 0.0f);

        return !all(planeTest != 0.0f) ? 1.0 : 0.0;
    }

    float ComputeClippingGroup8(float3 f3Position1, float3 f3Position2, float3 f3Position3, float3 f3Position4, float3 f3Position5, float3 f3Position6, float3 f3Position7, float3 f3Position8)
    {
        float4 clipPos1 = this.ComputeShadingPosition(float4(f3Position1.xyz, 1.0f));
        float4 clipPos2 = this.ComputeShadingPosition(float4(f3Position2.xyz, 1.0f));
        float4 clipPos3 = this.ComputeShadingPosition(float4(f3Position3.xyz, 1.0f));
        float4 clipPos4 = this.ComputeShadingPosition(float4(f3Position4.xyz, 1.0f));
        float4 clipPos5 = this.ComputeShadingPosition(float4(f3Position5.xyz, 1.0f));
        float4 clipPos6 = this.ComputeShadingPosition(float4(f3Position6.xyz, 1.0f));
        float4 clipPos7 = this.ComputeShadingPosition(float4(f3Position7.xyz, 1.0f));
        float4 clipPos8 = this.ComputeShadingPosition(float4(f3Position8.xyz, 1.0f));

        float3 planeTest;

        planeTest.x = ((-clipPos1.w <= clipPos1.x && clipPos1.x <= clipPos1.w) ? 1.0f : 0.0f)
                    + ((-clipPos2.w <= clipPos2.x && clipPos2.x <= clipPos2.w) ? 1.0f : 0.0f)
                    + ((-clipPos3.w <= clipPos3.x && clipPos3.x <= clipPos3.w) ? 1.0f : 0.0f)
                    + ((-clipPos4.w <= clipPos4.x && clipPos4.x <= clipPos4.w) ? 1.0f : 0.0f)
                    + ((-clipPos5.w <= clipPos5.x && clipPos5.x <= clipPos5.w) ? 1.0f : 0.0f)
                    + ((-clipPos6.w <= clipPos6.x && clipPos6.x <= clipPos6.w) ? 1.0f : 0.0f)
                    + ((-clipPos7.w <= clipPos7.x && clipPos7.x <= clipPos7.w) ? 1.0f : 0.0f)
                    + ((-clipPos8.w <= clipPos8.x && clipPos8.x <= clipPos8.w) ? 1.0f : 0.0f);

        planeTest.y = ((-clipPos1.w <= clipPos1.y && clipPos1.y <= clipPos1.w) ? 1.0f : 0.0f)
                    + ((-clipPos2.w <= clipPos2.y && clipPos2.y <= clipPos2.w) ? 1.0f : 0.0f)
                    + ((-clipPos3.w <= clipPos3.y && clipPos3.y <= clipPos3.w) ? 1.0f : 0.0f)
                    + ((-clipPos4.w <= clipPos4.y && clipPos4.y <= clipPos4.w) ? 1.0f : 0.0f)
                    + ((-clipPos5.w <= clipPos5.y && clipPos5.y <= clipPos5.w) ? 1.0f : 0.0f)
                    + ((-clipPos6.w <= clipPos6.y && clipPos6.y <= clipPos6.w) ? 1.0f : 0.0f)
                    + ((-clipPos7.w <= clipPos7.y && clipPos7.y <= clipPos7.w) ? 1.0f : 0.0f)
                    + ((-clipPos8.w <= clipPos8.y && clipPos8.y <= clipPos8.w) ? 1.0f : 0.0f);

        planeTest.z = ((0.0f <= clipPos1.z && clipPos1.z <= clipPos1.w) ? 1.0f : 0.0f)
                    + ((0.0f <= clipPos2.z && clipPos2.z <= clipPos2.w) ? 1.0f : 0.0f)
                    + ((0.0f <= clipPos3.z && clipPos3.z <= clipPos3.w) ? 1.0f : 0.0f)
                    + ((0.0f <= clipPos4.z && clipPos4.z <= clipPos4.w) ? 1.0f : 0.0f)
                    + ((0.0f <= clipPos5.z && clipPos5.z <= clipPos5.w) ? 1.0f : 0.0f)
                    + ((0.0f <= clipPos6.z && clipPos6.z <= clipPos6.w) ? 1.0f : 0.0f)
                    + ((0.0f <= clipPos7.z && clipPos7.z <= clipPos7.w) ? 1.0f : 0.0f)
                    + ((0.0f <= clipPos8.z && clipPos8.z <= clipPos8.w) ? 1.0f : 0.0f);

        return !all(planeTest != 0.0f) ? 1.0 : 0.0;
    }
    
    stage override float ComputeClipping(InputPatch<Input, InputControlPointCount> input, const OutputPatch<Input2, 3> output, inout Constants constants)
    {
        // For now, Displacement clipping is hardcoded here, need to be able to split that! (maybe not that easy, need array or something)
        const float displacementSize = 75;
        float3 f3Position5 = input[0].PositionWS.xyz + input[0].normalWS * displacementSize;
        float3 f3Position6 = input[1].PositionWS.xyz + input[1].normalWS * displacementSize;
        float3 f3Position7 = input[2].PositionWS.xyz + input[2].normalWS * displacementSize;

        float3 normalB111 = normalize((input[0].normalWS + input[1].normalWS + input[2].normalWS) / 3.0f);
        float3 f3Position8 = constants.f3ViewB111 + normalB111 * displacementSize;

        return ComputeClippingGroup8(input[0].PositionWS.xyz, input[1].PositionWS.xyz, input[2].PositionWS.xyz, constants.f3ViewB111,
                                        f3Position5, f3Position6, f3Position7, f3Position8);
        //return ComputeClippingGroup4(input[0].PositionWS.xyz, input[1].PositionWS.xyz, input[2].PositionWS.xyz, constants.f3ViewB111, ViewProjection);
    }

    stage override void TessellateHull(InputPatch<Input, InputControlPointCount> input, uint uCPID, uint NextCPID)
    {
        streams.nextPositionVS1 = ComputeControlPoint((float3)input[uCPID].PositionWS, (float3)input[NextCPID].PositionWS, input[uCPID].normalWS);
        streams.nextPositionVS2 = ComputeControlPoint((float3)input[NextCPID].PositionWS, (float3)input[uCPID].PositionWS, input[NextCPID].normalWS);
    }

    stage override void TessellateHullConstant(InputPatch<Input, InputControlPointCount> input, const OutputPatch<Input2, 3> output, inout Constants constants)
    {
        float3 f3B300 = output[0].PositionWS.xyz,
               f3B210 = output[0].nextPositionVS1.xyz,
               f3B120 = output[0].nextPositionVS2.xyz,
               f3B030 = output[1].PositionWS.xyz,
               f3B021 = output[1].nextPositionVS1.xyz,
               f3B012 = output[1].nextPositionVS2.xyz,
               f3B003 = output[2].PositionWS.xyz,
               f3B102 = output[2].nextPositionVS1.xyz,
               f3B201 = output[2].nextPositionVS2.xyz;

        float3 f3E = (f3B210 + f3B120 + f3B021 + f3B012 + f3B102 + f3B201) / 6.0f;
        float3 f3V = (f3B003 + f3B030 + f3B300) / 3.0f;
        constants.f3ViewB111 = f3E + ((f3E - f3V) / 2.0f);

        // TODO: Clipping test ?
        //if (ComputeClipping(constants.f3ViewB111, ViewProjection))
        //{
        //    constants.tessFactor[0] = 0.0f;
        //    constants.tessFactor[1] = 0.0f;
        //    constants.tessFactor[2] = 0.0f;
        //}

        //float fB111Clipped = IsClipped(
        //    ApplyProjection(g_f4x4ViewProjection, O.f3ViewB111));
    }

    stage override void InterpolateBarycentric(const OutputPatch<Input, 3> input, in Constants constants, float3 f3BarycentricCoords)
    {
        base.InterpolateBarycentric(input, constants, f3BarycentricCoords);

        float fU = f3BarycentricCoords.x;
        float fV = f3BarycentricCoords.y;
        float fW = f3BarycentricCoords.z;

        float fUU = fU * fU;
        float fVV = fV * fV;
        float fWW = fW * fW;
        float fUU3 = fUU * 3.0f;
        float fVV3 = fVV * 3.0f;
        float fWW3 = fWW * 3.0f;

        streams.PositionWS = 
            float4((float3)input[0].PositionWS * fUU * fU
            + (float3)input[1].PositionWS * fVV * fV
            + (float3)input[2].PositionWS * fWW * fW
            + input[0].nextPositionVS1 * fUU3 * fV
            + input[0].nextPositionVS2 * fVV3 * fU
            + input[1].nextPositionVS1 * fVV3 * fW
            + input[1].nextPositionVS2 * fWW3 * fV
            + input[2].nextPositionVS1 * fWW3 * fU
            + input[2].nextPositionVS2 * fUU3 * fW
            + constants.f3ViewB111 * 6.0f * fW * fU * fV, 1.0f);
    }
};
