local Venuscore = require "venuscore"
local apolloengine = require "apolloengine"
local optimization = require "optimization"
local mathfunction = require "mathfunction"
local util = require "behavior.avatar_behavior.util"
local oneeurofilter = require "math.lowpassfilter.oneeurofilter"


local ModelPose = Venuscore.VenusBehavior:extend("ModelPose"); 

function ModelPose:new()
    self.cutoffscore = 0.1;
    self.scoreMin = 0.3;
    self.reprojweight = {1,1,1,1,1,1,1,1,1,1,1,1,1,1};
    self.bonelengthweight = {0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.05};
    self.bonepair = {{0,1},{1,2},{3,4},{4,5},{6,7},{7,8},{9,10},{10,11},{12,13},{0,9},{3,6}};
    self.mcweight = {0.06,0.0,0.0,0.06,0.0,0.0,0.06,0.0,0.0,0.06,0.0,0.0,0.0,0.0};
    self.mzweight = {1.5,0.03,0.03,1.5,0.03,0.03,1.5,0.03,0.03,1.5,0.03,0.03,0.03,0.03};
    -- state machine against flickering
    self.cntFullbody = 0;
    self.cntTolerance = 3;
    self.lastFrameDetected = false;
    -- full body judgement
    self.NeedFullBody = false;
    self.maskFullbody = {1,0,0, 1,0,0, 1,1,0, 1,1,0, 1,1};
    -- score weighting for multiple model levels
    self.modellevel = nil; -- 0-default/1-small/2-medium/3-large
    self.scoreWeightLarge  = {10,2,0, 10,2,0, 5,0,0, 5,0,0, 10,10};
    self.scoreWeightMedium = {10,2,0, 10,2,0, 5,0,0, 5,0,0, 10,10};
    self.scoreWeightSmallDefault = {10,2,0, 10,2,0, 5,0,0, 5,0,0, 10,10};
    self.scoreTweakL = 0.6; -- fintune weigths for specified model level
    self.scoreTweakM = 0.55;
    self.scoreTweakSD = 0.5;
    self:NormalizeWeights(); -- normalize to sum(scoreweights) = 14 * self.scoreTweak

    self.RootEstimate = optimization.RootEstimate();
    self.ReprojCorrect = optimization.ReprojCorrect();
    self.ReprojCorrect:SetBoneLengthWeight(self.bonelengthweight);
    
    self.recogition=nil;
    self.camera = nil;
   
    self.parents = {0,1,2,0,4,5,0,7,8,0,10,11,0,13};
    self.eurofilters = {};
    self:InitRigisterForBluePrint();
end

function ModelPose:InitRigisterForBluePrint()
  -- positions
  self.pos_R_Shoulder = mathfunction.vector3();
  self.pos_R_Elbow    = mathfunction.vector3();
  self.pos_R_Wrist    = mathfunction.vector3();
  self.pos_L_Shoulder = mathfunction.vector3();
  self.pos_L_Elbow    = mathfunction.vector3();
  self.pos_L_Wrist    = mathfunction.vector3();
  self.pos_R_Hip      = mathfunction.vector3();
  self.pos_R_Knee     = mathfunction.vector3();
  self.pos_R_Ankle    = mathfunction.vector3();
  self.pos_L_Hip      = mathfunction.vector3();
  self.pos_L_Knee     = mathfunction.vector3();
  self.pos_L_Ankle    = mathfunction.vector3();
  self.pos_Pelvis     = mathfunction.vector3();
  -- rotations
  self.rot_R_Shoulder = mathfunction.Quaternion();
  self.rot_R_Elbow    = mathfunction.Quaternion();
  self.rot_R_Wrist    = mathfunction.Quaternion();
  self.rot_L_Shoulder = mathfunction.Quaternion();
  self.rot_L_Elbow    = mathfunction.Quaternion();
  self.rot_L_Wrist    = mathfunction.Quaternion();
  self.rot_R_Hip      = mathfunction.Quaternion();
  self.rot_R_Knee     = mathfunction.Quaternion();
  self.rot_R_Ankle    = mathfunction.Quaternion();
  self.rot_L_Hip      = mathfunction.Quaternion();
  self.rot_L_Knee     = mathfunction.Quaternion();
  self.rot_L_Ankle    = mathfunction.Quaternion();
  self.rot_Pelvis     = mathfunction.Quaternion();
  self.rot_Spine      = mathfunction.Quaternion();
  self.rot_Neck       = mathfunction.Quaternion();
end

function ModelPose:SetCachedPoseForBluePrint()
  self.pos_R_Shoulder:Set(self.transform[1].position);
  self.pos_R_Elbow:Set(self.transform[2].position);
  self.pos_R_Wrist:Set(self.transform[3].position);
  self.pos_L_Shoulder:Set(self.transform[4].position);
  self.pos_L_Elbow:Set(self.transform[5].position);
  self.pos_L_Wrist:Set(self.transform[6].position);
  self.pos_R_Hip:Set(self.transform[7].position);
  self.pos_R_Knee:Set(self.transform[8].position);
  self.pos_R_Ankle:Set(self.transform[9].position);
  self.pos_L_Hip:Set(self.transform[10].position);
  self.pos_L_Knee:Set(self.transform[11].position);
  self.pos_L_Ankle:Set(self.transform[12].position);
  self.pos_Pelvis:Set(self.transform[13].position);
  
  self.rot_R_Shoulder:Set(self.transform[1].rotation);
  self.rot_R_Elbow:Set(self.transform[2].rotation);
  self.rot_R_Wrist:Set(self.transform[3].rotation);
  self.rot_L_Shoulder:Set(self.transform[4].rotation);
  self.rot_L_Elbow:Set(self.transform[5].rotation);
  self.rot_L_Wrist:Set(self.transform[6].rotation);
  self.rot_R_Hip:Set(self.transform[7].rotation);
  self.rot_R_Knee:Set(self.transform[8].rotation);
  self.rot_R_Ankle:Set(self.transform[9].rotation);
  self.rot_L_Hip:Set(self.transform[10].rotation);
  self.rot_L_Knee:Set(self.transform[11].rotation);
  self.rot_L_Ankle:Set(self.transform[12].rotation);
  self.rot_Spine:Set(self.transform[13].rotation);

end

function ModelPose:NormalizeWeights()
  local sumL = 0;
  local sumM = 0;
  local sumS = 0;
  self.sumMask = 0;
  local numJoints = #self.maskFullbody;
  for i = 1, numJoints do
    sumL = sumL + self.scoreWeightLarge[i];
    sumM = sumM + self.scoreWeightMedium[i];
    sumS = sumS + self.scoreWeightSmallDefault[i];
    self.sumMask = self.sumMask + self.maskFullbody[i];
  end
  assert((sumL > 0) and (sumM > 0) and (sumS > 0))
  for i = 1, numJoints do
    self.scoreWeightLarge[i] = numJoints * self.scoreTweakL * self.scoreWeightLarge[i] / sumL;
    self.scoreWeightMedium[i] = numJoints * self.scoreTweakM * self.scoreWeightMedium[i] / sumM;
    self.scoreWeightSmallDefault[i] = numJoints * self.scoreTweakSD * self.scoreWeightSmallDefault[i] / sumS;
  end
end

function ModelPose:IsLessScore()
    if self.scores==nil then
      return true;
    end
    local scoreweight = self:GetScoreWeight()
    local total = 0;
    local validkeypoint = 0;
    for i = 1, #self.scores do
        total = total + self.scores[i] * scoreweight[i];
        if self.maskFullbody[i] == 1 and self.scores[i] > self.cutoffscore then
          validkeypoint = validkeypoint + 1;
        end
    end
    local res = (self.scoreMin > total) or (self.NeedFullBody and (validkeypoint ~= self.sumMask));
    -- state machine against flickering
    -- to do: the logic here can be simplified
    if self.lastFrameDetected == true then
      -- last frame detected
      if res then
        self.cntFullbody = self.cntFullbody + 1;
      else
        self.cntFullbody = 0;
      end
      if self.cntFullbody > 0 then
        self.lastFrameDetected = false;
        self.cntFullbody = 0;
      end
    else
      -- last frame NOT detected
      if res then
        self.cntFullbody = 0;
      else
        self.cntFullbody = self.cntFullbody + 1;
      end
      if self.cntFullbody > self.cntTolerance then
        self.lastFrameDetected = true;
        self.cntFullbody = 0;
      end
    end
    -- end state machine
    --LOG(">> TOTAL:"..total.." | scoreMin: "..self.scoreMin);
    return not self.lastFrameDetected;
end

function ModelPose:GetScoreWeight()
  if self.modellevel == nil then
    if self.recognition == nil then
      return self.scoreWeightSmallDefault
    else
      self.modellevel = self.recognition:GetPoseModelLevel()[1]
    end
  end
  if self.modellevel == 3 then      --large
    return self.scoreWeightLarge
  elseif self.modellevel == 2 then --medium
    return self.scoreWeightMedium
  else                              --small and default
    return self.scoreWeightSmallDefault
  end
end

function ModelPose:IsBack()
  if self.pos2d == nil then
    return true;
  end
  local RSPos = mathfunction.vector2( self.pos2d[1][1],  self.pos2d[1][2]);
  local LSPos = mathfunction.vector2( self.pos2d[4][1],  self.pos2d[4][2]);
  local RHPos = mathfunction.vector2( self.pos2d[7][1],  self.pos2d[7][2]);
  local LHPos = mathfunction.vector2( self.pos2d[10][1], self.pos2d[10][2]);

  if  LSPos:x() < RSPos:x() or
      LHPos:x() < RHPos:x() then
      return true;
  else    
      return false;
  end
end

function ModelPose:IsLateral()
  if self.pos2d == nil then
    return true;
  end
  local RSPos = mathfunction.vector2( self.pos2d[1][1],  self.pos2d[1][2]);
  local LSPos = mathfunction.vector2( self.pos2d[4][1],  self.pos2d[4][2]);
  local RHPos = mathfunction.vector2( self.pos2d[7][1],  self.pos2d[7][2]);
  local LHPos = mathfunction.vector2( self.pos2d[10][1], self.pos2d[10][2]);

  if math.abs( LSPos:x() - RSPos:x() ) < self.lateralErr or
     math.abs( LHPos:x() - RHPos:x() ) < self.lateralErr then
      return true;
  else
      return false;
  end
end

function ModelPose:GetRecongition()
  if self.recognition==nil then
      return;
  end
  local results = self.recognition:GetResult();
  if results[1]==nil then
    self.facedetected = false;
  else
    self.facedetected = true;
  end
  if results == nil or
     results[128] == nil or
     results[256] == nil then
       
    return nil, nil, nil
  end
  
  local pose2d_raw = results[128][1]; -- x1, y1, score1, x2, y2, score2, ...
  local pose3d_raw = results[256][1]; -- x1, y1,     z1, x2, y2,     z2, ...
  local pose2d_ret = {};
  local score_ret = {};
  local pose3d_ret = {};
  for i = 1, 14 do
      pose2d_ret[i] = {pose2d_raw[(i-1)*3+1], pose2d_raw[(i-1)*3+2]};
      score_ret[i] = pose2d_raw[(i-1)*3+3];
      pose3d_ret[i] = {pose3d_raw[(i-1)*3+1], pose3d_raw[(i-1)*3+2], pose3d_raw[(i-1)*3+3]}
      --LOG("GOT SCORE I IS "..pose2d_raw[(i-1)*3+3]);
  end
  self.pos2d = pose2d_ret;
  self.scores = score_ret;
  return pose2d_ret, score_ret, pose3d_ret
end

function ModelPose:UpdateCameraParam()
    if  self.camera==nil then
        return;
    end
    
    local resolution = self.camera.CameraResolution;
  
    local cx = 0.5*resolution:x();
    local cy = 0.5*resolution:y();


    local projectionmat = self.camera:GetProject();
    local fx = projectionmat.a11;
    local fy = projectionmat.a22;
    fx = fx * cx;
    fy = fy * cy;

    if math.abs(fx-fy)>0.01 then
        LOG("DIFF FOCUS DISTANCE"..fx.." "..fy);
    end
    
    self.RootEstimate:SetCameraParam(fx,fy,cx,cy);
    self.RootEstimate:SetContinueWeight(0.01);

    self.ReprojCorrect:SetCameraParam(fx,fy,cx,cy);
end

function ModelPose:GetWeightFromScore(score)
  self.reprojweight = {1,1,1,1,1,1,1,1,1,1,1,1,1,1};
  self.bonelengthweight = {0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.05,0.05};
  self.mcweight = {0.06,0.0,0.0,0.06,0.0,0.0,0.06,0.0,0.0,0.06,0.0,0.0,0.0,0.0};
  self.mzweight = {1.5,0.03,0.03,1.5,0.03,0.03,1.5,0.03,0.03,1.5,0.03,0.03,0.03,0.03};
  
  local weightmul = {1,1,1,1,1,1,1,1,1,1,1,1,1,1};
  
  local reprojweight = {};
  local bonelengthweight = {};
  local mcweight = {};
  local mzweight = {};
  
  for i=1,14 do
    local mul =1;
    if score[i]<self.cutoffscore then
      mul = 0;
    end
    reprojweight[i] = self.reprojweight[i]*mul;
    mcweight[i] = self.mcweight[i]*mul;
    mzweight[i] = self.mzweight[i]*mul;
  end  
  
  for i=1,11 do
    local mul =1
    local index1 = self.bonepair[i][1]+1;
    local index2 = self.bonepair[i][2]+1;
    if score[index1]<self.cutoffscore or score[index2]<self.cutoffscore then
      mul = 0;
    end
    bonelengthweight[i] = self.bonelengthweight[i]*mul;
  end
  
  self.ReprojCorrect:SetReprojWeight(reprojweight);
  self.ReprojCorrect:SetBoneLengthWeight(bonelengthweight);
  self.ReprojCorrect:SetMinimizeCWeight(mcweight);
  self.ReprojCorrect:SetMinimizeZWeight(mzweight);
  self.RootEstimate:SetWeight({1,0,0,1,0,0,1,0,0,1,0,0,0,0});
 
  local str = "reprojection ids:"
  for i=1,#reprojweight do
    str = str..reprojweight[i].." "
  end
  --LOG(str);
end

function ModelPose:GetHumanPose()

    local  pose2d_ret,score_ret, pose3d_ret =   self:GetRecongition();
    
    local totalscore = 0;
    if pose2d_ret==nil then
      return nil ,nil,nil
    end
    for i = 1, 14 do
        totalscore = totalscore+score_ret[i];
    end
    self:UpdateCameraParam();
    self:GetWeightFromScore(score_ret);
    
    if totalscore> self.scoreMin then
      return pose2d_ret, score_ret, pose3d_ret;
    else
      return nil ,nil,nil
    end
end

function ModelPose:SmoothScores(score_ret,time)
  for i =1 ,#score_ret do
     if self.eurofilters[i] ==nil then
        self.filtertime = 0;
        self.eurofilters[i] = oneeurofilter(0,score_ret[i],0,3,0.1,2);
      else
        local x = self.eurofilters[i]:filter(self.filtertime,score_ret[i]);
        score_ret[i] = x;
      end
  end
  self.filtertime = self.filtertime+time;
  return score_ret;
end

function ModelPose:CaculateTransforms( pos3d,scores )
  self.transform = {};
  
  for i =1, 12 do
    self.transform[i] = {};
    if scores[i]<self.cutoffscore then
      self.transform[i].active = false;
    else
      self.transform[i].active = true;
    end
    self.transform[i].position = mathfunction.vector3(pos3d[i][1],pos3d[i][2],pos3d[i][3]);
    if self.parents[i] == 0 then
      self.transform[i].rotation = mathfunction.Quaternion();
    else
      
      local paridx = self.parents[i];
      local parentpos =   mathfunction.vector3(pos3d[paridx][1],pos3d[paridx][2],pos3d[paridx][3]);
      local forward =  self.transform[i].position - parentpos;
      
      local defaultup = mathfunction.vector3(0,1,0);
      --[[
      local defaultfoward = mathfunction.vector3(0,0,1);
      --local perpendicular_up = defaultup - defaultup:Dot(forward) * forward;
      --perpendicular_up:NormalizeSelf();
      local updir = math.abs(forward:Dot(defaultup));
      local forwarddir = math.abs(forward:Dot(defaultfoward));
      local tangent = defaultup;
      if updir>forwarddir then
        tangent = defaultfoward
      end
      local rotmat = util:ConstractRotationFromVector(tangent,forward);
      self.transform[i].rotation = rotmat:ToQuaternion();]]
      --self.transform[i].rotation = mathfunction.Quaternion();
      --self.transform[i].rotation:AxisToAxis(defaultup,forward);
      
      local xrot = forward:x();
      local yrot = forward:y();
      local zrot = forward:z();
      
      local tanz = yrot/xrot;
      local tany = -zrot/xrot;
      local rotinz = math.atan(tanz);
      if xrot<0 then
        rotinz = rotinz+math.pi;
      end
      
      local rotiny = math.atan(tany);
      if xrot<0 then
        --rotiny = rotiny+math.pi;
      end
      
      local rotation = mathfunction.Quaternion();
      rotation:RotateXYZ(0,0,rotinz);
      self.transform[i].rotation = rotation;
      
      local rotationz = mathfunction.Quaternion();
      rotationz:RotateXYZ(0,rotiny,0);
      self.transform[i].rotationz = rotationz;
      
      --self.transform[i].rotation = self.transform[i].rotation*self.transform[i].rotationz;
      
    end
  end
  self.transform[13] = self:GetPelvisTransform(pos3d);
  local pelvisscore = scores[1]+scores[4]+scores[7]+scores[10];
  if pelvisscore>self.cutoffscore*4 then
    self.transform[13].active = true;
  else
    self.transform[13].active = false;
  end
end

function ModelPose:GetTransform(index)
  if self.transform == nil then
    return nil;
  end
  return self.transform[index];
end

function ModelPose:GetEditorTransform(index)
 
end

function ModelPose:GetPelvisTransform(pos3d)
  local RSPos = mathfunction.vector3( pos3d[1][1],  pos3d[1][2],  pos3d[1][3]);
  local LSPos = mathfunction.vector3( pos3d[4][1],  pos3d[4][2],  pos3d[4][3]);
  local RHPos = mathfunction.vector3( pos3d[7][1],  pos3d[7][2],  pos3d[7][3]);
  local LHPos = mathfunction.vector3( pos3d[10][1], pos3d[10][2], pos3d[10][3]);
  
  local shoulderwidth = (RSPos - LSPos):Length();
  local hipwidth = (RHPos - LHPos):Length();

  local bodyvector = (RSPos+LSPos)/2 - (RHPos+LHPos)/2;
  local bodyheight = bodyvector:Length();
  local shouldervector = LSPos-RSPos;
  local hipvector = LHPos-RHPos;


  local avgvector = hipvector:Normalize()+shouldervector:Normalize();

  local hiprotate      =  util:ConstractRotationFromVector(hipvector:Normalize(),bodyvector:Normalize());
  local shoulderrotate =  util:ConstractRotationFromVector(shouldervector:Normalize(),bodyvector:Normalize());
  local avgrotate      =  mathfunction.Mathutility:Slerp(hiprotate:ToQuaternion(),shoulderrotate:ToQuaternion(),0.5);
  self.rot_Pelvis:Set(hiprotate:ToQuaternion());
  self.rot_Neck:Set(shoulderrotate:ToQuaternion());
  
  local hippos = (RHPos+LHPos)/2;
  local sholderpos = (RSPos+LSPos)/2;
  local modelavgpos = self:Lerp(sholderpos,hippos,0.5);
  
  local transform = {};
  transform.position = modelavgpos;
  transform.rotation = avgrotate;
  return transform;
end
function ModelPose:Lerp(from,to,percent)
    local ret = mathfunction.vector3();
    ret.mx = (to.mx - from.mx)*percent+from.mx;
    ret.my = (to.my - from.my)*percent+from.my;
    ret.mz = (to.mz - from.mz)*percent+from.mz;
    return ret;
end

return ModelPose;