// Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)
// Distributed under the MIT license. See the LICENSE.md file in the project root for more information.

namespace Stride.Rendering.Compositing
{
	/// <summary>
    /// A MSAA textures resolver shader
    /// </summary>
    internal shader MSAAResolverShader<int MSAASamples, int ResolveFilterType, float ResolveFilterDiameter> : ImageEffectShader, Math
    {
        // Reference: https://github.com/TheRealMJP/MSAAFilter

        stage float4 SvPosUnpack;
        stage float2 TextureSizeLess1;
#ifndef INPUT_MSAA_SAMPLES
		stage Texture2DMS<float4> InputTexture;
#elif INPUT_MSAA_SAMPLES == 1
		stage Texture2DMS<float4, 1> InputTexture;
#elif INPUT_MSAA_SAMPLES == 2
		stage Texture2DMS<float4, 2> InputTexture;
#elif INPUT_MSAA_SAMPLES == 4
		stage Texture2DMS<float4, 4> InputTexture;
#elif INPUT_MSAA_SAMPLES == 8
		stage Texture2DMS<float4, 8> InputTexture;
#else
        #error "Unsupported amount of MSAA texture samples."
#endif
        
        // Supported filter types (note: this must match C# source)
        static const int FilterTypes_Box = 1;
        static const int FilterTypes_Triangle = 2;
        static const int FilterTypes_Gaussian = 3;
        static const int FilterTypes_BlackmanHarris = 4;
        static const int FilterTypes_Smoothstep = 5;
        static const int FilterTypes_BSpline = 6;
        static const int FilterTypes_CatmullRom = 7;
        static const int FilterTypes_Mitchell = 8;
        static const int FilterTypes_Sinc = 9;

        // These are the sub-sample locations for the 2x, 4x, and 8x standard multisample patterns.
        // See the MSDN documentation for the D3D11_STANDARD_MULTISAMPLE_QUALITY_LEVELS enumeration.
        static const float2 SubSampleOffsets8[8] = {
                float2( 0.0625f, -0.1875f),
                float2(-0.0625f,  0.1875f),
                float2( 0.3125f,  0.0625f),
                float2(-0.1875f, -0.3125f),
                float2(-0.3125f,  0.3125f),
                float2(-0.4375f, -0.0625f),
                float2( 0.1875f,  0.4375f),
                float2( 0.4375f, -0.4375f),
        };
        static const float2 SubSampleOffsets4[4] = {
                float2(-0.125f, -0.375f),
                float2( 0.375f, -0.125f),
                float2(-0.375f,  0.125f),
                float2( 0.125f,  0.375f),
        };
        static const float2 SubSampleOffsets2[2] = {
                float2( 0.25f,  0.25f),
                float2(-0.25f, -0.25f),
        };
        static const float2 SubSampleOffsets1[1] = {
                float2(0.0f, 0.0f),
        };

		// All filtering functions assume that 'x' is normalized to [0, 1], where 1 == FilteRadius

		float FilterBox(in float x)
		{
			return x <= 1.0f;
		}

		static float FilterTriangle(in float x)
		{
			return saturate(1.0f - x);
		}

		static float FilterGaussian(in float x)
		{
			static const float sigma = 0.5f;
			static const float g = 1.0f / sqrt(2.0f * 3.14159f * sigma * sigma);
			return (g * exp(-(x * x) / (2 * sigma * sigma)));
		}
        
        float FilterCubic(in float x, in float B, in float C)
		{
			float y = 0.0f;
			float x2 = x * x;
			float x3 = x * x * x;
			if(x < 1)
				y = (12 - 9 * B - 6 * C) * x3 + (-18 + 12 * B + 6 * C) * x2 + (6 - 2 * B);
			else if (x <= 2)
				y = (-B - 6 * C) * x3 + (6 * B + 30 * C) * x2 + (-12 * B - 48 * C) * x + (8 * B + 24 * C);

			return y / 6.0f;
		}

		float FilterSinc(in float x)
		{
			float s;

			x *= ResolveFilterDiameter;

			if(x < 0.001f)
				s = 1.0f;
			else
				s = sin(x * PI) / (x * PI);

			return s;
		}

		float FilterBlackmanHarris(in float x)
		{
			x = 1.0f - x;

			static const float a0 = 0.35875f;
			static const float a1 = 0.48829f;
			static const float a2 = 0.14128f;
			static const float a3 = 0.01168f;
			return saturate(a0 - a1 * cos(PI * x) + a2 * cos(2 * PI * x) - a3 * cos(3 * PI * x));
		}

		float FilterSmoothstep(in float x)
		{
			return 1.0f - smoothstep(0.0f, 1.0f, x);
		}

		float Filter(in float x)
		{
			// Cubic filters naturually work in a [-2, 2] domain. For the resolve case we
			// want to rescale the filter so that it works in [-1, 1] instead
			float cubicX = x * 2.0f;

			if(ResolveFilterType == FilterTypes_Box)
				return FilterBox(x);
			else if(ResolveFilterType == FilterTypes_Triangle)
				return FilterTriangle(x);
			else if(ResolveFilterType == FilterTypes_Gaussian)
				return FilterGaussian(x);
			else if(ResolveFilterType == FilterTypes_BlackmanHarris)
				return FilterBlackmanHarris(x);
			else if(ResolveFilterType == FilterTypes_Smoothstep)
				return FilterSmoothstep(x);
			else if(ResolveFilterType == FilterTypes_BSpline)
				return FilterCubic(cubicX, 1.0, 0.0f);
			else if(ResolveFilterType == FilterTypes_CatmullRom)
				return FilterCubic(cubicX, 0, 0.5f);
			else if(ResolveFilterType == FilterTypes_Mitchell)
				return FilterCubic(cubicX, 1 / 3.0f, 1 / 3.0f);
			else if(ResolveFilterType == FilterTypes_Sinc)
				return FilterSinc(x);
			else
				return 1.0f;
		}

		// 1:-1 to 0:TextureSize
        int2 ClipPosToUvPos(float2 clipPos)
        {
            return (int2)(clipPos * SvPosUnpack.xy + SvPosUnpack.zw);
        }

        override stage float4 Shading()
        {
            float4 output = 0;
			int2 pixelPos = ClipPosToUvPos(streams.Position.xy);

            // Special case for single sample resolving
            if(MSAASamples == 1)
            {
                output = InputTexture.Load(pixelPos, 0);
            }
            else
            {
                float4 sum = 0.0f;
			    float totalWeight = 0.0f;

                static const int SampleRadius = (int)((ResolveFilterDiameter / 2.0f) + 0.499f);

			    for(int y = -SampleRadius; y <= SampleRadius; y++)
			    {
				    for(int x = -SampleRadius; x <= SampleRadius; x++)
				    {
					    float2 sampleOffset = float2(x, y);
					    float2 samplePos = pixelPos + sampleOffset;
					    samplePos = clamp(samplePos, 0, TextureSizeLess1);

					    [unroll]
					    for(uint subSampleIdx = 0; subSampleIdx < MSAASamples; subSampleIdx++)
					    {
						    float2 subSampleOffset;
                            if(MSAASamples == 8)
                                subSampleOffset = SubSampleOffsets8[subSampleIdx].xy;
                            else if(MSAASamples == 4)
                                subSampleOffset = SubSampleOffsets4[subSampleIdx].xy;
                            else if(MSAASamples == 2)
                                subSampleOffset = SubSampleOffsets2[subSampleIdx].xy;
                            else
                                subSampleOffset = SubSampleOffsets1[subSampleIdx].xy;
						    float2 sampleDist = abs(sampleOffset + subSampleOffset) / (ResolveFilterDiameter / 2.0f);

						    bool useSample = all(sampleDist <= 1.0f);
						    if(useSample)
						    {
							    float4 sampleValue = InputTexture.Load(samplePos, subSampleIdx);
							    float weight = Filter(sampleDist.x) * Filter(sampleDist.y);
							    float sampleLum = LuminanceUtils.Luma(sampleValue.rgb);

							    // Use inverse luminance filtering (better quality on highlights)
							    weight *= 1.0f / (1.0f + sampleLum);

							    sum += sampleValue * weight;
							    totalWeight += weight;
						    }
					    }
				    }
			    }

			    output = sum / max(totalWeight, 0.00001f);
            }

            return output;
        }
    };
}
