#pragma once

#include "common.inc"
#include "hairShadow.inc"


#define SHORTCUT_MIN_ALPHA 0.02
#define PI            3.1415926

struct HairShadeParams
{
    float3 Color;
    float Ka;
    float Kd;
    float Ks1;
    float Ex1;
    float Ks2;
    float Ex2;
};

struct LightParams
{
    float3 lightColor;
	float3 lightDir;
	float lightIntensity;
};

struct NvHair_Material 
{
	// 3 float4
	float4			rootColor; 
	float4			tipColor; 
	float4			specularColor; 

	// 4 floats (= 1 float4)
	float			diffuseBlend;
	float			diffuseScale;
	float			diffuseHairNormalWeight;
	float			_diffuseUnused_; // for alignment and future use

	// 4 floats (= 1 float4)
	float			specularPrimaryScale;
	float			specularPrimaryPower;
	float			specularPrimaryBreakup;
	float			specularNoiseScale;

	// 4 floats (= 1 float4)
	float			specularSecondaryScale;
	float			specularSecondaryPower;
	float			specularSecondaryOffset;
	float			_specularUnused_; // for alignment and future use

	// 4 floats (= 1 float4)
	float			rootTipColorWeight;
	float			rootTipColorFalloff;
	float			shadowSigma;
	float			strandBlendScale;

	// 4 floats (= 1 float4)
	float			glintStrength;
	float			glintCount;
	float			glintExponent;
	float			rootAlphaFalloff;
};

struct NvHair_ShaderAttributes
{
	float3	P;			// world coord position
	float3	T;			// world space tangent vector
	float3	N;			// world space normal vector at the root
	float4	texcoords; // texture coordinates on hair root 
						// .xy: texcoord on the hair root
						// .z: texcoord along the hair
						// .w: texcoord along the hair quad
	float3	V;			// world space view vector
	float	hairID;		// unique hair identifier
  float4 TexCoord_Fraction;
};

//--------------------------------------------------------------------------------------
// Get world position from NDC coordinates
float3 NDCToWorld(float3 vNDC)
{
    float4 pos = mul(float4(vNDC, 1.f), CAMERA_VIEWPROJ_INV);
    return pos.xyz / pos.w;
}
//--------------------------------------------------------------------------------------
// Get NDC from screen position (and depth)
// xyz in [-1, 1]
// NDC direction
//          ^
//          |
//          |
//          |
//  ------------------>
//          |
//          |
//          |

float3 ScreenPosToNDC(float3 vScreenPos)
{
   float2 xy = vScreenPos.xy;

   // add viewport offset.
   xy -= 0.5f;

   // scale by viewport to put in 0 to 1
   xy /= CAMERA_RESOLUTION;

   //shift and scale to put in -1 to 1. y is flipped accordingly
   xy.x = (2 * xy.x) - 1;
   xy.y = (1 - 2 * xy.y);

   
   return float3(xy, 2 * vScreenPos.z - 1.);
}

//--------------------------------------------------------------------------------------
// Calculate the pixel coverage of a hair strand by computing the hair width
float ComputeCoverage(float2 p0, float2 p1, float2 pixelLoc)
{
    // translate from NDC to ScreenSpace
    p0 *= (p0 + 1)* 0.5 * CAMERA_RESOLUTION;
    p1 *= (p1 + 1)* 0.5 * CAMERA_RESOLUTION;
    pixelLoc *= (pixelLoc + 1)* 0.5 * CAMERA_RESOLUTION;

    float p0dist = length(p0 - pixelLoc);
    float p1dist = length(p1 - pixelLoc);
    float hairWidth = length(p0 - p1);

    // will be 1.f if pixel outside hair, 0.f if pixel inside hair
    float outside = any(float2(step(hairWidth, p0dist), step(hairWidth, p1dist)));
    float sign = outside > 0.f ? -1.f : 1.f;

    // signed distance (positive if inside hair, negative if outside hair)
    float relDist = sign * saturate(min(p0dist, p1dist));

    // returns coverage based on the relative distance
    // 0, if completely outside hair edge
    // 1, if completely inside hair edge

    float width = max(hairWidth / 5, 1.0);
    //return (relDist + 1.f) * 0.5f;
    return outside > 0.f ? (relDist + 1.5f) * 0.5 : 0.5 + 0.5 * relDist / width;  
}

//--------------------------------------------------------------------------------------
// hair has a tiled-conical shape along its lenght.  Sort of like the following.
// 
// \    /
//  \  /
// \    /
//  \  /  
//
// The angle of the cone is the last argument, in radians.  
// It's typically in the range of 5 to 10 degrees
float3 KajiyaShading(float3 vEyeDir, float3 vLightDir, float3 vTangentDir, float coneAngleRadians = 10 * PI / 180)
{
    // in Kajiya's model: diffuse component: sin(t, l)
    float cosTL = (dot(vTangentDir, vLightDir));
    float sinTL = sqrt(1 - cosTL * cosTL);
    float diffuse = sinTL; // here sinTL is apparently larger than 0

    float cosTRL = -cosTL;
    float sinTRL = sinTL;
    float cosTE = (dot(vTangentDir, vEyeDir));
    float sinTE = sqrt(1 - cosTE * cosTE);

	// primary highlight: reflected direction shift towards root (2 * coneAngleRadians)
    float cosTRL_root = cosTRL * cos(2 * coneAngleRadians) - sinTRL * sin(2 * coneAngleRadians);
    float sinTRL_root = sqrt(1 - cosTRL_root * cosTRL_root);
    float specular_root = max(0, cosTRL_root * cosTE + sinTRL_root * sinTE);

	// secondary highlight: reflected direction shifted toward tip (3*coneAngleRadians)
    float cosTRL_tip = cosTRL * cos(-3 * coneAngleRadians) - sinTRL * sin(-3 * coneAngleRadians);
    float sinTRL_tip = sqrt(1 - cosTRL_tip * cosTRL_tip);
    float specular_tip = max(0, cosTRL_tip * cosTE + sinTRL_tip * sinTE);

	return float3(diffuse, specular_root, specular_tip);
}


#if PointLight
LightParams GetLightParams(float3 WorldPosition)
{
	LightParams pointlight;
	pointlight.lightColor = LIGHT_COLOR.xyz;

	float3 lightDir = LIGHT_POSITION.xyz - WorldPosition.xyz;
	float dis = length(lightDir);
	float disAtten = clamp(dis * LIGHT_RANGE_INV, 0.0, 1.0);
	float attenation = (1.0 - disAtten) /  ( LIGHT_ATTENUATION.x + disAtten * LIGHT_ATTENUATION.y + disAtten * disAtten * LIGHT_ATTENUATION.z );

	pointlight.lightDir = normalize(lightDir);
	// light intensity itself is encoded in light color
	pointlight.lightIntensity = 1.0 * attenation;

	return pointlight;
}
#elif DirLight
LightParams GetLightParams(float3 WorldPosition)
{
	LightParams Dirlight;
	Dirlight.lightColor = LIGHT_COLOR.xyz;
	Dirlight.lightDir = normalize(-LIGHT_GIVEN_DIRECTION.xyz);
	// light intensity itself is encoded in light color
	Dirlight.lightIntensity = 1.0;

	return Dirlight;
}
#elif SpotLight
LightParams GetLightParams(float3 WorldPosition)
{
	LightParams spotlight;

    float3 lightDir = LIGHT_POSITION.xyz - WorldPosition.xyz;
    float dis = length(lightDir);
	float disAtten = clamp(dis * LIGHT_RANGE_INV, 0.0, 1.0);
	float attenation = (1.0 - disAtten) /  ( LIGHT_ATTENUATION.x + disAtten * LIGHT_ATTENUATION.y + disAtten * disAtten * LIGHT_ATTENUATION.z );

	float attenAngle = 1.0;
	lightDir = normalize(lightDir.xyz);
	attenAngle = clamp( 1.0 - ( LIGHT_INNER_DIFF_INV.x - dot(lightDir.xyz, -LIGHT_GIVEN_DIRECTION.xyz) ) * LIGHT_INNER_DIFF_INV.y, 0.0, 1.0 );
    attenation *= attenAngle;


	spotlight.lightColor = LIGHT_COLOR.xyz;
	spotlight.lightDir = lightDir;
	// light intensity itself is encoded in light color
	spotlight.lightIntensity = 1.0 * attenation;

	return spotlight;
}
#elif NoLight
LightParams GetLightParams(float3 WorldPosition)
{
	LightParams nolight;
	nolight.lightColor = float3(0.0, 0.0, 0.0);
	nolight.lightDir = float3(0.0, 0.0, 0.0);
	nolight.lightIntensity = 0.0;

	return nolight;
}
#else
LightParams GetLightParams(float3 WorldPosition)
{
	LightParams nolight;
	nolight.lightColor = float3(0.0, 0.0, 0.0);
	nolight.lightDir = float3(0.0, 0.0, 0.0);
	nolight.lightIntensity = 0.0;

	return nolight;
}
#endif



float3 HairShading(float3 WorldPos, float3 vNDC, float3 Tangent, HairShadeParams params)
{
	float3 vPositionWS = WorldPos.xyz; 
	float3 vViewDirWS = normalize(CAMERA_WORLDPOSITION - vPositionWS);
	float3 vTangent = normalize(Tangent.xyz);

	LightParams lightParams = GetLightParams(vPositionWS);

	float3 color = float3(0.0, 0.0, 0.0);


	if( lightParams.lightIntensity > 0.f)
	{
       float3 reflectionParams  =  KajiyaShading(vViewDirWS, lightParams.lightDir, vTangent);

	   float3 diffuseColor      = params.Kd * reflectionParams.x * lightParams.lightColor * params.Color;
	   float3 primarySpecular   = params.Ks1 * pow(reflectionParams.y, params.Ex1) * lightParams.lightColor ;
	   float3 secondarySpecular = params.Ks2 * pow(reflectionParams.z, params.Ex2) * lightParams.lightColor * params.Color;

	   color = max((diffuseColor + primarySpecular + secondarySpecular ) * lightParams.lightIntensity,  float3(0, 0, 0));

	}

	return color;
}


float FetchBayerDithering(float2 ssPos)
{

int dither[8][8] = {
{ 0, 32, 8, 40, 2, 34, 10, 42}, /* 8x8 Bayer ordered dithering */
{48, 16, 56, 24, 50, 18, 58, 26}, /* pattern. Each input pixel */
{12, 44, 4, 36, 14, 46, 6, 38}, /* is scaled to the 0..63 range */
{60, 28, 52, 20, 62, 30, 54, 22}, /* before looking in this table */
{ 3, 35, 11, 43, 1, 33, 9, 41}, /* to determine the action. */
{51, 19, 59, 27, 49, 17, 57, 25},
{15, 47, 7, 39, 13, 45, 5, 37},
{63, 31, 55, 23, 61, 29, 53, 21} }; 

  int x = (int)fmod(ssPos.x, 8.); 
  int y = (int)fmod(ssPos.y, 8.); 


  return float( (dither[x][y] + 1) / 64.0 );

// int dither[4][4]  = {
//     {0,   8,   2,   10},
//     {12,  4,   14,   6},
//     {3,   11,  1,    9},
//     {15,  7,   13,  5}};


// int x = (int)fmod(ssPos.x, 4.); 
// int y = (int)fmod(ssPos.y, 4.); 

//  return float( (dither[x][y] + 1) / 16.0 );

// int dither[2][2]  = {
//     {0,   2},
//     {3,  1}};


// int x = (int)fmod(ssPos.x, 2.); 
// int y = (int)fmod(ssPos.y, 2.); 

//  return float( (dither[x][y] + 1) / 4.0 );

}

// Set NvHair_Material
NvHair_Material NvHair_SetMaterial(float4 rootColor, float4 tipColor, float specularPrimaryScale, float specularPrimaryPower, float specularSecondaryScale, float specularSecondaryPower)
{
  // media\HumanSamples\Female\Eve\gm_main.apx
  NvHair_Material mat;
  mat.rootColor = rootColor;
  mat.tipColor = tipColor;
  mat.specularColor = float4(1.0f, 1.0f, 1.0f, 1.0f);
  mat.diffuseBlend = 0.7f;
  mat.diffuseScale = 1.0f;
  mat.diffuseHairNormalWeight = 0.55f;
  mat._diffuseUnused_ = 0.0f;
  mat.specularPrimaryScale = specularPrimaryScale;
  mat.specularPrimaryPower = specularPrimaryPower;
  mat.specularPrimaryBreakup = 0.25f;
  mat.specularNoiseScale = 0.0f;
  mat.specularSecondaryScale = specularSecondaryScale;
  mat.specularSecondaryPower = specularSecondaryPower;
  mat.specularSecondaryOffset = 0.14f;
  mat._specularUnused_ = 0.0f;
  mat.rootTipColorWeight = 0.73f;
  mat.rootTipColorFalloff = 0.62f;
  mat.shadowSigma = 0.0f;
  mat.strandBlendScale = 0.0f;
  mat.glintStrength = 0.0f;
  mat.glintCount = 0.0f;
  mat.glintExponent = 0.0f;
  mat.rootAlphaFalloff = 0.04f;

  return mat;
}

// Computes blending factor between root and tip
float NvHair_GetRootTipRatio(const float s, NvHair_Material mat)
{
	float ratio = s;

	// add bias for root/tip color variation
	if (mat.rootTipColorWeight < 0.5f)
	{
		float slope = 2.0f * mat.rootTipColorWeight;
		ratio = slope * ratio;
	}
	else
	{
		float slope = 2.0f * (1.0f - mat.rootTipColorWeight) ;
		ratio = slope * (ratio - 1.0f) + 1.0f;
	}

	// modify ratio for falloff
	float slope = 1.0f / (mat.rootTipColorFalloff + 0.001f);
	ratio = saturate(0.5f + slope * (ratio - 0.5f));

	return ratio;
}

// Returns hair color from textures for this hair fragment.
float3 NvHair_SampleHairColorTex(
	NvHair_Material			mat, 
	SamplerState			rootTexSampler,
  SamplerState			tipTexSampler, 
	Texture2D					rootColorTex, 
	Texture2D					tipColorTex, 
	float4						tcfp,
  float             index)
{
	float3 rootColor = mat.rootColor.rgb;
	float3 tipColor = mat.tipColor.rgb;

  rootColor = rootColorTex.Sample(rootTexSampler, tcfp.xy).rgb;
  rootColor *= mat.rootColor.rgb;
  tipColor = tipColorTex.Sample(tipTexSampler, tcfp.xy).rgb;
  tipColor *= mat.tipColor.rgb;

  float rootRange = 1.f - tcfp.w;
  float ratio = pow(tcfp.z, index);
  // float ratio = NvHair_GetRootTipRatio(tcfp.z, mat);

  float3 hairColor = tcfp.z > rootRange ? lerp(rootColor, tipColor, ratio) : rootColor;

	return hairColor;
}

// NVHair lighting without glint
float3 NvHair_ComputeHairShading(

  float3 vPositionWS,
	float3 V, // view vector
	float3 N, // surface normal
	float3 T, // hair tangent

	float3 diffuseColor, // diffuse albedo
	float3 specularColor, // specularity

	float diffuseBlend,
	float primaryScale,
	float primaryShininess,
	float secondaryScale,
	float secondaryShininess,
	float secondaryOffset
  )
{
  LightParams lightParams = GetLightParams(vPositionWS);
  float3 color = float3(0.0, 0.0, 0.0);

  // diffuse hair shading
  float TdotL = clamp(dot(T, lightParams.lightDir), -1.0f, 1.0f);
  float diffuseSkin = max(0, dot(N, lightParams.lightDir));
  float diffuseHair = sqrt(1.0f - TdotL*TdotL);
  float diffuseSum = lerp(diffuseHair, diffuseSkin, diffuseBlend);

  // primary specular
  float3 H = normalize(V + lightParams.lightDir);
  float TdotH = clamp(dot(T, H), -1.0f, 1.0f);
  float specPrimary = sqrt(1.0f - TdotH*TdotH);
  specPrimary = pow(max(0, specPrimary), primaryShininess);

  // secondary
  TdotH = clamp(TdotH + secondaryOffset, -1.0, 1.0);
  float specSecondary = sqrt(1 - TdotH*TdotH);
  specSecondary = pow(max(0, specSecondary), secondaryShininess);

  // specular sum
  float specularSum = primaryScale * specPrimary + secondaryScale * specSecondary;
  float3 ambient = float3(0.0, 0.0, 0.0);

	float3 output = (ambient + diffuseSum * lightParams.lightColor )* diffuseColor + specularSum * (lightParams.lightColor * specularColor);

	return output;
}

float3 NvHair_ComputeHairShading(float3 vPositionWS, NvHair_ShaderAttributes attr, NvHair_Material mat, float3 hairColor)
{
	return NvHair_ComputeHairShading(
    vPositionWS,
		attr.V, attr.N, attr.T,
		hairColor, mat.specularColor.rgb,
		mat.diffuseBlend,
		mat.specularPrimaryScale, mat.specularPrimaryPower, mat.specularSecondaryScale, mat.specularSecondaryPower, mat.specularSecondaryOffset
    );
}
