local mathfunction = require "mathfunction"

local gmm = require "torchutility.gaussianmixturemodel"
local ViewCamera = require "torchutility.cameraviewproj"
local Robustifier = require "torchutility.robustifier"
local Substract =  require "torchutility.substract"
local Multiply =  require "torchutility.multiply"
local ToPoseList =  require "torchutility.toposelist"
local SignedSqrt =  require "torchutility.signedsqrt"
local AutoTensor =  require "torchutility.autotensor"
local Rodrigues = require "torchutility.rodrigues"
local SubTensor = require "torchutility.subtensor"
local TransformMat = require "torchutility.transformmat"
local dogleg = require "torchutility.dogleg"

local MixGauss = require "torchutility.mixgauss"
local Add =  require "torchutility.add"
local PoseExp =  require "torchutility.poseexp"
local Length =  require "torchutility.Length"
local Power =  require "torchutility.power"
local zVariations =  require "torchutility.zvariations"
local venuscore = require "venuscore"

local oneeuro = require "math.lowpassfilter.oneeurofilter_tensor"


require "venusdebug"
require "utility"

--3d骨骼到2d点位置索引key-2d点
local subproj = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};  --identity indexing

local skeletonfit = {}

--[===[
function skeletonfit:Initialize(bones, bonepair, viewmat)
  self:SetupBoneLength(bones, bonepair)
  self:SetCamera(viewmat)
end

function skeletonfit:SetupBoneLength(bones, bonepair)
  self.bonepair = bonepair;
  self.reallen = {};
  for _, boneindex in ipairs(bonepair) do
    local bone1 = bones[boneindex[1]];
    local bone2 = bones[boneindex[2]];
    table.insert(self.reallen, (bone1 - bone2):Length());
  end
end
]===]
function skeletonfit:Initialize(bonelength, bonepair, viewmat)
  self:SetupBoneLength(bonelength)
  self:SetCamera(viewmat)
  self.bonepair = bonepair;
end
function skeletonfit:SetupBoneLength(bonelength)
  self.reallen = {};
  for i = 1, #bonelength do
    table.insert(self.reallen, bonelength[i]);
  end
end

function skeletonfit:SetCamera(viewmat)
    self.fx = viewmat[1][1];
    self.fy = viewmat[2][2];
    self.cx = viewmat[3][1];
    self.cy = viewmat[3][2];
end

function skeletonfit:Correct(vec3list, vec2list, cameraoffset)
  --vec3list 
  local cameraoffsetlist = {}
  for i=1, #vec2list do
    table.insert(cameraoffsetlist, {cameraoffset[1], cameraoffset[2], cameraoffset[3]});
  end  
  
  local position = AutoTensor(vec3list);
  
  local pointtemp = SubTensor(subproj,position); -- 这里后面应该不需要了吧？
  
  local point3doffset = AutoTensor(cameraoffsetlist);
  local point3d = pointtemp+point3doffset;    
  local projectto = AutoTensor(vec2list);

  local Camera = ViewCamera(self.fx,self.fy,self.cx,self.cy,point3d);
  local Sub = Substract(projectto,Camera);
  local Robu = Robustifier(100,Sub);
  local SignedSqrt = SignedSqrt(Robu);
  
  local reallen = AutoTensor(self.reallen);
  local lenope = Length(position, self.bonepair);--
  local lSub = Substract(reallen, lenope);
  local leastsquares = Power(lSub, 2);
  
  self.lastpos = self.lastpos or torch.zeros(position:R():size());
  local lastpos = AutoTensor(self.lastpos);
  local zvari = zVariations(lastpos, position);
  local lsz = Power(zvari, 2);
  
  local funcs = {}
  funcs.proj = {}
  funcs.proj.fx =  function(x)
                    position:V(x);
                    local ret = SignedSqrt:R();
                    return ret;
                  end

  funcs.proj.dx =  function(x)
                    position:V(x);
                    local dx = SignedSqrt:D2(position);
                    return dx;
                  end


  local bonelengthweight = 100;
  funcs.bonelength = {}
  funcs.bonelength.fx = function(x)
    position:V(x);
    local ret = bonelengthweight * leastsquares:R();
    return ret;
  end

  funcs.bonelength.dx = function(x)
    position:V(x);
    local final = bonelengthweight * leastsquares:D2(position);
    return final;
  end

  local rootweight = 10000;
  local weighttensor = torch.Tensor(vec3list):zero()
  weighttensor[1][1] = rootweight;
  weighttensor[1][2] = rootweight;
  weighttensor[1][3] = rootweight;
  funcs.rootpos = {}
  funcs.rootpos.fx = function(x)
    position:V(x);
    local pos = position:R();
    local root = pos[{15, {}}];
    local ret = torch.Tensor({rootweight * torch.dot(root, root)});
    return ret;
  end

  funcs.rootpos.dx = function(x)
    position:V(x);
    local final = torch.cmul(weighttensor, position:R());    
    return final:view(1, -1);
  end
  
  local zvariationsweight = 100;
  funcs.zvariations = {}
  funcs.zvariations.fx = function(x)
    position:V(x);
    local ret = zvariationsweight * lsz:R();
    return ret;
  end

  funcs.zvariations.dx = function(x)
    position:V(x);
    local final = zvariationsweight * lsz:D2(position);
    return final;
  end

  local config =  {
        maxiter = 100,
        e_3 = 0.01,
      }
  dogleg(funcs, position:R(), config); 
  
  --LOG("show estimate result");
  --LOG(Sub);
  
  LOG("bone length diff")
  LOG(leastsquares:R());
  
  LOG("z variations")
  LOG(lsz:R());
  
  local res_pos = position:R();
  self.lastpos = res_pos:clone();
  --LOG("optimize complete,let's see the result");
  --LOG(res_pos);
  
  local ct = venuscore.ITimerSystem:GetTimevalue();
  if not self.oneeuro then
    --self.oneeuro = oneeuro(ct, res_pos, nil, 0.8, 0.4);
    self.oneeuro = oneeuro(ct, res_pos);
  else
    res_pos = self.oneeuro:filter(ct, res_pos);
  end
  
  local res_vec = {}
  for i=1, res_pos:size()[1] do
    local t = res_pos[{i, {}}];
    local vec = mathfunction.vector3(t[1], t[2], t[3]);
    table.insert(res_vec, vec);
  end 
  
  return res_vec;
end




return skeletonfit;