local mathfunction = require "mathfunction"
local apolloengine = require "apolloengine"
local apolloDefine = require "apolloutility.defiend"
local collision = require "behavior.vtuber_behavior.solvertest.capsulecollision"

local laplaciansolver = {}

function laplaciansolver:init()
  
  self.links = {  {2,{1,7,23,24,25,26,27,28,29,30}},
                  {3,{1,7,23,24,25,26,27,28,29,30}},
                  {8,{7,1,23,24,25,26,27,28,29,30}},
                  {9,{7,1,23,24,25,26,27,28,29,30}},
                };
  
  self.maxlappoint = 4;
  
  self.discons = {{1,2},{2,3},{7,8},{8,9}};  
  
  self.constraints = {};
  self.dconstraints = {};
  self.lambda = {};
  self.alpha = 0.1;
  self.beta = 0.1;
  self.x = {}
  --loss function
  self.lf = function (x) 
    local ret = 0;
    for i=1,#x do 
      local t = x[i] -self.t[i];
      ret = ret+t*t;
    end
    local lefthand = mathfunction.vector3(x[1],x[2],x[3]);
    local righthand = mathfunction.vector3(x[1],x[2],x[3]);
    local handdiff = (lefthand-righthand) - self.handvec ;
    ret = ret+handdiff:x()*handdiff:x()+handdiff:y()*handdiff:y()+handdiff:z()*handdiff:z();
    return ret;
  end
  --loss function's gradiant
  self.dlf = function (x)
    local ret = {};
    for i=1,#x do
      local dx = 2*x[i]-2*self.t[i];
      table.insert(ret,dx);
    end
    
    ret[4] = ret[4]+2*x[4]-2*x[10]+2*self.handvec:x();
    ret[5] = ret[5]+2*x[5]-2*x[11]+2*self.handvec:y();
    ret[6] = ret[6]+2*x[6]-2*x[12]+2*self.handvec:z();
    ret[10] = ret[10]+2*x[10]-2*x[4]-2*self.handvec:x();
    ret[11] = ret[11]+2*x[11]-2*x[5]-2*self.handvec:y();
    ret[12] = ret[12]+2*x[12]-2*x[6]-2*self.handvec:z();
    return ret;
  end
  --bone length contraint, using length square
  local rarmlength = function(x)
    local ret = 0;
    for i=1,3 do
      local t = x[i] -self.rshoulder[i];
      ret = ret+t*t;
    end
    return ret - self.resetdistance[1];
  end
  --bone length contraint's gradiant, using length square
  local drarmlength = function(x)
    local ret = {};
    for i=1,12 do
      ret[i] = 0;
    end
    for i=1,3 do
      local t = 2*x[i] -2*self.rshoulder[i];
      ret[i] = t;
    end
    return ret;
  end
  
  local larmlength = function(x)
    local ret = 0;
    for i=7,9 do
      local t = x[i] -self.lshoulder[i-6];
      ret = ret+t*t;
    end
    return ret - self.resetdistance[3];
  end
  
  local dlarmlength = function(x)
    local ret = {};
    for i=1,12 do
      ret[i] = 0;
    end
    for i=7,9 do
      local t = 2*x[i] -2*self.lshoulder[i-6];
      ret[i] = t;
    end
    return ret;
  end
  
  local rfarmlength = function(x)
    local ret = 0;
    for i=1,3 do
      local t = x[i] -x[i+3];
      ret = ret+t*t;
    end
    return ret - self.resetdistance[2];
  end
  
  local drfarmlength = function(x)
    local ret = {};
    for i=1,12 do
      ret[i] = 0;
    end
    for i=1,3 do
      local t = 2*x[i] -2*x[i+3];
      ret[i] = t;
      ret[i+3] = -t;
    end
    return ret;
  end
  
  local lfarmlength = function(x)
    local ret = 0;
    for i=7,9 do
      local t = x[i] -x[i+3];
      ret = ret+t*t;
    end
    return ret - self.resetdistance[4];
  end
  
  local dlfarmlength = function(x)
    local ret = {};
    for i=1,12 do
      ret[i] = 0;
    end
    for i=7,9 do
      local t = 2*x[i] -2*x[i+3];
      ret[i] = t;
      ret[i+3] = -t;
    end
    return ret;
  end
  
  table.insert(self.constraints,rarmlength);
  table.insert(self.constraints,larmlength);
  table.insert(self.constraints,rfarmlength);
  table.insert(self.constraints,lfarmlength);
  
  table.insert(self.dconstraints,drarmlength);
  table.insert(self.dconstraints,dlarmlength);
  table.insert(self.dconstraints,drfarmlength);
  table.insert(self.dconstraints,dlfarmlength);
  
  self.collidec = {};
  self.collided = {};
end

function laplaciansolver:resetcapsule()
  self.capsules = {};
end

function laplaciansolver:addcapsule(pointa,pointb,radius)
  table.insert(self.capsules,{pointa,pointb,radius});
end

function laplaciansolver:clearbindsphere()
  self.bindsphere = {};
end

function laplaciansolver:addbindsphere(index,parentidx,radius,offset,worldrot)
  local sphere = {index,parentidx,radius,offset,worldrot};
  table.insert(self.bindsphere,sphere);
end

function laplaciansolver:setupnodecollision()
  for i=1,#self.bindsphere do
    local index = self.bindsphere[i][1];
    local parentidx = self.bindsphere[i][2];
    local parentdir = mathfunction.vector3(self.modelpose[index][1]-self.modelpose[parentidx][1],
                                           self.modelpose[index][2]-self.modelpose[parentidx][2],
                                           self.modelpose[index][2]-self.modelpose[parentidx][3]);
    local parentrot = mathfunction.Quaternion();
    parentrot:AxisToAxis(mathfunction.vector3(0,1,0),parentdir);
    parentrot:InverseSelf();
    local localrot = self.bindsphere[i][5]*parentrot;
    self.bindsphere[i][6] = localrot;
  end
end


--collision constraints is set as equation constraint,for easy impementation
--instead of st. C(x)>0 ,i use st.G(x) = 0
-- if C(x)>0  G(x) = C(x), if C(x)<=0,G(x) = 0
function laplaciansolver:setupcollisioncontraints()
  
  for i=5,#self.constraints do
    self.constraints[i] = nil;
    self.dconstraints[i] = nil;
  end
  
  for i=1,#self.capsules  do
    for j=1,#self.bindsphere do
     
      local collide = function(modelpose)
        local index = self.bindsphere[j][1];
        local parentidx = self.bindsphere[j][2];
        if index ==2 then
          index =1;
          praentidx = 0;
        end
        if index ==3 then
          index =2;
          praentidx = 1;
        end
        local sphereradius = self.bindsphere[j][3]
        local worldrot = nil
        if parentidx~=0 then
          local parentdir = mathfunction.vector3(modelpose[index*3-2]-modelpose[parentidx*3-2],
                                                 modelpose[index*3-1]-modelpose[parentidx*3-1],
                                                 modelpose[index*3-0]-modelpose[parentidx*3  ]);
          local parentrot = mathfunction.Quaternion();
          parentrot:AxisToAxis(mathfunction.vector3(0,1,0),parentdir);
          worldrot = self.bindsphere[j][6]*parentrot;
        end
        local pos = mathfunction.vector3(modelpose[index*3-2],modelpose[index*3-1],modelpose[index*3]);
        if parentidx~=0 then
          pos = pos +  self.bindsphere[j][4]*worldrot;
        end
        local collided ,offset = collision:collide(pos,sphereradius,self.capsules[i][1],self.capsules[i][2],self.capsules[i][3]);
        if collided==true then
          return offset;-- offset:Length()*offset:Length();
        end
        return 0;
      end
      
      local dcollide = function(modelpose)
        local ret = {};
        for idx=1,12 do
          ret[idx] = 0;
        end
        local index = self.bindsphere[j][1];
        local parentidx = self.bindsphere[j][2];
        if index ==2 then
          index =1;      --hack code
          praentidx = 0; --hack code
        end
        if index ==3 then
          index =2;        --hack code
          praentidx = 1;   --hack code
        end
        local sphereradius = self.bindsphere[j][3]
        local worldrot = nil
        if parentidx~=0 then
          local parentdir = mathfunction.vector3(modelpose[index*3-2]-modelpose[parentidx*3-2],
                                                 modelpose[index*3-1]-modelpose[parentidx*3-1],
                                                 modelpose[index*3  ]-modelpose[parentidx*3  ]);
          local parentrot = mathfunction.Quaternion();
          parentrot:AxisToAxis(mathfunction.vector3(0,1,0),parentdir);
          worldrot = self.bindsphere[j][6]*parentrot;
        end
        local pos = mathfunction.vector3(modelpose[index*3-2],modelpose[index*3-1],modelpose[index*3]);
        if parentidx~=0 then
          pos = pos +  self.bindsphere[j][4]*worldrot;
        end
        local collided ,offset,gradiant = collision:collidewithd(pos,sphereradius,self.capsules[i][1],self.capsules[i][2],self.capsules[i][3]);
        if collided==true then
          ret[index*3-2] = gradiant[1];
          ret[index*3-1] = gradiant[2];
          ret[index*3-0] = gradiant[3];
        end

        return ret;
      end
      table.insert(self.constraints,collide);
      table.insert(self.dconstraints,dcollide);
    end
  end
  
 
end



function laplaciansolver:setuplinks()
  self.resetdistance = {};
  for i=1 ,#self.discons  do
    local first = self.discons[i][1];
    local second = self.discons[i][2];
    local firstpos = mathfunction.vector3(self.modelpose[first][1],self.modelpose[first][2],self.modelpose[first][3]);
    local secondpos = mathfunction.vector3(self.modelpose[second][1],self.modelpose[second][2],self.modelpose[second][3]);  
    local distance = (firstpos-secondpos):Length();
    table.insert(self.resetdistance,distance*distance);
  end
end

function laplaciansolver:setuplaplacian()
  
  for i=1,#self.links do
    local centeridx =  self.links[i][1];
    local center = mathfunction.vector3(self.targetpose[centeridx][1],self.targetpose[centeridx][2],self.targetpose[centeridx][3]);
    local disidxmap = {};
    
    for j=1,#self.links[i][2] do
      local otheridx =  self.links[i][2][j];
      local other = mathfunction.vector3(self.targetpose[otheridx][1],self.targetpose[otheridx][2],self.targetpose[otheridx][3]);
      local disvec = other - center;
      local dis = disvec:Length();
      if dis == 0 then
        dis = 9999999;
      else
        dis = 1/dis;
      end
      table.insert(disidxmap,{dis,otheridx});
    end
    
    table.sort(disidxmap, function(v1,v2) return v1[1]>v2[1] end );
    local total =0;
    for i=1, self.maxlappoint do
      total = total+disidxmap[i][1];
    end
    
    local targetlaplacianvec = mathfunction.vector3(0,0,0);
    for j=1,self.maxlappoint do
      local otheridx =  disidxmap[j][2];
      targetlaplacianvec = targetlaplacianvec +
      ( center-mathfunction.vector3(self.targetpose[otheridx][1],self.targetpose[otheridx][2],self.targetpose[otheridx][3]))* (disidxmap[j][1])/total;
    end
    
    local modellaplacianpos = mathfunction.vector3(0,0,0);
    local modelcenter = mathfunction.vector3(self.modelpose[centeridx][1],self.modelpose[centeridx][2],self.modelpose[centeridx][3]);
    for j=1,self.maxlappoint do
      local otheridx =  disidxmap[j][2];
      modellaplacianpos = modellaplacianpos + 
      ( mathfunction.vector3(self.modelpose[otheridx][1],self.modelpose[otheridx][2],self.modelpose[otheridx][3]))*(disidxmap[j][1])/total;
    end
    self.modelpose[centeridx][1] = modellaplacianpos:x()+targetlaplacianvec:x();
    self.modelpose[centeridx][2] = modellaplacianpos:y()+targetlaplacianvec:y();
    self.modelpose[centeridx][3] = modellaplacianpos:z()+targetlaplacianvec:z();
    
  end
  
  local righthand = mathfunction.vector3(self.targetpose[3][1],self.targetpose[3][2],self.targetpose[3][3]);
  local lefthand = mathfunction.vector3(self.targetpose[9][1],self.targetpose[9][2],self.targetpose[9][3]);
  self.handvec = (lefthand-righthand);
end

function laplaciansolver:setuplambda()
  for i=1,#self.constraints do
    self.lambda[i] = 0;
  end
end

function laplaciansolver:solve(targetpos,modelpos)
  
  
  self.targetpose = targetpos;
  self.modelpose = modelpos;
  
  self:setuplinks();
  self:setupnodecollision();
  self:setuplaplacian();
  local currentx = self:getxfrompose();
  
  self:setupcollisioncontraints(currentx);
  self:setuplambda();

  for step =1,self.iterationcount do
  
    local dlf = self.dlf(currentx);
    
    -- get gradiants of contraints
    local dc = {}
    for j=1,#self.dconstraints do
      dc[j] = self.dconstraints[j](currentx); 
    end
    -- get contraints
    local c = {}
    for j=1,#self.constraints do
      c[j] = self.constraints[j](currentx);
    end
    
    local converge = 0;
    local covergegradiant = 0;
    local covergeconstraint= 0;
    
    --graidiant update x.   x = x- alpha* (gradiantx+lambda*gradiantconstraint)
    for i=1,#currentx do
      local dcitotal = 0;
      for j=1,#self.dconstraints do
        dcitotal = dcitotal+self.lambda[j]*dc[j][i]; --dcitotal = lambda*gradiantconstraint
      end
      currentx[i] = currentx[i] - self.alpha*(dlf[i]+dcitotal); --x = x- alpha* (gradiantx+lambda*gradiantconstraint)
      converge = converge+math.abs(dlf[i]+dcitotal);
      covergegradiant = covergegradiant+math.abs(dlf[i]+dcitotal);
    end
    
    --graidiant update lambda. lambda = lambda + beta*constraint
    for i=1,#self.lambda do
      self.lambda[i] = self.lambda[i] + self.beta*(c[i]); -- lambda = lambda + beta*constraint
      converge = converge+math.abs(c[i]);
      if i>4 then
        covergeconstraint = covergeconstraint+math.abs(c[i]);
      end
    end
    LOG(" CONVERGEC IN STEP"..step.." ,is "..covergeconstraint);
    --LOG(" CONVERGEG IN STEP"..step.." ,is "..covergegradiant);
    if converge<0.1 then
      --LOG("ACICHE CONVERGE IN "..step.." STEP!");
      --break;
    end
  end
  return self:getposefromx(currentx);
end

function laplaciansolver:getxfrompose()
  local x = {};
  for i=1 ,3 do
    x[i] = self.modelpose[2][i];
  end
  for i=1 ,3 do
    x[i+3] = self.modelpose[3][i];
  end
  for i=1 ,3 do
    x[i+6] = self.modelpose[8][i];
  end
  for i=1 ,3 do
    x[i+9] = self.modelpose[9][i];
  end
  self.t = {};
  for i=1 ,3 do
    self.t[i] = self.modelpose[2][i];
  end
  for i=1 ,3 do
    self.t[i+3] = self.modelpose[3][i];
  end
  for i=1 ,3 do
    self.t[i+6] = self.modelpose[8][i];
  end
  for i=1 ,3 do
    self.t[i+9] = self.modelpose[9][i];
  end
  self.lshoulder = {};
  for i=1 ,3 do
    self.lshoulder[i] = self.modelpose[7][i];
  end
  self.rshoulder = {};
  for i=1 ,3 do
    self.rshoulder[i] = self.modelpose[1][i];
  end
  return x;
end

function laplaciansolver:getposefromx(x)
  for i=1 ,3 do
    self.modelpose[2][i] = x[i] ;
  end
  for i=1 ,3 do
    self.modelpose[3][i] = x[i+3] ;
  end
  for i=1 ,3 do
    self.modelpose[8][i] = x[i+6];
  end
  for i=1 ,3 do
    self.modelpose[9][i] = x[i+9] ;
  end
  return self.modelpose;
end

function laplaciansolver:applypredelta()
 
end

function laplaciansolver:tbl_copy(orig)
  local orig_type = type(orig)
  local copy
  if orig_type == "table" then
      copy = {}
      for orig_key, orig_value in next, orig, nil do
          copy[self:tbl_copy(orig_key)] = self:tbl_copy(orig_value)
      end
  else -- number, string, boolean, etc
      copy = orig
  end
  return copy
end

return laplaciansolver;