
local TransNode = require "apolloutility.apollonode.trasnnode"
local apolloengine = require "apolloengine"
local mathfunction = require "mathfunction"
local util = require "behavior.avatar_behavior.util"
local apolloDefine = require "apolloutility.defiend"
local SkeletonNode = TransNode:extend();
local radiusscale = math.pi / 180;


function SkeletonNode:new()
  SkeletonNode.super.new(self);
end

function SkeletonNode:SetBonesRef(table_bones)
  self.BoneNodes = {};
  for name,node in pairs(table_bones)  do
    self.BoneNodes[name] = node;
  end
end

function SkeletonNode:SetPoseEstimateBoneNames(rootname,spinebone,bonenames,children,bodypair)
  self.rootbone = rootname;
  self:_SetRootJoint(rootname);
  
  self.spinebone = spinebone;
  self.bonenamemap = bonenames;
  self.bodybones = {
      self.bonenamemap[1],
      self.bonenamemap[4],
      self.bonenamemap[7],
      self.bonenamemap[10]
    }
    
  self.children = children;
  self.bodypair = bodypair;
end

function SkeletonNode:GetCurrentPose()
  local currentpose = {}
  for i=1,14,1 do
    currentpose[i] = self:GetTPos(i);
  end
  --[[
  local mid = {}
  mid[1] = (currentpose[7][1] + currentpose[10][1])/2
  mid[2] = (currentpose[7][2] + currentpose[10][2])/2
  mid[3] = (currentpose[7][3] + currentpose[10][3])/2
  for i=1,14,1 do
    currentpose[i][1] = currentpose[i][1] -  mid[1]
    currentpose[i][2] = currentpose[i][2] -  mid[2];
    currentpose[i][3] = currentpose[i][3] -  mid[3];
  end]]
  return currentpose;
end


function SkeletonNode:BindPose()
  self.Tdir = {};
  self.Trot = {};
  self.Tpos = {};
  self.TBody = {};
  self.TBoneLength = {};
  
  for i=1,14,1 do
    self.Trot[i] = self:GetTRot(i);
    self.Tdir[i] = self:GetTDir(i);
    self.Tpos[i] = self:GetTPos(i)
  end
  
  for  i=1,14,1  do
    if self.children[i]~=0 then
      local childidx = self.children[i];
      local parentpose = mathfunction.vector3(self.Tpos[i][1],self.Tpos[i][2],self.Tpos[i][3]);
      local childpose = mathfunction.vector3(self.Tpos[childidx][1],self.Tpos[childidx][2],self.Tpos[childidx][3]);
      local length = (parentpose-childpose):Length();
      self.TBoneLength[i] = length;
    else
      self.TBoneLength[i] = 0;
    end
  end
  self.BodyResetPose = {};
  local bodyidx = {1,4,7,10};
  for i=1,#bodyidx do
    local name = self.bonenamemap[bodyidx[i]];
    local joint = self:GetJoint(name);
    local comp = joint.trans;
    local pos = comp:GetLocalPosition();
    self.BodyResetPose[i] = pos;
  end
  
  self.TBody = self:GetBodyPose();
  self.TBodyRot = self:GetBodyRot();
  
  self.RootInitRot = self.trans:GetWorldRotation();
  
  self.BodyStartRotMat = util:ConstractRotationFromVector(self.TBody[2],self.TBody[3]);
  self.BodyStartRotMat:Transpose();
  LOG(self.TBody[2]);
  LOG(self.TBody[3]);
  LOG("AT INIT GO");
  
  local rootjoint = self:GetJoint(self.rootbone);
  local comp = rootjoint.trans;
  self.rootpos = comp:GetWorldPosition();
end


function SkeletonNode:GetBodyRot()
  local bonenames = {self.rootbone,self.bonenamemap[13]}
  for i=1,#self.spinebone do
    bonenames[2 + i] = self.spinebone[i]
  end
  local ret = {}
  for i=1,#bonenames do
    local transcom = self:GetJoint(bonenames[i]).trans;
    ret[bonenames[i]] = transcom:GetWorldRotation();
  end
  return ret;
end

function SkeletonNode:GetBodyPose()
  local bonepose = {}
  for i =1, #self.bodypair do
    local first =  self.bodybones[self.bodypair[i][1]]
    local second = self.bodybones[self.bodypair[i][2]]
    
    local fristjoint = self:GetJoint(first);
    local firstworldpose = fristjoint.trans:GetWorldPosition();
   
    local secondjoint = self:GetJoint(second);
    local secondworldpose = secondjoint.trans:GetWorldPosition();
    bonepose[i] = secondworldpose - firstworldpose;
    bonepose[i]:NormalizeSelf();
  end

  bonepose[3] = (bonepose[3]+bonepose[4]/2);
  bonepose[3]:NormalizeSelf();

  bonepose[4] = nil;
  return bonepose;
end

function SkeletonNode:GetTPos(boneidx)
  local name = self.bonenamemap[boneidx];

  if name==nil then
    return nil
  end
  
  local joint = self:GetJoint(name);
  local comp = joint.trans;
  local pos = comp:GetWorldPosition();
  return {pos:x(),pos:y(),pos:z()};
end


function SkeletonNode:GetTRot(boneidx)
  if self.children[boneidx] == 0 then
    return nil
  end

  local name = self.bonenamemap[boneidx];

  if name==nil then
    return nil
  end
  LOG("GetTRot:"..name)
  local joint = self:GetJoint(name);
  local comp = joint.trans;

  return comp:GetWorldRotation();
end

function SkeletonNode:GetTDir(boneidx)
  if self.children[boneidx] == 0 then
    return nil
  end
  local name1 = self.bonenamemap[boneidx];
  local name2 = self.bonenamemap[self.children[boneidx]];
  if name1==nil or name2==nil then
    return nil
  end
  local joint1 = self:GetJoint(name1);
  local comp1 = joint1.trans;
  
  local joint2 = self:GetJoint(name2);
  local comp2 = joint2.trans;
  
  local pos1 = comp1:GetWorldPosition();
  local pos2 = comp2:GetWorldPosition();
  
  local dir =  pos2 - pos1;
  dir:NormalizeSelf();

  return dir;
end

function SkeletonNode:GetJoint(name)
  if self.BoneNodes[name] == nil then
    return nil
  end
  local transnode = {}
  transnode = TransNode:cast(transnode);
  transnode.node = self.BoneNodes[name];
  transnode.trans = self.BoneNodes[name]:GetComponent(apolloengine.Node.CT_TRANSFORM);
  return transnode;      
end

function SkeletonNode:GetTargets()
  local target = {}
  for i =1 , #self.children do
    local joint = self:GetJoint(self.bonenamemap[i]);
    local trans = joint.trans;
    --target[self.bonenamemap[i] ] = trans:GetWorldPosition();
    target[self.bonenamemap[i] ] = trans;
  end

  return target,self.RootRot;
end

function SkeletonNode:ResetBoneLength()
  
  local bodyidx = {1,4,7,10};
  for i=1,#bodyidx do
    local name = self.bonenamemap[bodyidx[i]];
    local joint = self:GetJoint(name);
    local comp = joint.trans;
    local pos = comp:GetLocalPosition();
    comp:SetLocalPosition( self.BodyResetPose[i]);
  end
  
  for i =1 , #self.children do
    local joint = self:GetJoint(self.bonenamemap[i]);
    local trans = joint.trans;
    trans:SetWorldScale( mathfunction.vector3(1,1,1) );
  end
  
end

function SkeletonNode:UpdateBoneLength(pos3d,score,cutoff)
  local bodyidx = {1,4,7,10};
  
  for i=1,#bodyidx do
    local index = bodyidx[i];
    if score==nil or (score~=nil and score[index]>cutoff )  then
    
      local joint = self:GetJoint(self.bonenamemap[index]);
      local trans = joint.trans;
      local worldpos = joint.trans:GetWorldPosition();
      local targetpos = mathfunction.vector3(pos3d[index][1],pos3d[index][2],pos3d[index][3]);
      local deltapos = targetpos- worldpos;
      joint.trans:SetWorldPosition(targetpos);
    end
  end
  
  for i =1 , #self.children do
    local joint = self:GetJoint(self.bonenamemap[i]);
    local trans = joint.trans;

    if self.children[i] ~= 0  then
      local mypos = mathfunction.vector3(pos3d[i][1],pos3d[i][2],pos3d[i][3]);
      local chidx = self.children[i]
      local chpos = mathfunction.vector3(pos3d[chidx][1],pos3d[chidx][2],pos3d[chidx][3]);
      
      if score==nil or (score~=nil and score[i]>cutoff and score[chidx]>cutoff)  then
        local dir =  chpos - mypos;
        local length = dir:Length();
        local scale = length/self.TBoneLength[i];
        if i==13 then
          scale = scale/3;
        end
        if i <=6 then
          trans:SetWorldScale( mathfunction.vector3(scale,scale+0.2,scale+0.2) );
        else
          trans:SetWorldScale( mathfunction.vector3(scale+0.2,scale,scale+0.2) );
        end
      end
    end
  end
end

function SkeletonNode:ResetSkeletonRotation()
    for i = 1, #self.spinebone do
      local transS1 = self:GetJoint(self.spinebone[i]).trans;
      transS1:SetLocalRotation(mathfunction.Quaternion());
    end
    for i = 1, #self.bonenamemap do
      local transS1 = self:GetJoint(self.bonenamemap[i]).trans;
      transS1:SetLocalRotation(mathfunction.Quaternion());
    end
end

function SkeletonNode:UpdateSkeleton(pos3d, cameraT,score,cutoff)
  
  local rotationmat = mathfunction.Matrix33();
  local rootrot = rotationmat:ToQuaternion();
  if score==nil or (score~=nil and score[1]>cutoff and score[4]>cutoff and score[7]>cutoff and score[10]>cutoff) then
    tempppair = {{1,4},{7,10},{1,7},{4,10}}
    tempdir = {}
    for i=1,#tempppair do
        local fpos = mathfunction.vector3(pos3d[tempppair[i][1]][1],pos3d[tempppair[i][1]][2],pos3d[tempppair[i][1]][3]);
        local spos = mathfunction.vector3(pos3d[tempppair[i][2]][1],pos3d[tempppair[i][2]][2],pos3d[tempppair[i][2]][3]);
        local dir = (spos - fpos);
        dir:NormalizeSelf();
        table.insert(tempdir,dir);
    end
    tempdir[3] = (tempdir[3]+tempdir[4])/2
    tempdir[3]:NormalizeSelf();
    
    
    local currentBodyRotMat = util:ConstractRotationFromVector(tempdir[2],tempdir[3]);
    
    rotationmat =  self.BodyStartRotMat * currentBodyRotMat ;
    
    rootrot = rotationmat:ToQuaternion();
    
    self.RootRot = rootrot;

    local rotation1 = mathfunction.Quaternion();
    rotation1:AxisToAxis(self.TBody[1]*rotationmat,tempdir[1]);
   
    --local rotation2 = mathfunction.Quaternion();
    --rotation2:AxisToAxis(self.TBody[2]*rotationmat,tempdir[2]);
    
    local trans2 = self:GetJoint(self.rootbone).trans;
    trans2:SetWorldRotation(self.TBodyRot[self.rootbone]*rootrot);  
    
    if cameraT ~= nil then
      trans2:SetWorldPosition(  cameraT );
    end  

    for i = 1, #self.spinebone do
      local spine_rot = mathfunction.Mathutility:Slerp(mathfunction.Quaternion(),rotation1,i/#self.spinebone);
      local transS1 = self:GetJoint(self.spinebone[i]).trans;
      local rotS1 = self.TBodyRot[self.spinebone[i] ];
      transS1:SetWorldRotation(rotS1*rootrot*spine_rot);
    end
  end
  
  local target = {}
  for i =1 , #self.children do
        
    local joint = self:GetJoint(self.bonenamemap[i]);
    local trans = joint.trans;

    if self.children[i] ~= 0  then
      local mypos = mathfunction.vector3(pos3d[i][1],pos3d[i][2],pos3d[i][3]);
      local chidx = self.children[i]
      
      if score==nil or (score~=nil and score[i]>cutoff and score[chidx]>cutoff) then
        local chpos = mathfunction.vector3(pos3d[chidx][1],pos3d[chidx][2],pos3d[chidx][3]);
        
        local dir =  chpos - mypos;
        dir:NormalizeSelf();
        local rotation = mathfunction.Quaternion();
        rotation:AxisToAxis(self.Tdir[i]*rotationmat,dir);
        if joint ~= nil then
          local rotated = self.Trot[i]*rootrot*rotation
          trans:SetWorldRotation( rotated );
        end
      end
    end
    target[self.bonenamemap[i] ] = trans:GetWorldPosition();
  end
  return target,rootrot;
end

--[[
参数说明:
rotUpper:大臂世界旋转
rotFore：小臂世界旋转
upperarm：大臂方向向量
forearm：小臂方向向量
bendAxisUpper：初始大臂外侧朝向 bendAxisUpper * rotUpper = upperarm X forearm
bendAxisFore：初始小臂外侧朝向  bendAxisFore * rotFore = upperarm X forearm
isLeft: (bool)是否为左臂
]]--
function SkeletonNode:EstimateArmRot(rotUpper,rotFore,upperarm,forearm,bendAxisUpper,bendAxisFore,isLeft)
  
  local resRotUpper = rotUpper;
  local resRotFore = rotFore;
  
  upperarm : NormalizeSelf();
  forearm : NormalizeSelf();
  local cosine = math.abs(upperarm : Dot(forearm));

  if cosine < 0.966 then
    local elbowRotAxisW0 = upperarm:Cross(forearm); 
    if isLeft then
      elbowRotAxisW0 = mathfunction.vector3(0.0,0.0,0.0) - elbowRotAxisW0;
    end
    local initRotAxis1 = elbowRotAxisW0 * rotUpper:Inverse();
    local rotation0 = mathfunction.Quaternion();
    rotation0:AxisToAxis(bendAxisUpper,initRotAxis1);
    --resRotUpper = rotation0 * resRotUpper;
    
    rotation0 = mathfunction.Mathutility:Slerp(mathfunction.Quaternion(),rotation0,0.5); --转太多蒙皮会坏掉，所以就少转点。
    resRotUpper =  rotation0*resRotUpper;

    local initRotAxis2 = elbowRotAxisW0 * rotFore : Inverse();
    local rotation1 = mathfunction.Quaternion();
    rotation1:AxisToAxis(bendAxisFore,initRotAxis2);
    rotation1 = mathfunction.Mathutility:Slerp(mathfunction.Quaternion(),rotation1,0.5);
    resRotFore = rotation1 * resRotFore;
    --resRotFore =  resRotFore;
  end
  
  return resRotUpper,resRotFore;
end

function SkeletonNode:CorrectArmsTwist()
  
  if self.leftElbowAxis0 == nil then
  local upper,fore,hand;
  upper = mathfunction.vector3(self.Tpos[4][1],self.Tpos[4][2],self.Tpos[4][3]);
  fore = mathfunction.vector3(self.Tpos[5][1],self.Tpos[5][2],self.Tpos[5][3]);
  hand = mathfunction.vector3(self.Tpos[6][1],self.Tpos[6][2],self.Tpos[6][3]);
  local upperarmL = fore - upper;
  local forearmL = hand - fore;
  self.leftElbowAxis = forearmL :Cross(upperarmL);
  self.leftElbowAxis1 = self.leftElbowAxis * self.Trot[5] : Inverse();
  self.leftElbowAxis0 = self.leftElbowAxis * self.Trot[4] : Inverse();
  self.leftElbowAxis0 : NormalizeSelf()
  self.leftElbowAxis1 : NormalizeSelf()
  --LOG("leftElbowAxis0 :"..self.leftElbowAxis0:x()..","..self.leftElbowAxis0:y()..","..self.leftElbowAxis0:z());
  --LOG("leftElbowAxis1 :"..self.leftElbowAxis1:x()..","..self.leftElbowAxis1:y()..","..self.leftElbowAxis1:z());
   
  upper = mathfunction.vector3(self.Tpos[1][1],self.Tpos[1][2],self.Tpos[1][3]);
  fore = mathfunction.vector3(self.Tpos[2][1],self.Tpos[2][2],self.Tpos[2][3]);
  hand = mathfunction.vector3(self.Tpos[3][1],self.Tpos[3][2],self.Tpos[3][3]);
  local upperarmR =  fore - upper;
  local forearmR = hand - fore;
  self.rightElbowAxis = upperarmR :Cross(forearmR);
  self.rightElbowAxis1 = self.rightElbowAxis * self.Trot[2] : Inverse();
  self.rightElbowAxis0 = self.rightElbowAxis * self.Trot[1] : Inverse();
  self.rightElbowAxis0 : NormalizeSelf()
  self.rightElbowAxis1 : NormalizeSelf()
  --LOG("rightElbowAxis0 :"..self.rightElbowAxis0:x()..","..self.rightElbowAxis0:y()..","..self.rightElbowAxis0:z());
  --LOG("rightElbowAxis1 :"..self.rightElbowAxis1:x()..","..self.rightElbowAxis1:y()..","..self.rightElbowAxis1:z());
  end

  --left UpperArm/ForeArm/Hand
  local transUpperL = self:GetJoint(self.bonenamemap[4]).trans;
  local rotUpperL = transUpperL : GetWorldRotation();
  local targetUpperL = transUpperL : GetWorldPosition();
  
  local transForeL = self:GetJoint(self.bonenamemap[5]).trans;
  local rotForeL = transForeL:GetLocalRotation() * rotUpperL;
  local targetForeL = transForeL : GetWorldPosition();

  local transHandL = self:GetJoint(self.bonenamemap[6]).trans;
  local targetHandL = transHandL : GetWorldPosition();
  
  local upperarmL = targetForeL - targetUpperL;
  local forearmL = targetHandL - targetForeL;
  
  local resrotUpperL,resrotForeL = self:EstimateArmRot(rotUpperL,rotForeL,upperarmL,forearmL,
    self.leftElbowAxis0,self.leftElbowAxis1,true);
  transUpperL : SetWorldRotation( resrotUpperL  );
  transForeL : SetWorldRotation( resrotForeL );

  --right UpperArm/ForeArm/Hand
  local transUpperR = self:GetJoint(self.bonenamemap[1]).trans;
  local rotUpperR = transUpperR : GetWorldRotation();
  local targetUpperR = transUpperR : GetWorldPosition();
  
  local transForeR = self:GetJoint(self.bonenamemap[2]).trans;
  local rotForeR = transForeR:GetLocalRotation() * rotUpperR;
  local targetForeR = transForeR : GetWorldPosition();

  local transHandR = self:GetJoint(self.bonenamemap[3]).trans;
  local targetHandR = transHandR : GetWorldPosition();
  
  local upperarmR = targetForeR - targetUpperR;
  local forearmR = targetHandR - targetForeR;
  
  local resrotUpperR,resrotForeR = self:EstimateArmRot(rotUpperR,rotForeR,upperarmR,forearmR,
    self.rightElbowAxis0,self.rightElbowAxis1,false);
  transUpperR : SetWorldRotation( resrotUpperR  );
  transForeR : SetWorldRotation( resrotForeR ); 

end

function SkeletonNode:_GetIK()
  if self.ik == nil then
    self.ik = self.BoneNodes[self.rootbone]:GetComponent(apolloengine.Node.CT_KINEMATICS);
    if self.ik == nil then
    LOG("IK is nil")
    end
  end
  return self.ik;
end

function SkeletonNode:_UpdateRootRot(rootrot,cameraT,cameraR)
  local trans = self:GetJoint(self.rootbone).trans;
  trans:SetWorldRotation(self.TBodyRot[self.rootbone]*rootrot);
  if cameraT ~= nil then
    trans:SetWorldPosition( ( cameraT )/1000 );
  end
end

function SkeletonNode:UpdateIK(rootRot,cameraT,cameraR,scores)
  if self:_GetIK() == nil then
    return
  end
  self:_UpdateRootRot(rootRot,cameraT,cameraR)
  for i=1,#self.bonenamemap do
    local bone = self.BoneNodes[self.bonenamemap[i]]:GetComponent(apolloengine.Node.CT_TRANSFORM);
    local CONFIDENCE = 1.0;
    self:_GetIK():SetConfidence(bone:GetStaticID(), scores[i] > CONFIDENCE);     
  end
  self:_GetIK():UpdateKinematics(0.2,5);  
end

function SkeletonNode:SetTarget(targetNodes)
  if self:_GetIK() == nil then
    return
  end
  for i=1, #self.bonenamemap do
    local bone   = self.BoneNodes[self.bonenamemap[i]]:GetComponent(apolloengine.Node.CT_TRANSFORM);
    local target =    targetNodes[self.bonenamemap[i]]:GetComponent(apolloengine.Node.CT_TRANSFORM);
    self:_GetIK():SetTarget(bone:GetStaticID(),target);
  end
end

function SkeletonNode:_SetRootJoint(name)
  local joint = self:GetJoint(name);
  local comp = joint.trans;
  self.rootpos = comp:GetWorldPosition();
end

return SkeletonNode;