local mathfunction = require "mathfunction"

--[[local Substract =  require "torchutility.substract"
local Dot =  require "torchutility.multiply"
local Multiply =  require "torchutility.nummultiply"
local ToPoseList =  require "torchutility.toposelist"
local SignedSqrt =  require "torchutility.signedsqrt"
local AutoTensor =  require "torchutility.autotensor"
local SubTensor = require "torchutility.subtensor"
local dogleg = require "torchutility.dogleg"

local Add =  require "torchutility.add"
local PoseExp =  require "torchutility.poseexp"
local Length =  require "torchutility.Length"
local Power =  require "torchutility.power"
local FrameLength =  require "torchutility.vec3length"
local Flat=  require "torchutility.flat"]]--

local laplacianRetarget = {}

function laplacianRetarget:SetupInitPose(oripose,fitpose)
  self.OriPose = oripose;
  self.FitPose = fitpose;
end

function laplacianRetarget:SetupFrameLinkData(weight)
  self.FrameLinkWeight = 10;
  if weight~= nil then
    self.FrameLinkWeight = weight;
  end


end

function laplacianRetarget:SetupStaticAnchor(anchor,anchorweight)
  self.StaticAnchor = anchor;
  self.AnchorWeight = 1000;
  if anchorweight~= nil then
    self.AnchorWeight = anchorweight;
  end
  
  self.AnchorMaskMat = self:SetupAnchorMaskMat();
end

function laplacianRetarget:SetupLaplacianLinks(links)
  self.LaplacianLinks = links;
end

function laplacianRetarget:CaculateWeightFromLink(links,laplacianweight)
  local size  =  #self.OriPose;
  local matA = torch.Tensor(size,size):zero();
  self.LaplacianWeight = 6;
  if laplacianweight~=nil then
    self.LaplacianWeight = laplacianweight;
  end
  self.LaplacianLinks = {};
  for i = 1 , #links do
    local centeridx = links[i][1];
    local lengtharr = {};
    local totallength = 0;
    
    for j = 1, #links[i][2] do
      local otheridx = links[i][2][j];
      local center =  mathfunction.vector3 (self.OriPose[centeridx][1],self.OriPose[centeridx][2],self.OriPose[centeridx][3]);
      local other  =  mathfunction.vector3 (self.OriPose[otheridx][1],self.OriPose[otheridx][2],self.OriPose[otheridx][3]);
      local length = (other - center):Length();
      length =  1.0/length;
      lengtharr [j] = length;
      totallength = totallength+length;
    end
    
    for j = 1, #lengtharr do
      lengtharr[j] = lengtharr[j]/totallength;
      local weightpairtable = {};
      weightpairtable[1] = lengtharr[j]*self.LaplacianWeight;
      weightpairtable[2] = centeridx;
      weightpairtable[3] = links[i][2][j];
      table.insert(self.LaplacianLinks,weightpairtable);
    end
  end
end

function laplacianRetarget:SetupBoneData(bonePair,boneLength,lengthWeight)
  self.BonePair = bonePair;
  self.BoneLength = boneLength;
  self.LengthWeight = 1000;
  if lengthWeight~= nil then
    self.LengthWeight = lengthWeight;
  end
end


function laplacianRetarget:SetupLaplacianMat()
  local size  =  #self.OriPose;
  local matA = torch.Tensor(size,size):zero();

  for i = 1 , #self.LaplacianLinks do
    matA[self.LaplacianLinks[i][2]][self.LaplacianLinks[i][3]] = self.LaplacianLinks[i][1];
  end
  
  --local matAT = matA:transpose(1,2);

  --matA = matA+matAT;
  
  local matASum = torch.Tensor(size,size):zero();
  for i = 1 ,size do 
    local sum = 0
    for j = 1, size do
      sum = sum + matA[i][j]
    end
    matASum[i][i] = sum;
  end

  local matLaplacian = matASum - matA;
  
  return matLaplacian;
  
end

function laplacianRetarget:SetupAnchorMaskMat()
  local size = #self.OriPose;
  local matMask = torch.Tensor(size,3):zero();
  
  for i=1 , #self.StaticAnchor do 
    for j = 1 ,3 do
      matMask[self.StaticAnchor[i]][j] = 1
    end
  end
  
  return matMask;
end

function laplacianRetarget:SetupCollisionMat(jacobian,delta,weight)
  self.CollisionWeight = 20;
  if weight ~= nil then
    self.CollisionWeight = weight;
  end
  self.CollisionJacobian = jacobian;
  self.CollisionDelta = delta;
end

function laplacianRetarget:Minimize()
  
  if self.lastFramePose  == nil then
    self.lastFramePose = {}
    for i=1,3 do
      self.lastFramePose[i] = self.FitPose;
    end
  end
  self.lastFramePose[1] = self.lastFramePose[2];
  self.lastFramePose[2] = self.lastFramePose[3];
  self.lastFramePose[3] = self.FitPose;
  
  
  if self.laplacianPose  == nil then
    self.laplacianPose = {}
    for i=1,2 do
      self.laplacianPose[i] = self.FitPose;
    end
  end
  self.laplacianPose[1] = self.laplacianPose[2];
  self.laplacianPose[2] = self.OriPose;

  local funcs = {}
  
  local torchori = torch.Tensor(self.laplacianPose[1]);
  
  local laplacianmat = AutoTensor();
  local torchlaplacianmat = self:SetupLaplacianMat();
  laplacianmat:VFromTorch(torchlaplacianmat);
  
 
  local targetLaplacianEnergy = AutoTensor();
  targetLaplacianEnergy:VFromTorch(torchlaplacianmat * torchori);
  
  
  local modelPose = AutoTensor(self.lastFramePose[2]);
  local modelAnchorPose = AutoTensor(self.lastFramePose[2]);
  local modelLaplacianEnergy = Dot(laplacianmat,modelPose);
  
  
  funcs.laplacian = {}
  funcs.laplacian.fx = function(x)
                    modelPose:V(x);
                    local ret = Power((modelLaplacianEnergy - targetLaplacianEnergy),2)
                    local total = 0;
                    for i=1,14 do
                      for j =1,3 do
                        total = total+ret:R()[i][j];
                      end
                    end
                    --LOG("total laplacian is");
                    --LOG(total);
                    return ret:R();
                  end
  funcs.laplacian.dx = function(x)
                    modelPose:V(x);
                    local ret = Power((modelLaplacianEnergy - targetLaplacianEnergy),2)
                    return ret:D2(modelPose);
                  end

  local anchormaskmat = AutoTensor();
  anchormaskmat:VFromTorch(self.AnchorMaskMat); 
  
  funcs.anchor = {}
  funcs.anchor.fx = function(x)
                    modelPose:V(x);
                    local ret = Power(Multiply((modelPose - modelAnchorPose),anchormaskmat),2)
                    local total = 0;
                    for i=1,14 do
                      for j =1,3 do
                        total = total+ret:R()[i][j];
                      end
                    end
                    --LOG("total anchor is");
                    --LOG(total);
                    return ret:R()*self.AnchorWeight;
                  end
  funcs.anchor.dx = function(x)
                    modelPose:V(x);
                    local ret =Power(Multiply((modelPose - modelAnchorPose),anchormaskmat),2)
                    return ret:D2(modelPose)*self.AnchorWeight;
                  end
  
  local targetLength = AutoTensor(self.BoneLength);
  local sumlength = Length(modelPose, self.BonePair);

  funcs.bonelength = {}                
  funcs.bonelength.fx = function(x)
                    modelPose:V(x);
                    local ret = Power(sumlength - targetLength,2)
                    local total = 0;
                    for i=1,8 do
                      
                        total = total+ret:R()[i];
                    
                    end
                    --LOG("total length is");
                    --LOG(total);
                    return ret:R()*self.LengthWeight;
                  end
  funcs.bonelength.dx = function(x)
                    modelPose:V(x);
                    local ret = Power(sumlength - targetLength,2)
                    return ret:D2(modelPose)*self.LengthWeight;
                  end 
  
  if self.lastFramePose ~=nil then
    local lastframe1 = AutoTensor(self.lastFramePose[1]);
    local lastframe2 = AutoTensor(self.lastFramePose[3]);

    local  frameDelta1 =  Substract(lastframe1,modelPose);
    local  frameDelta2 =  Substract(modelPose, lastframe2);
  
  
    funcs.frameDelta = {}                
    funcs.frameDelta.fx = function(x)
                      modelPose:V(x);
                      local ret = Power(Substract(frameDelta2,frameDelta1),2);
                      return ret:R()*self.FrameLinkWeight;
                    end
    funcs.frameDelta.dx = function(x)
                      modelPose:V(x);
                      local ret = Power(Substract(frameDelta2,frameDelta1),2);
                      return ret:D2(modelPose)*self.FrameLinkWeight;
                    end
  end 
  
  if self.CollisionJacobian~=nil then
    
    local tempPose = AutoTensor(self.lastFramePose[2]);
    
    local deltaPose = Substract(modelPose,tempPose);
    local flatdelta =  Flat(deltaPose);
   
    local collisionJ = AutoTensor(self.CollisionJacobian);
    local collisiondelta = Dot(collisionJ,flatdelta);
    

    local targetdelta =  AutoTensor();
    targetdelta:VFromTorch( torch.Tensor(self.CollisionDelta):view(-1) );

    funcs.collision = {}                
    funcs.collision.fx = function(x)
                      modelPose:V(x);
                      local ret = Power(Add(collisiondelta , targetdelta),2)
                      return ret:R()*self.CollisionWeight;
                    end
    funcs.collision.dx = function(x)
                      modelPose:V(x);
                      local ret = Power(Add(collisiondelta , targetdelta),2)
                      return ret:D2(modelPose)*self.CollisionWeight;
                    end 
    
  end
    
  local config =  {
    maxiter = 100,
    e_3 = 0.01,
  }
 
  if next(funcs) ~= nil then
    dogleg(funcs, modelPose:R(), config); 
  end
  if self.collisionEnergy ~=nil then
    --LOG(self.collisionEnergy );
  end
  
 
  
  return modelPose:R();
end


return laplacianRetarget;