// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.
/// <summary>
/// Base compute shader.
/// </summary>
/// <remarks>
/// ThreadNumberX: Macro - number of threads on the X axis.
/// ThreadNumberY: Macro - number of threads on the Y axis.
/// ThreadNumberZ: Macro - number of threads on the Z axis.
/// </remarks>
#ifndef ThreadNumberX
# define ThreadNumberX 1
#endif
#ifndef ThreadNumberY
# define ThreadNumberY 1
#endif
#ifndef ThreadNumberZ
# define ThreadNumberZ 1
#endif
shader ComputeShaderBase
{
    stage stream uint3 GroupId : SV_GroupID;
    stage stream uint3 DispatchThreadId : SV_DispatchThreadID;
    stage stream uint3 GroupThreadId : SV_GroupThreadID;
    stage stream uint  GroupIndex : SV_GroupIndex;

    stage stream uint3 ThreadGroupCount;
    stage stream uint  ThreadCountPerGroup;
    stage stream uint  ThreadGroupIndex;
    
    stage stream int ThreadCountX;
    stage stream int ThreadCountY;
    stage stream int ThreadCountZ;

    cbuffer PerDispatch {
        // This variable provides the ThreadGroupCount from the dispatch method
        [Link("ComputeShaderBase.ThreadGroupCountGlobal")]
        stage int3 ThreadGroupCountGlobal;
    };

    [numthreads(ThreadNumberX, ThreadNumberY, ThreadNumberZ)]
    void CSMain() 
    {
        // give access to ThreadCounts everywhere in the shader
        streams.ThreadCountX = ThreadNumberX;
        streams.ThreadCountY = ThreadNumberY;
        streams.ThreadCountZ = ThreadNumberZ;
        
        // Predefined variable that gives the number of threads per group
        streams.ThreadCountPerGroup = ThreadNumberX * ThreadNumberY * ThreadNumberZ;

        // Copy the global variable to the stream to make it consistent
        streams.ThreadGroupCount = ThreadGroupCountGlobal;

        // Calculate a unique thread group index, an index that identifies a unique group of thread from a dispatch
        streams.ThreadGroupIndex = (streams.GroupId.z * streams.ThreadGroupCount.y + streams.GroupId.y) * streams.ThreadGroupCount.x  + streams.GroupId.x;

        Compute();
    }

    void Compute() 
    {
    }

    bool IsFirstThreadOfGroup()
    {
        return streams.GroupThreadId.x == 0 && streams.GroupThreadId.y == 0 && streams.GroupThreadId.z == 0;
    }
};
