CGPROGRAM
#define NUM_THREADS 8
#include "fluid_simulation.inc"

float3 _UpVel;
float _AmbientT, _Sigma, _Kappa;

float _DeltaTime;
float _Radius;

float _DissipateV;

float2 _Dissipate;
float2 _Amount;
float3 _Pos;

RWStructuredBuffer<float4> _Obstacles;

RWStructuredBuffer<float2> _Write2f;
RWStructuredBuffer<float2> _Read2f;

RWStructuredBuffer<float4> _Write3f;
RWStructuredBuffer<float4> _Read3f;

void SampleBilinearDTV(float3 uv, float4 size, int idx, int3 id)
{	
  int x = uv.x;
  int y = uv.y;
  int z = uv.z;
	
  int X = size.x;
  int XY = size.w;
  
  float3 fxyz = frac(uv);
  float3 invfxyz = 1.0f - fxyz;
  int3 pXYZ = min(size.xyz - 1, uv+1);
  
  int X01 = x + y * X + z * XY;
  int X02 = pXYZ.x + y * X + z * XY;
  
  int X11 = x + y * X + pXYZ.z * XY;
  int X12 = pXYZ.x + y * X + pXYZ.z * XY;
  
  int X21 = x + pXYZ.y * X + z * XY;
  int X22 = pXYZ.x + pXYZ.y * X + z * XY;
  
  int X31 = x + pXYZ.y * X + pXYZ.z * XY;
  int X32 = pXYZ.x + pXYZ.y * X + pXYZ.z * XY;
	
	float2 x0d = _Read2f[X01] * invfxyz.x + _Read2f[X02] * fxyz.x;
	float2 x1d = _Read2f[X11] * invfxyz.x + _Read2f[X12] * fxyz.x;
	
	float2 x2d = _Read2f[X21] * invfxyz.x + _Read2f[X22] * fxyz.x;
	float2 x3d = _Read2f[X31] * invfxyz.x + _Read2f[X32] * fxyz.x;
	
	float2 z0d = x0d * invfxyz.z + x1d * fxyz.z;
	float2 z1d = x2d * invfxyz.z + x3d * fxyz.z;
	
	float2 resd = z0d * invfxyz.y + z1d * fxyz.y;
	resd = max(float2(0, 0), resd * _Dissipate);
	
	float3 x0v = _Read3f[X01].xyz * invfxyz.x + _Read3f[X02].xyz * fxyz.x;
	float3 x1v = _Read3f[X11].xyz * invfxyz.x + _Read3f[X12].xyz * fxyz.x;
	
	float3 x2v = _Read3f[X21].xyz * invfxyz.x + _Read3f[X22].xyz * fxyz.x;
	float3 x3v = _Read3f[X31].xyz * invfxyz.x + _Read3f[X32].xyz * fxyz.x;
	
	float3 z0v = x0v * invfxyz.z + x1v * fxyz.z;
	float3 z1v = x2v * invfxyz.z + x3v * fxyz.z;
	
	float3 resv = (z0v * invfxyz.y + z1v * fxyz.y) * _DissipateV;
	
	if (resd.x > _AmbientT)
    resv += (_DeltaTime * (resd.x - _AmbientT) * _Sigma - resd.y * _Kappa) * _UpVel;
	
	_Write3f[idx] = float4(resv, 0.0);

	float3 wpos = GridCord2WorldPos(float3(id));
	float d = distance(_Pos, wpos);

	float impulse = 0;

	if(d < _Radius) 
	{
		float a = (_Radius - d) * 0.5;
		impulse = min(a, 1.0);
	} 
	
	_Write2f[idx] = max(float2(0,0), lerp(resd, _Amount, float2(impulse, impulse)));
	
}

#pragma compute Advect

[numthreads(NUM_THREADS,NUM_THREADS,NUM_THREADS)]
void Advect(int3 id : SV_DispatchThreadID)
{
	int idx = dot(id, float3(1, _Size.xw));
	
  if (_Obstacles[idx].x > 0.1f)
  {
    _Write2f[idx] = float2(0.0f, 0.0f);
    _Write3f[idx] = float4(0.0f, 0.0f, 0.0f, 0.0);
    return;
  }

  float3 uv = float3(id) - _DeltaTime * _Forward * _Read3f[idx].xyz;
  
  SampleBilinearDTV(uv, _Size, idx, id);

}

ENDCG

