#pragma once

#include "common.inc"


#define FLOAT_EPSILON 1e-7


StructuredBuffer<float4> g_HairVertexPositions;
StructuredBuffer<float4> g_HairVertexTangents;
StructuredBuffer<float2> g_HairStrandUV;


struct HairStrandParams
{
    float4        MatBaseColor;
    float4        MatTipColor;

    float         TipPercentage;
    float         FiberRatio;
    float         FiberRadius;

    int           NumVerticesPerStrand;
    int           EnableThinTip;
};

struct HairVertex
{
    float4 Position;
    float4 Tangent;
    float4 p0;
    float4 p1;
    float4 StrandColor;
    float4 ShadowCoord;
    float4 TexCoord_Fraction;
    // .xy TexCoord
    // .z fractionOfStrand
    // .w TipPercentage
};
//--------------------------------------------------------------------------------------
float2 safe_normalize(float2 vec)
{
    float len = length(vec);
    return len >= FLOAT_EPSILON ? (vec * rcp(len)) : float2(0, 0);
}

//--------------------------------------------------------------------------------------
float3 safe_normalize(float3 vec)
{
    float len = length(vec);
    return len >= FLOAT_EPSILON ? (vec * rcp(len)) : float3(0, 0, 0);
}
//-------------------------------------------------------------------------------------------------
float3 GetStrandColor(int index, float fractionOfStrand, HairStrandParams params, Texture2D BaseAlbedoTexture, SamplerState BaseAlbedoTexture_sampler)
{
   
   float3 _rootColor;
   float3 _tipColor;

   float2 texCoord = g_HairStrandUV[index / params.NumVerticesPerStrand].xy;
   _rootColor  = BaseAlbedoTexture.SampleLevel(BaseAlbedoTexture_sampler, texCoord, 0).rgb;
   _rootColor *=  params.MatBaseColor.rgb;

   _tipColor = params.MatTipColor.rgb;

   float rootRange = 1.f - params.TipPercentage;
   float3 color = fractionOfStrand > rootRange ? lerp(_rootColor, _tipColor, fractionOfStrand) : _rootColor;
   return color;

}


HairVertex ExpendHairVertex(uint vertexId, HairStrandParams params, Texture2D BaseAlbedoTexture, SamplerState BaseAlbedoTexture_sampler)
{
    // Access the current line segment
    // vertexId is actually the indexed vertex id when indexed triangles are used
    uint index = vertexId / 2;

    float3 wPos   = g_HairVertexPositions[index].xyz;
    float3 wT   = g_HairVertexTangents[index].xyz;
    float2 tc = g_HairStrandUV[index / params.NumVerticesPerStrand].xy;
  

    uint NumVerticesPerStrand = params.NumVerticesPerStrand;
    
    // Get hair strand thickness
    uint indexInStrand = fmod(index, NumVerticesPerStrand);
	float fractionOfStrand = (float)indexInStrand / (NumVerticesPerStrand - 1);
    float ratio = params.EnableThinTip > 0.0 ? lerp(1.0, params.FiberRatio, fractionOfStrand) : 1.0; //need length of full strand vs the length of this point on the strand. 	
    

	// Calculate right and projected right vectors
	float3 right = safe_normalize(cross(wT.xyz, safe_normalize(wPos.xyz - CAMERA_WORLDPOSITION)));
    float2 proj_right = safe_normalize(WorldToClipPos(float4(right, 0)).xy);
	
    
    // We always to to expand for faster hair AA, we may want to gauge making this adjustable
    float expandPixels = 1.0;

    // Calculate the negative and positive offset screenspace positions
    float4 hairEdgePositions[2]; // 0 is negative, 1 is positive
	hairEdgePositions[0] = float4(wPos.xyz + -1.0 * right * ratio * params.FiberRadius, 1.0);
    hairEdgePositions[1] = float4(wPos.xyz + 1.0 * right * ratio * params.FiberRadius, 1.0);
	hairEdgePositions[0] = WorldToClipPos(hairEdgePositions[0]);
	hairEdgePositions[1] = WorldToClipPos(hairEdgePositions[1]);

	// Gonna hi-jack Tangent.w (unused) and add a .w component to strand color to store a strand UV
    float2 strandUV;
    strandUV.x = (vertexId & 0x01) ? 0.f : 1.f;
    strandUV.y = fractionOfStrand;

	HairVertex Output;
	float fDirIndex = (vertexId & 0x01) ? -1.0 : 1.0;
	Output.Position = ((vertexId & 0x01) ? hairEdgePositions[0] : hairEdgePositions[1]) + fDirIndex * float4(proj_right * expandPixels / CAMERA_RESOLUTION.y, 0.0f, 0.0f) * ((vertexId & 0x01) ? hairEdgePositions[0].w : hairEdgePositions[1].w);
    Output.Tangent = float4(wT.xyz, strandUV.x);

	// Output.p0p1 = float4(hairEdgePositions[0].xy / max(hairEdgePositions[0].w, FLOAT_EPSILON), hairEdgePositions[1].xy / max(hairEdgePositions[1].w, FLOAT_EPSILON));
    // interpolation in screenspace has intolerable error
    // need to pass world space coordinates and transfer to screen space in frag
    float4 hairEdgeWpos[2];
    hairEdgeWpos[0] = mul(hairEdgePositions[0].xyzw / hairEdgePositions[0].w, CAMERA_VIEWPROJ_INV);
    hairEdgeWpos[1] = mul(hairEdgePositions[1].xyzw / hairEdgePositions[1].w, CAMERA_VIEWPROJ_INV);
    hairEdgeWpos[0] /= hairEdgeWpos[0].w;
    hairEdgeWpos[1] /= hairEdgeWpos[1].w;

    Output.p0 = hairEdgeWpos[0];
    Output.p1 = hairEdgeWpos[1];

    Output.StrandColor = float4(GetStrandColor(index, fractionOfStrand, params, BaseAlbedoTexture, BaseAlbedoTexture_sampler), strandUV.y);
    Output.TexCoord_Fraction = float4(tc, fractionOfStrand, 1.0f);
    
    

#if ShadowOn
   float2 proj_right_shadow = safe_normalize(mul(mul(float4(right, 0), LIGHT_CAMERA_VIEW), LIGHT_CAMERA_PROJECTION).xy);
   float expandPixels_shadow = 1.;
   hairEdgePositions[0] = float4(wPos.xyz + -1.0 * right * ratio * params.FiberRadius, 1.0);
   hairEdgePositions[1] = float4(wPos.xyz + 1.0 * right * ratio * params.FiberRadius, 1.0);
   hairEdgePositions[0] = mul(mul(hairEdgePositions[0], LIGHT_CAMERA_VIEW), LIGHT_CAMERA_PROJECTION);
   hairEdgePositions[1] = mul(mul(hairEdgePositions[1], LIGHT_CAMERA_VIEW), LIGHT_CAMERA_PROJECTION);
   
    Output.ShadowCoord = ((vertexId & 0x01) ? hairEdgePositions[0] : hairEdgePositions[1]) + fDirIndex * float4(proj_right_shadow * expandPixels_shadow * LIGHT_PARAM.z, 0.0f, 0.0f) * ((vertexId & 0x01) ? hairEdgePositions[0].w : hairEdgePositions[1].w);
#else
    Output.ShadowCoord = float4(0., 0., 0., 1.);
#endif
	return Output;
}


HairVertex ExpendHairShadowVertex(uint vertexId, HairStrandParams params)
{
    // Access the current line segment
    // vertexId is actually the indexed vertex id when indexed triangles are used
    uint index = vertexId / 2;

    float3 wPos   = g_HairVertexPositions[index].xyz;
    float3 wT   = g_HairVertexTangents[index].xyz;

    uint NumVerticesPerStrand = params.NumVerticesPerStrand;
    
    // Get hair strand thickness
    uint indexInStrand = fmod(index, NumVerticesPerStrand);
	float fractionOfStrand = (float)indexInStrand / (NumVerticesPerStrand - 1);
    float ratio = params.EnableThinTip > 0.0 ? lerp(1.0, params.FiberRatio, fractionOfStrand) : 1.0; //need length of full strand vs the length of this point on the strand. 	
    

	// Calculate right and projected right vectors
	float3 right = safe_normalize(cross(wT.xyz, safe_normalize(wPos.xyz - CAMERA_WORLDPOSITION)));
    float2 proj_right = safe_normalize(WorldToClipPos(float4(right, 0)).xy);
	
    
    // We always to to expand for faster hair AA, we may want to gauge making this adjustable
    float expandPixels = 1.;

    // Calculate the negative and positive offset screenspace positions
    float4 hairEdgePositions[2]; // 0 is negative, 1 is positive
	hairEdgePositions[0] = float4(wPos.xyz + -1.0 * right * ratio * params.FiberRadius, 1.0);
    hairEdgePositions[1] = float4(wPos.xyz + 1.0 * right * ratio * params.FiberRadius, 1.0);
	hairEdgePositions[0] = WorldToClipPos(hairEdgePositions[0]);
	hairEdgePositions[1] = WorldToClipPos(hairEdgePositions[1]);

	// Gonna hi-jack Tangent.w (unused) and add a .w component to strand color to store a strand UV
    float2 strandUV;
    strandUV.x = (vertexId & 0x01) ? 0.f : 1.f;
    strandUV.y = fractionOfStrand;

	HairVertex Output = (HairVertex)0;
	float fDirIndex = (vertexId & 0x01) ? -1.0 : 1.0;
	Output.Position = ((vertexId & 0x01) ? hairEdgePositions[0] : hairEdgePositions[1]) + fDirIndex * float4(proj_right * expandPixels / CAMERA_RESOLUTION.y, 0.0f, 0.0f) * ((vertexId & 0x01) ? hairEdgePositions[0].w : hairEdgePositions[1].w);

	return Output;
}