CGPROGRAM

#include "hair_common.inc"


#pragma compute LocalShapeConstraints


[numthreads(THREAD_GROUP_SIZE,1,1)]
void LocalShapeConstraints(
    uint GIndex : SV_GroupIndex,
    uint3 GId : SV_GroupID,
    uint3 DTid : SV_DispatchThreadID)
{
   
    uint local_id, group_id, globalStrandIndex, numVerticesInTheStrand, globalRootVertexIndex;
    CalcIndicesInStrandLevelMaster(GIndex, GId.x, globalStrandIndex, numVerticesInTheStrand, globalRootVertexIndex);


    // stiffness for local shape constraints
    float stiffnessForLocalShapeMatching = GetLocalStiffness();

    //1.0 for stiffness makes things unstable sometimes.
    stiffnessForLocalShapeMatching = 0.5f*min(stiffnessForLocalShapeMatching, 0.95f);

    //--------------------------------------------
    // Local shape constraint for bending/twisting
    //--------------------------------------------
    {
        float4 boneQuat = float4(0, 0, 0, 0);// g_StrandLevelData[globalStrandIndex].skinningQuat;

        // vertex 1 through n-1
        for (uint localVertexIndex = 1; localVertexIndex < numVerticesInTheStrand - 1; localVertexIndex++)
        {
            uint globalVertexIndex = globalRootVertexIndex + localVertexIndex;

            float4 pos = g_HairVertexPositions[globalVertexIndex];
            float4 pos_plus_one = g_HairVertexPositions[globalVertexIndex + 1];
            float4 pos_minus_one = g_HairVertexPositions[globalVertexIndex - 1];


	        float4 initPos0 =  ApplyWorldTransformToVertex(g_InitialHairPositions[globalVertexIndex],_Transform);
	        float4 initPos1 =  ApplyWorldTransformToVertex(g_InitialHairPositions[globalVertexIndex+1],_Transform);
	        float4 initPos2 =  ApplyWorldTransformToVertex(g_InitialHairPositions[globalVertexIndex-1],_Transform);


            float3 bindPos = initPos0.xyz;//MultQuaternionAndVector(boneQuat, g_InitialHairPositions[globalVertexIndex].xyz);
            float3 bindPos_plus_one = initPos1.xyz;//MultQuaternionAndVector(boneQuat, g_InitialHairPositions[globalVertexIndex + 1].xyz);
            float3 bindPos_minus_one = initPos2.xyz;//MultQuaternionAndVector(boneQuat, g_InitialHairPositions[globalVertexIndex - 1].xyz);

            float3 lastVec = pos.xyz - pos_minus_one.xyz;

            float4 invBone = InverseQuaternion(boneQuat);
            float3 vecBindPose = bindPos_plus_one - bindPos;
            float3 lastVecBindPose = bindPos - bindPos_minus_one;
            float4 rotGlobal = QuatFromTwoUnitVectors(normalize(lastVecBindPose), normalize(lastVec));

            float3 orgPos_i_plus_1_InGlobalFrame = MultQuaternionAndVector(rotGlobal, vecBindPose) + pos.xyz;
            float3 del = stiffnessForLocalShapeMatching * (orgPos_i_plus_1_InGlobalFrame - pos_plus_one.xyz);

            if (IsMovable(pos))
                pos.xyz -= del.xyz;

            if (IsMovable(pos_plus_one))
                pos_plus_one.xyz += del.xyz;


            g_HairVertexPositions[globalVertexIndex] = pos;
            g_HairVertexPositions[globalVertexIndex + 1] = pos_plus_one;
        }
    }

}


ENDCG