// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.
/// <summary>
/// Performs skinning on the position.
/// </summary>
/// <remarks>
/// SkinningMaxBones: Macro - number of threads on the X axis.
/// </remarks>
#ifndef SkinningMaxBones
# define SkinningMaxBones 4
#endif

shader TransformationSkinning : TransformationBase, PositionStream4, Transformation
{
    cbuffer PerDraw
    {
        // TODO switch to float4x3 in a way compatible with ES 2.0
        stage float4x4 BlendMatrixArray[SkinningMaxBones];
    }

    stage stream float4 BlendWeights : BLENDWEIGHT;
    stage stream uint4 BlendIndices : BLENDINDICES;

    stage stream float4x4 skinningBlendMatrix;

    float4x4 GetBlendMatrix(int index)
    {
        return BlendMatrixArray[index];
    }

    override stage void PreTransformPosition()
    {
        base.PreTransformPosition();

        streams.skinningBlendMatrix = GetBlendMatrix(streams.BlendIndices[0]) * streams.BlendWeights[0]
                                    + GetBlendMatrix(streams.BlendIndices[1]) * streams.BlendWeights[1]
                                    + GetBlendMatrix(streams.BlendIndices[2]) * streams.BlendWeights[2]
                                    + GetBlendMatrix(streams.BlendIndices[3]) * streams.BlendWeights[3];
        float4 blendPos = mul(float4(streams.Position.xyz, 1.0f), streams.skinningBlendMatrix);
        blendPos /= blendPos.w;
        streams.PositionWS = blendPos;
    }
};
