﻿ /*
 * Copyright (c) .NET Foundation and Contributors (https://dotnetfoundation.org/ & https://stride3d.net) and Silicon Studio Corp. (https://www.siliconstudio.co.jp)

 * Copyright (C) 2012 Jorge Jimenez (jorge@iryoku.com)
 * Copyright (C) 2012 Diego Gutierrez (diegog@unizar.es)
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 * 
 *    1. Redistributions of source code must retain the above copyright notice,
 *       this list of conditions and the following disclaimer.
 *
 *    2. Redistributions in binary form must reproduce the following disclaimer
 *       in the documentation and/or other materials provided with the 
 *       distribution:
 *
 *       "Uses Separable SSS. Copyright (C) 2012 by Jorge Jimenez and Diego
 *        Gutierrez."
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS ``AS 
 * IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 
 * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 
 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL COPYRIGHT HOLDERS OR CONTRIBUTORS 
 * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 
 * POSSIBILITY OF SUCH DAMAGE.
 *
 * The views and conclusions contained in the software and documentation are 
 * those of the authors and should not be interpreted as representing official
 * policies, either expressed or implied, of the copyright holders.
 */
 
/**
 *                  _______      _______      _______      _______
 *                 /       |    /       |    /       |    /       |
 *                |   (----    |   (----    |   (----    |   (----
 *                 \   \        \   \        \   \        \   \
 *              ----)   |    ----)   |    ----)   |    ----)   |
 *             |_______/    |_______/    |_______/    |_______/
 *
 *        S E P A R A B L E   S U B S U R F A C E   S C A T T E R I N G
 *
 *                           http://www.iryoku.com/
 *
 * Hi, thanks for your interest in Separable SSS!
 *
 * It's a simple shader composed of two components:
 *
 * 1) A transmittance function, 'SSSSTransmittance', which allows to calculate
 *    light transmission in thin slabs, useful for ears and nostrils. It should
 *    be applied during the main rendering pass as follows:
 *
 *        float3 t = albedo.rgb * lights[i].color * attenuation * spot;
 *        color.rgb += t * SSSSTransmittance(...)
 *
 *    (See 'Main.fx' for more details).
 *
 * 2) A simple two-pass reflectance post-processing shader, 'SSSSBlur*', which
 *    softens the skin appearance. It should be applied as a regular
 *    post-processing effect like bloom (the usual framebuffer ping-ponging):
 *
 *    a) The first pass (horizontal) must be invoked by taking the final color
 *       framebuffer as input, and storing the results into a temporal
 *       framebuffer.
 *    b) The second pass (vertical) must be invoked by taking the temporal
 *       framebuffer as input, and storing the results into the original final
 *       color framebuffer.
 *
 *    Note that This SSS filter should be applied *before* tonemapping.
 *
 * Before including SeparableSSS.h you'll have to setup the target. The
 * following targets are available:
 *         SMAA_HLSL_3
 *         SMAA_HLSL_4
 *         SMAA_GLSL_3
 *
 * For more information of what's under the hood, you can check the following
 * URLs (but take into account that the shader has evolved a little bit since
 * these publications):
 *
 * 1) Reflectance: http://www.iryoku.com/sssss/
 * 2) Transmittance: http://www.iryoku.com/translucency/
 *
 * If you've got any doubts, just contact us!
 */

namespace Stride.Rendering.SubsurfaceScattering
{
    /// <summary>
    /// The Separable Subsurface Scattering shader based on https://github.com/iryoku/separable-sss.
    /// </summary>
    shader SubsurfaceScatteringBlurShader<
        bool BlurHorizontally,
        bool KernelSizeJittering,
        bool OrthographicProjection,    // If orthographic projection is used, this is true. If perspective projections are used, this is false. 
        int MaxMaterialCount,
        int KernelLength,
        int RenderMode
        > : ImageEffectShader, Camera, Math
    {
        // Generated values:
        stage float2 ProjectionSizeOnUnitPlaneInClipSpace;
        stage float ScatteringWidths[MaxMaterialCount]; // TODO: Use Buffer<float> instead?
        stage float IterationNumber;
        stage float4x4 ViewProjectionMatrix;    // This is used for debugging only.
        
        cbuffer PerDraw
        {
            // Filter kernel layout is as follows:
            //   - Weights in the RGB channels.
            //   - Offsets in the A channel.
            stage Buffer<float4> KernelBuffer;
        }

        // These values correspond to the ones defined in SubsurfaceScatteringBlur.cs
        #define SHOW_SCATTERING_OBJECTS 1
        #define SHOW_MATERIAL_INDEX 2
        #define SHOW_SCATTERING_WIDTH 3

        //------------------------------------------------------------------------------
        // Configurable Defines
        
        // Light diffusion should occur on the surface of the object, not in a screen 
        // oriented plane. Setting SSSS_FOLLOW_SURFACE to 1 will ensure that diffusion
        // is more accurately calculated, at the expense of more memory accesses.
        #ifndef SSSS_FOLLOW_SURFACE
        #define SSSS_FOLLOW_SURFACE 0
        #endif
        
        // This define allows to specify a different source for the SSS strength
        // (instead of using the alpha channel of the color framebuffer). This is useful
        // when the alpha channel of the mian color buffer is used for something else.
        #ifndef SSSS_STRENGTH_SOURCE
        #define SSSS_STRENGTH_SOURCE (colorM.a)
        #endif

        //------------------------------------------------------------------------------
        // Porting Functions
        //SamplerState LinearSampler { Filter = MIN_MAG_LINEAR_MIP_POINT; AddressU = Clamp; AddressV = Clamp; };
        //SamplerState PointSampler { Filter = MIN_MAG_MIP_POINT; AddressU = Clamp; AddressV = Clamp; };
        #define SSSSTexture2D Texture2D
        #define SSSSSampleLevelZero(tex, coord) tex.SampleLevel(LinearSampler, coord, 0)
        #define SSSSSampleLevelZeroPoint(tex, coord) tex.SampleLevel(PointSampler, coord, 0)
        #define SSSSSample(tex, coord) SSSSSampleLevelZero(tex, coord)
        #define SSSSSamplePoint(tex, coord) SSSSSampleLevelZeroPoint(tex, coord)
        #define SSSSSampleLevelZeroOffset(tex, coord, offset) tex.SampleLevel(LinearSampler, coord, 0, offset)
        #define SSSSSampleOffset(tex, coord, offset) SSSSSampleLevelZeroOffset(tex, coord, offset)
        #define SSSSLerp(a, b, t) lerp(a, b, t)
        #define SSSSSaturate(a) saturate(a)
        #define SSSSMad(a, b, c) mad(a, b, c)
        #define SSSSMul(v, m) mul(v, m)
        #define SSSS_FLATTEN [flatten]
        #define SSSS_BRANCH [branch]
        #define SSSS_UNROLL [unroll]

        //------------------------------------------------------------------------------
        // Separable SSS Reflectance Pixel Shader

        float4 SSSSBlurPS(
                // The usual quad texture coordinates.
                float2 texcoord,

                // This is a SRGB or HDR color input buffer, which should be the final
                // color frame, resolved in case of using multisampling. The desired
                // SSS strength should be stored in the alpha channel (1 for full
                // strength, 0 for disabling SSS). If this is not possible, you an
                // customize the source of this value using SSSS_STRENGTH_SOURCE.
                //
                // When using non-SRGB buffers, you should convert to
                // linear before processing, and back again to gamma space before
                // storing the pixels (see Chapter 24 of GPU Gems 3 for more info)
                //
                // IMPORTANT: WORKING IN A NON-LINEAR SPACE WILL TOTALLY RUIN SSS!
                SSSSTexture2D colorTex,

                // The linear depth buffer of the scene, resolved in case of using
                // multisampling. The resolve should be a simple average to avoid
                // artifacts in the silhouette of objects.
                SSSSTexture2D depthTex,

                // This parameter specifies the global level of subsurface scattering
                // or, in other words, the width of the filter. It's specified in
                // world space units.
                float sssWidth,

                // Direction of the blur:
                //   - First pass:   float2(1.0, 0.0)
                //   - Second pass:  float2(0.0, 1.0)
                float2 dir,

                // This parameter indicates whether the stencil buffer should be
                // initialized. Should be set to 'true' for the first pass if not
                // previously initialized, to enable optimization of the second pass
                bool initStencil,

                // Stride: The material index used to access the correct scattering kernel.
                uint materialIndex)
        {
            // Fetch color of current pixel:
            float4 colorM = SSSSSamplePoint(colorTex, texcoord);
    
            // This is disabled because we already discard using the material index.
            /*
            // Initialize the stencil buffer in case it was not already available:
            if (initStencil) // (Checked in compile time, it's optimized away)
                if (SSSS_STRENGTH_SOURCE == 0.0) discard;
            */

            // Fetch linear depth of current pixel:
            float depthM = SSSSSamplePoint(depthTex, texcoord).r;
            depthM = CalculateViewSpaceDepth(depthM);

            // Calculate the sssWidth scale (1.0 for a unit plane sitting on the projection window):
            float2 scale = CalculateProjectionSize(depthM); // This is more accurate than the original approach, because it calculates the correct radius for non-square viewports.
   
            // Calculate the final step to fetch the surrounding pixels:
            float2 finalStep = sssWidth * scale * dir;
            finalStep *= SSSS_STRENGTH_SOURCE; // Modulate it using the alpha channel.
            //finalStep *= 1.0 / 3.0; // Divide by 3 as the kernels range from -3 to 3. // This is disabled because we bake it into the kernel on the CPU instead.
            
            if(KernelSizeJittering)
            {
                // This reduces the banding artifacts by introducing a bit of noise.
                // This might create a less mathematically correct falloff, since it messes with the sample offsets.
                // But the difference is barely noticeable.
                //finalStep *= 0.5 + GetRandomNumber(streams.ShadingPosition.xy + int(IterationNumber) * 10) * 0.5; // More noisy
                finalStep *= 0.5 + GetRandomNumber8x8(streams.ShadingPosition.xy + 2 + int(IterationNumber) * 10) * 0.5; // More regular (shows a bit of a grid pattern)
                //finalStep *= 0.5 + FastRandom(streams.ShadingPosition.xy + 2 + int(IterationNumber) * 10) * 0.5; // More noisy
            }

            // Accumulate the center sample:
            float4 colorBlurred = colorM;
            colorBlurred.rgb *= GetKernelElement(materialIndex, 0).rgb;
    
            // Accumulate the other samples:
            SSSS_UNROLL
            for (int i = 1; i < KernelLength; i++)
            {
                // Fetch color and depth for current sample:
                float2 offset = texcoord + GetKernelElement(materialIndex, i).a * finalStep;
                float4 color = SSSSSample(colorTex, offset);

            #if SSSS_FOLLOW_SURFACE == 1
                // If the difference in depth is huge, we lerp color back to "colorM":
                float depth = SSSSSample(depthTex, offset).r;
                depth = CalculateViewSpaceDepth(depth);
        
                //float s = SSSSSaturate(300.0 * distanceToProjectionWindow * sssWidth * abs(depthM - depth));  // Original version
                float s = SSSSSaturate(abs(depthM - depth) / sssWidth * 0.5); // TODO: Use a quadratic falloff or something?
                color.rgb = SSSSLerp(color.rgb, colorM.rgb, s);
            #endif

                // Accumulate:
                colorBlurred.rgb += GetKernelElement(materialIndex, i).rgb * color.rgb;
            }

            return colorBlurred;
        }

        //------------------------------------------------------------------------------

        float4 CalculateDebugView(float3 valueCenter, float3 valueTopLeft, float3 valueBottomLeft, float3 valueTopRight, float valueBottomRight, float centerImageMargin)
        {
            float3 output = float3(0.0, 0.0, 0.0);

            if(streams.TexCoord.x < 0.5)
            {
                if(streams.TexCoord.y < 0.5)
                {
                    output = valueTopLeft;
                }
                else
                {
                    output = valueBottomLeft;
                }
            }
            else
            {
                if(streams.TexCoord.y < 0.5)
                {
                    output = valueTopRight;
                }
                else
                {
                    output = valueBottomRight;
                }
            }

            if((streams.TexCoord.x > centerImageMargin)&&(streams.TexCoord.x < 1.0 - centerImageMargin)&&
               (streams.TexCoord.y > centerImageMargin)&&(streams.TexCoord.y < 1.0 - centerImageMargin))
            {
                return(float4(valueCenter, 1.0));
            }

            return(float4(output, 1.0));
        }

        float4 GenerateColorFromID(int id)
        {
            return float4((23 + id * 109) % 256,
                          (67 + id * 67) % 256,
                          (109 + id * 23) % 256,
                          255.0) / 255.0;
        }

        float4 GetKernelElement(uint materialIndex, int sampleIndex)    // TODO: Make both signed or unsigned?
        {
            return KernelBuffer.Load(materialIndex * KernelLength + sampleIndex);
        }
        /*
        // Based on this article: https://briansharpe.wordpress.com/2011/11/15/a-fast-and-simple-32bit-floating-point-hash-function/
        float Hash(float2 p, float2 offset, float domainSize)	// "p" is assumed to be an integer coordinate.
        {
	        const float inverseLargeFloat = 1.0 / 951.135664;

	        //p = p - floor(p / domain) * domain;	// Truncate the domain
	        p = p % domainSize;	// Truncate the domain (same as the above line).
	        p += offset;	// Offset to the interesting part of the noise.
	        p *= p;	// Square the vector.
	
	        return(frac(p.x * p.y * inverseLargeFloat));
        }
        */
        float GetRandomNumber4x4(int2 coordinate)
        {
            float randomNumbers[16] = 
            {
                0.3125, 0.625,  0.875,  0.25,
                0.1875, 0.4375, 0.0625, 0.75,
                1,      0.375,  0.6875, 0.9375,
                0.5,    0.5625, 0.8125, 0.125
            };
            
            int2 wrappedCoordinate = coordinate % 4;
            return randomNumbers[wrappedCoordinate.x * 4 + wrappedCoordinate.y];
        }

        float GetRandomNumber8x8(int2 coordinate)
        {
            float randomNumbers[64] = 
            {
                0.907692307692306, 0.153846153846154, 0.523076923076923, 0.769230769230768,
                0.215384615384615, 0.338461538461538, 0.030769230769230, 0.107692307692308,
                0.123076923076923, 0.492307692307692, 0.676923076923076, 0.861538461538460,
                0.692307692307692, 0.230769230769231, 0.892307692307691, 0.984615384615383,
                0.584615384615384, 0.461538461538462, 0.476923076923077, 0.015384615384615,
                0.815384615384614, 0.569230769230769, 0.092307692307692, 0.553846153846154,
                0.707692307692307, 0.307692307692308, 0.046153846153846, 0.830769230769230,
                0.384615384615385, 0.953846153846152, 0.261538461538462, 0.538461538461538,
                0.923076923076922, 0.369230769230769, 0.738461538461538, 0.753846153846153,
                0.200000000000000, 0.076923076923076, 0.415384615384615, 0.969230769230768,
                0.846153846153845, 0.169230769230769, 0.061538461538461, 0.876923076923076,
                0.600000000000000, 0.799999999999999, 0.784615384615384, 0.246153846153846,
                0.323076923076923, 0.430769230769231, 0.938461538461537, 0.138461538461538,
                0.446153846153846, 0.353846153846154, 0.292307692307692, 0.400000000000000,
                0.184615384615385, 0.723076923076922, 0.507692307692308, 0.615384615384615,
                0.276923076923077, 0.646153846153846, 0.630769230769230, 0.661538461538461
            };

            int2 wrappedCoordinate = coordinate % 8;
            return randomNumbers[wrappedCoordinate.x * 8 + wrappedCoordinate.y];
        }
        /*
        float GetRandomNumber(int2 coordinate)
        {
            return(frac(sin(dot(coordinate, float2(12.9898, 78.2332))) * 43758.5453));
        }
        */
        float2 CalculateProjectionSize(float viewSpaceDepth)
        {
            if(OrthographicProjection)
            {
                return ProjectionSizeOnUnitPlaneInClipSpace;  // Size stays the same regardless of distance.
            }

            return ProjectionSizeOnUnitPlaneInClipSpace / viewSpaceDepth;
        }
        
        float CalculateViewSpaceDepth(float nonlinearDepth) // TODO: Does this really convert to view space Z or just linearize the depth?
        {
            if(OrthographicProjection)
            {
                return(NearClipPlane + nonlinearDepth * (FarClipPlane - NearClipPlane));    // TODO: PERFORMANCE: Precompute "FarClipPlane - NearClipPlane" because it can't be inlined?    // TODO: PERFORMANCE: And do we need the addition with "NearClipPlane"? I think it's redundant in this case because the orthographic projection matrix's near plane is at 0.0.
            }

            return(ZProjection.y / (nonlinearDepth - ZProjection.x));
        }

        bool IntersectsProjectedSphere(float3 sphereWorldSpacePosition, float sphereRadiusWorldSpace, float2 clipSpaceCoordinate)
        {
            // Code for debugging the sampling radius calculation:
            float4 projectedSphereCoordinate = mul(float4(sphereWorldSpacePosition, 1.0), ViewProjectionMatrix);
            projectedSphereCoordinate.y = -projectedSphereCoordinate.y; // TODO: Why is this necessary?
            projectedSphereCoordinate.xyz /= projectedSphereCoordinate.w;
            
            float projectedSphereViewSpaceDepth = CalculateViewSpaceDepth(projectedSphereCoordinate.z);
            float2 projectedSphereScreenDimensions = CalculateProjectionSize(projectedSphereViewSpaceDepth);
            
            float2 SphereToPixelClipSpace = clipSpaceCoordinate.xy - projectedSphereCoordinate.xy;
            
            return length(SphereToPixelClipSpace / projectedSphereScreenDimensions) < sphereRadiusWorldSpace;
        }

        stage override float4 Shading()
        {
            uint materialIndex = uint(Texture2.Load(int3(streams.ShadingPosition.xy, 0.0)).r + 0.5);    // Version for material ID stored as a float.

            if(RenderMode == SHOW_SCATTERING_OBJECTS)
            {
                if(materialIndex == 0)  // Material index 0 means it's not a scattering material and we can discard it.
                {
                    return 0.0;
                }

                return 1.0;
            }
            else if(RenderMode == SHOW_MATERIAL_INDEX)
            {
                if(materialIndex == 0)  // Material index 0 means it's not a scattering material and we can discard it.
                {
                    return 0.0;
                }

                // Generate a "random" color using the material index:
                return GenerateColorFromID(materialIndex);
            }
            else if(RenderMode == SHOW_SCATTERING_WIDTH)
            {
                return fmod(ScatteringWidths[materialIndex], 1.0);
            }

            const float4 sceneColor = Texture0.Sample(LinearSampler, streams.TexCoord);

            if(materialIndex == 0)  // Material index 0 means it's not a scattering material and we can discard it.
            {
                return sceneColor;  // Return the scene color because we are doing texture ping-ponging.
            }
            
            float randomNumber = GetRandomNumber4x4(streams.ShadingPosition.xy + int(IterationNumber)); // More regular (shows a bit of a grid pattern)
            //float randomNumber = FastRandom(streams.ShadingPosition.xy + int(IterationNumber)); // More noisy
            
            /*
            float2 scale = CalculateProjectionSize(viewSpaceDepth);
            if(scale.x < 1.0)// TODO: PEFORMANCE: Turn off rotation for small kernels (mentioned in the paper)?
            {
                randomNumber = 0.0;
            }
            */
            
            float randomAngle = randomNumber * PI;  // TODO: PERFORMANCE: Rotate only by 90 degrees?
            randomAngle += IterationNumber; // TODO: Scale this vector somehow? Maybe from 0 to PI / 2 (depending on the number of passes).

            if(!BlurHorizontally)
            {
                randomAngle += PI / 2.0;
            }

	        float2 blurDirection = float2(cos(randomAngle), sin(randomAngle));
            float scatteringWidth = ScatteringWidths[materialIndex];

            float4 blurredSSS = SSSSBlurPS(streams.TexCoord,
                                           Texture0,
                                           Texture1,   // TODO: PERFORMANCE: Preconvert to view space/linearize?
                                           scatteringWidth,
                                           blurDirection,
                                           false,
                                           materialIndex);
            return blurredSSS;

            // Code for debugging the sampling radius calculation:
            /*
            float2 clipSpaceCoordinate = streams.TexCoord.xy * 2.0 - 1.0;
            if(IntersectsProjectedSphere(float3(0.0, 0.0, 0.0), 1.0, clipSpaceCoordinate))
            {
                if(OrthographicProjection)
                {
                    return sceneColor * float4(0.1, 1.0, 0.1, 1.0);
                }

                return sceneColor * float4(1.0, 0.1, 0.1, 1.0);
            }
            */
            
            // Code for debugging the SSSS in general, by visualizing the different buffers:
            /*
            if(BlurHorizontally)    // If this is the 1st pass:
            {
                return blurredSSS;
            }
            else    // If this is the 2nd pass:
            {
                const float nonlinearDepth = Texture1.Sample(Sampler, streams.TexCoord).r;
                const float viewSpaceDepth = CalculateViewSpaceDepth(nonlinearDepth);

                const float centerImageMargin = 0.25;

                return CalculateDebugView(blurredSSS,
                                          viewSpaceDepth,
                                          sceneColor.rgb, //MaxMaterialCount / 500.0,
                                          float(materialIndex) / 255.0,
                                          ScatteringWidths[materialIndex] * 10.0,
                                          centerImageMargin);
            }
            */
        }
    };
}
