// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.
/// <summary>
/// Defines the basic methods for tessellation.
/// </summary>
/// <remarks>
/// InputControlPointCount: Macro - Input control points count.
/// OutputControlPointCount: Macro - Output control points count.
/// </remarks>

#ifndef InputControlPointCount
# define InputControlPointCount 3
#endif

#ifndef OutputControlPointCount
# define OutputControlPointCount 3
#endif

shader TessellationBase : ShaderBase, TransformationBase, MaterialDomainStream, Camera, Transformation, NormalBase
{
    cbuffer PerMaterial
    {
        [Link("Tessellation.DesiredTriangleSize")]
        stage float DesiredTriangleSize = 12.0f;
    }
    
    patchstream float tessFactor[3] : SV_TessFactor;
    patchstream float insideTessFactor : SV_InsideTessFactor;
    
    override stage void GenerateNormal_VS()
    {
        base.GenerateNormal_VS();

        // Ensure that normal is normalized at every steps of the tessellation.
        streams.normalWS = normalize(streams.normalWS);
    }

    [domain("tri")]
    [partitioning("fractional_odd")]
    [outputtopology("triangle_cw")]
    [outputcontrolpoints(3)]
    [patchconstantfunc("HSConstantMain")]
    void HSMain(InputPatch<Input, InputControlPointCount> input, out Output output, uint uCPID : SV_OutputControlPointID) 
    {
        const uint NextCPID = uCPID < 2 ? uCPID + 1 : 0;

        streams = input[uCPID];

        TessellateHull(input, uCPID, NextCPID);

        // Compute screen space position of current control point and next one
        // TODO: Reuse ShadingPosition?
        // However, not sure if we can do tessellation directly through ShadingPosition interpolation (in which case we wouldn't need to do it in domain shader either)
        float2 screenPosition0 = GetScreenSpacePosition(input[uCPID].PositionWS, ViewSize.x, ViewSize.y);
        float2 screenPosition1 = GetScreenSpacePosition(input[NextCPID].PositionWS, ViewSize.x, ViewSize.y);
        
        // Screen space tessellation based on desired triangle size
        streams.oppositeEdgeLOD = distance(screenPosition0, screenPosition1) / DesiredTriangleSize;

        output = streams;
    }

    void HSConstantMain(InputPatch<Input, InputControlPointCount> input, const OutputPatch<Input2, 3> output, out Constants constants) 
    {
        constants.tessFactor[0] = output[1].oppositeEdgeLOD;
        constants.tessFactor[1] = output[2].oppositeEdgeLOD;
        constants.tessFactor[2] = output[0].oppositeEdgeLOD;
        constants.insideTessFactor = 0.33f * (constants.tessFactor[0] + constants.tessFactor[1] + constants.tessFactor[2]);

        TessellateHullConstant(input, output, constants);

        if (ComputeClipping(input, output, constants))
        {
            constants.tessFactor[0] = 0.0f;
            constants.tessFactor[1] = 0.0f;
            constants.tessFactor[2] = 0.0f;
            constants.insideTessFactor = 0.0f;
        }
    }
    
    [domain("tri")]
    void DSMain(const OutputPatch<Input, OutputControlPointCount> input, out Output output, in Constants constants, float3 f3BarycentricCoords : SV_DomainLocation) 
    {    
        InterpolateBarycentric(input, constants, f3BarycentricCoords);

        this.BaseTransformDS();

        output = streams;
    }

    stage override void BaseTransformVS()
    {
        this.PreTransformPosition();
    }

    stage void BaseTransformDS()
    {
        this.TransformPosition();
        this.PostTransformPosition();
    }
    
    stage override void TransformPosition()
    {
        base.TransformPosition();

        // Apply tessellation map, etc...
        TessellateDomain();
    }

    float2 GetScreenSpacePosition( 
                                    float4 f3Position,              // View space position of patch control point
                                    float fScreenWidth,             // Screen width
                                    float fScreenHeight             // Screen height
                                    )
    {
        float4 f4ProjectedPosition = this.ComputeShadingPosition(f3Position);
        float2 f2ScreenPosition = f4ProjectedPosition.xy / f4ProjectedPosition.w;
        f2ScreenPosition = ( f2ScreenPosition + 1.0f ) * 0.5f * float2( fScreenWidth, -fScreenHeight );
        return f2ScreenPosition;
    }
    
    stage void TessellateHull(InputPatch<Input, InputControlPointCount> input, uint uCPID, uint NextCPID) {}
    stage void TessellateHullConstant(InputPatch<Input, InputControlPointCount> input, const OutputPatch<Input2, 3> output, inout Constants constants) {}
    stage float ComputeClipping(InputPatch<Input, InputControlPointCount> input, const OutputPatch<Input2, 3> output, inout Constants constants) {}
    stage void InterpolateBarycentric(const OutputPatch<Input, 3> input, in Constants constants, float3 f3BarycentricCoords) {}
    stage void TessellateDomain() {}
};
