CGPROGRAM

#include "hair_common.inc"

#define TRESSFX_MAX_NUM_COLLISION_CAPSULES 8


float4 g_centerAndRadius0[TRESSFX_MAX_NUM_COLLISION_CAPSULES];
float4 g_centerAndRadius1[TRESSFX_MAX_NUM_COLLISION_CAPSULES];
float g_numCollisionCapsules;

struct CollisionCapsule
{
    float4 p0; // xyz = position of capsule 0, w = radius 0
    float4 p1; // xyz = position of capsule 1, w = radius 1
};

int GetLengthConstraintIterations()
{
    return (int)g_SimInts.x;
}

float2 ConstraintMultiplier(float4 particle0, float4 particle1)
{
    if (IsMovable(particle0))
    {
        if (IsMovable(particle1))
            return float2(0.5, 0.5);
        else
            return float2(1, 0);
    }
    else
    {
        if (IsMovable(particle1))
            return float2(0, 1);
        else
            return float2(0, 0);
    }
}

void ApplyDistanceConstraint(inout float4 pos0, inout float4 pos1, float targetDistance, float stiffness = 1.0)
{
    float3 delta = pos1.xyz - pos0.xyz;
    float distance = max(length(delta), 1e-7);
    float stretching = 1 - targetDistance / distance;
    delta = stretching * delta;
    float2 multiplier = ConstraintMultiplier(pos0, pos1);

    pos0.xyz += multiplier[0] * delta * stiffness;
    pos1.xyz -= multiplier[1] * delta * stiffness;
}



//--------------------------------------------------------------------------------------
//
//  CapsuleCollision
//
//  Moves the position based on collision with capsule
//
//--------------------------------------------------------------------------------------
bool CapsuleCollision(float4 curPosition, float4 oldPosition, inout float3 newPosition, CollisionCapsule cc, float friction = 0.4f)
{
    const float radius0 = cc.p0.w;
    const float radius1 = cc.p1.w;
    newPosition = curPosition.xyz;

    if ( !IsMovable(curPosition) )
        return false;

    float3 segment = cc.p1.xyz - cc.p0.xyz;
    float3 delta0 = curPosition.xyz - cc.p0.xyz;
    float3 delta1 = cc.p1.xyz - curPosition.xyz;

    float dist0 = dot(delta0, segment);
    float dist1 = dot(delta1, segment);

    // colliding with sphere 1
    if (dist0 < 0.f )
    {
        if ( dot(delta0, delta0) < radius0 * radius0)
        {
            float3 n = normalize(delta0);
            newPosition = radius0 * n + cc.p0.xyz;
            return true;
        }

        return false;
    }

    // colliding with sphere 2
    if (dist1 < 0.f )
    {
        if ( dot(delta1, delta1) < radius1 * radius1)
        {
            float3 n = normalize(-delta1);
            newPosition = radius1 * n + cc.p1.xyz;
            return true;
        }

        return false;
    }

    // colliding with middle cylinder
    float3 x = (dist0 * cc.p1.xyz + dist1 * cc.p0.xyz) / (dist0 + dist1);
    float3 delta = curPosition.xyz - x;

    float radius_at_x = (dist0 * radius1 + dist1 * radius0) / (dist0 + dist1);

    if ( dot(delta, delta) < radius_at_x * radius_at_x)
    {
        float3 n = normalize(delta);
        float3 vec = curPosition.xyz - oldPosition.xyz;
        float3 segN = normalize(segment);
        float3 vecTangent = dot(vec, segN) * segN;
        float3 vecNormal = vec - vecTangent;
        newPosition = oldPosition.xyz + friction * vecTangent + (vecNormal + radius_at_x * n - delta);
        return true;
    }

    return false;
}


// Resolve hair vs capsule collisions. To use this, set TRESSFX_COLLISION_CAPSULES to 1 in both hlsl and cpp sides. 
bool ResolveCapsuleCollisions(inout float4 curPosition, float4 oldPos, float friction = 0.4f)
{
    bool bAnyColDetected = false;
     float3 newPos;

        for (int i = 0; i < (uint)g_numCollisionCapsules ; i++)
        {
            float4 c0 = float4(g_centerAndRadius0[i].xyz,1.0);
            float4 c1 = float4(g_centerAndRadius1[i].xyz,1.0);

            float3 center0 = c0;
            float3 center1 = c1;

            CollisionCapsule cc;
            cc.p0.xyz = center0;
            cc.p0.w = g_centerAndRadius0[i].w;
            cc.p1.xyz = center1;
            cc.p1.w = g_centerAndRadius1[i].w;

            bool bColDetected = CapsuleCollision(curPosition, oldPos, newPos, cc, friction);

            if (bColDetected)
                curPosition.xyz = newPos;

            bAnyColDetected = bColDetected ? true : bAnyColDetected;
        }

    return bAnyColDetected;
}


#pragma compute LengthConstraints

[numthreads(THREAD_GROUP_SIZE,1,1)]
void LengthConstraints(
    uint GIndex : SV_GroupIndex,
    uint3 GId : SV_GroupID,
    uint3 DTid : SV_DispatchThreadID)
{
	uint globalStrandIndex, localStrandIndex, globalVertexIndex, localVertexIndex, numVerticesInTheStrand, indexForSharedMem;
	CalcIndicesInVertexLevelMaster(GIndex, GId.x, globalStrandIndex, localStrandIndex, globalVertexIndex, localVertexIndex, numVerticesInTheStrand, indexForSharedMem);

    //--------
    //Wind
    //--------
    //...

    uint numOfStrandsPerThreadGroup = g_NumOfStrandsPerThreadGroup;

    //------------------------------
    // Copy data into shared memory
    //------------------------------
    sharedPos[indexForSharedMem] = g_HairVertexPositions[globalVertexIndex];
    sharedLength[indexForSharedMem] = g_HairRestLength[globalVertexIndex] ;
    GroupMemoryBarrierWithGroupSync();


    uint a = floor(numVerticesInTheStrand/2.0f);
    uint b = floor((numVerticesInTheStrand-1)/2.0f);

    int nLengthContraintIterations = GetLengthConstraintIterations();

    for ( int iterationE=0; iterationE < nLengthContraintIterations; iterationE++ )
    {
        uint sharedIndex = 2*localVertexIndex * numOfStrandsPerThreadGroup + localStrandIndex;

        if( localVertexIndex < a )
            ApplyDistanceConstraint(sharedPos[sharedIndex], sharedPos[sharedIndex+numOfStrandsPerThreadGroup], sharedLength[sharedIndex].x);

        GroupMemoryBarrierWithGroupSync();

        if( localVertexIndex < b )
            ApplyDistanceConstraint(sharedPos[sharedIndex+numOfStrandsPerThreadGroup], sharedPos[sharedIndex+numOfStrandsPerThreadGroup*2], sharedLength[sharedIndex+numOfStrandsPerThreadGroup].x);

        GroupMemoryBarrierWithGroupSync();
    }

    float4 oldPos = g_HairVertexPositionsPrev[globalVertexIndex];


    bool bAnyColDetected = ResolveCapsuleCollisions(sharedPos[indexForSharedMem], oldPos);
    GroupMemoryBarrierWithGroupSync();

    //-------------------
    // Compute tangent
    //-------------------
    // If this is the last vertex in the strand, we can't get tangent from subtracting from the next vertex, need to use last vertex to current
    uint indexForTangent = (localVertexIndex == numVerticesInTheStrand - 1) ? indexForSharedMem - numOfStrandsPerThreadGroup : indexForSharedMem;
    float3 tangent = sharedPos[indexForTangent + numOfStrandsPerThreadGroup].xyz - sharedPos[indexForTangent].xyz;
    g_HairVertexTangents[globalVertexIndex] = float4(normalize(tangent), g_HairVertexTangents[globalVertexIndex].w);



    //---------------------------------------
    // clamp velocities, rewrite history
    //---------------------------------------
    float3 positionDelta = sharedPos[indexForSharedMem].xyz - oldPos;
    float speedSqr = dot(positionDelta, positionDelta);
    if (speedSqr > g_ClampPositionDelta * g_ClampPositionDelta) {
        positionDelta *= g_ClampPositionDelta * g_ClampPositionDelta / speedSqr;
        g_HairVertexPositionsPrev[globalVertexIndex] = float4(sharedPos[indexForSharedMem].xyz - positionDelta, g_HairVertexPositionsPrev[globalVertexIndex].w);
    }



    g_HairVertexPositions[globalVertexIndex] = sharedPos[indexForSharedMem];

    
    if (bAnyColDetected)
       g_HairVertexPositionsPrev[globalVertexIndex] = sharedPos[indexForSharedMem];


    GroupMemoryBarrierWithGroupSync();
}


ENDCG