/************************************************************************************

Depthkit Unity SDK License v1
Copyright 2016-2024 Simile Inc dba Scatter. All Rights reserved.  

Licensed under the the Simile Inc dba Scatter ("Scatter")
Software Development Kit License Agreement (the "License"); 
you may not use this SDK except in compliance with the License, 
which is provided at the time of installation or download, 
or which otherwise accompanies this software in either electronic or hard copy form.  

You may obtain a copy of the License at http://www.depthkit.tv/license-agreement-v1

Unless required by applicable law or agreed to in writing, 
the SDK distributed under the License is distributed on an "AS IS" BASIS, 
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
See the License for the specific language governing permissions and limitations under the License. 

************************************************************************************/

#pragma kernel GaussianBlur3x3x3 TILE_BORDER=1
#pragma kernel GaussianBlur5x5x5 TILE_BORDER=2
#pragma kernel GaussianBlur7x7x7 TILE_BORDER=3
#pragma kernel GaussianBlur9x9x9 TILE_BORDER=4

#include "Packages/nyc.scatter.depthkit.core/Runtime/Resources/Shaders/Includes/Utils.cginc"
#include "Packages/nyc.scatter.depthkit.studio/Runtime/Resources/Shaders/DataSource/StudioMeshSourceCommon.cginc"

StructuredBuffer<float> _PingData;
RWStructuredBuffer<float> _PongData;

float3 _DataSize;
float3 _Axis;
float _GaussianExponential;
float _GaussianNormalization;

static const uint BLOCK_SIZE = 8; // 8x8x8 dispatch tile
static const uint BLOCK_THREAD_COUNT = BLOCK_SIZE * BLOCK_SIZE * BLOCK_SIZE; //512 dispatch threads
static const int KERNEL_WIDTH = TILE_BORDER;
static const uint TILE_SIZE = BLOCK_SIZE + TILE_BORDER * 2; 
static const uint TILE_VOXEL_COUNT = TILE_SIZE * TILE_SIZE * TILE_SIZE; 

groupshared float g_samples[TILE_VOXEL_COUNT];

void GaussianBlur(uint3 GroupId, uint3 GroupThreadId, uint GroupIndex)
{
    uint ind;

    const int3 tile_upperleft = int3(GroupId * uint3(BLOCK_SIZE, BLOCK_SIZE, BLOCK_SIZE)) - TILE_BORDER;

    [unroll]
    for (ind = GroupIndex; ind < TILE_VOXEL_COUNT; ind += BLOCK_THREAD_COUNT)
    {
        int3 voxel = tile_upperleft + (int3) toCoord3D(ind, uint3(TILE_SIZE, TILE_SIZE, TILE_SIZE));
        float sdfSample = _PingData[toIndex3DClamp(voxel, (uint3) _DataSize)];
        g_samples[ind] = sdfSample;
    }

    GroupMemoryBarrierWithGroupSync();
    
    int3 coord = GroupThreadId + TILE_BORDER;
    
    const int3 voxel = tile_upperleft + coord;
    //make sure there's a pixel here
    if (voxel.x < 0 || voxel.y < 0 || voxel.z < 0 || voxel.x >= _DataSize.x || voxel.y >= _DataSize.y || voxel.z >= _DataSize.z)
        return;

    float r = 0.0;

    [unroll]
    for (int kernel_ind = -KERNEL_WIDTH; kernel_ind < KERNEL_WIDTH + 1; kernel_ind++)
    {
        int3 sampleCoord = ((int3) _Axis * kernel_ind) + coord;
        float gaussian = _GaussianNormalization * exp(_GaussianExponential * float(kernel_ind * kernel_ind));
        r += g_samples[toIndex3D(sampleCoord, (uint3) TILE_SIZE)] * gaussian;
    }

    //only use the final value if this voxel wasn't blurred with invalid voxels
    _PongData[toIndex3D(voxel, (uint3) _DataSize)] = r > Invalid_Sdf_compare ? g_samples[toIndex3D(coord, (uint3) TILE_SIZE)] : r;
}

[numthreads(BLOCK_SIZE, BLOCK_SIZE, BLOCK_SIZE)]
void GaussianBlur3x3x3(uint3 GroupId : SV_GroupID, uint3 GroupThreadId : SV_GroupThreadID, uint GroupIndex : SV_GroupIndex)
{    
    GaussianBlur(GroupId, GroupThreadId, GroupIndex);
}

[numthreads(BLOCK_SIZE, BLOCK_SIZE, BLOCK_SIZE)]
void GaussianBlur5x5x5(uint3 GroupId : SV_GroupID, uint3 GroupThreadId : SV_GroupThreadID, uint GroupIndex : SV_GroupIndex)
{
    GaussianBlur(GroupId, GroupThreadId, GroupIndex);
}

[numthreads(BLOCK_SIZE, BLOCK_SIZE, BLOCK_SIZE)]
void GaussianBlur7x7x7(uint3 GroupId : SV_GroupID, uint3 GroupThreadId : SV_GroupThreadID, uint GroupIndex : SV_GroupIndex)
{
    GaussianBlur(GroupId, GroupThreadId, GroupIndex);
}

[numthreads(BLOCK_SIZE, BLOCK_SIZE, BLOCK_SIZE)]
void GaussianBlur9x9x9(uint3 GroupId : SV_GroupID, uint3 GroupThreadId : SV_GroupThreadID, uint GroupIndex : SV_GroupIndex)
{
    GaussianBlur(GroupId, GroupThreadId, GroupIndex);
}