local mathfunction = require "mathfunction"
local apolloengine = require "apolloengine"


local collisionhandler = {}

function collisionhandler:Init(skenode)
  self.PartList =  {"Bip001 Head", "Bip001 R Forearm", "Bip001 L Forearm", 
  "Bip001 L Calf", "Bip001 R Calf", "Bip001 Spine1"};
  self.MassList = {0,1,1,1,1,0};
  self.skenode = skenode;
  
  self.ResultIdx = {["Bip001 R Forearm"] = 1, 
                    ["Bip001 L Forearm"] = 2, 
                    ["Bip001 L Calf"] = 3, 
                    ["Bip001 R Calf"] = 4}
                  
  self.JDelta = 0.001;
end

function collisionhandler:tbl_copy(orig)
    local orig_type = type(orig)
    local copy
    if orig_type == "table" then
        copy = {}
        for orig_key, orig_value in next, orig, nil do
            copy[self:tbl_copy(orig_key)] = self:tbl_copy(orig_value)
        end
    else -- number, string, boolean, etc
        copy = orig
    end
    return copy
end

function collisionhandler:CaculatePD()
  local Jacobian = {}
  apolloengine.IPhysicSystem:UpdateCollision(0.01,10,1.0/60.0);
  local collisionInfo = apolloengine.IPhysicSystem:GetPhysicsCollisionInfo();
  local basecollision,collisionIdx,collisionfilter = self:GetCollisionResult(collisionInfo);
  local basemodelpose = self.skenode:GetCurrentPose();
  local hascollision = false;
  for index,_ in pairs(collisionIdx) do
    hascollision = true;
    for i=1,3 do 
      local offsetmodepose = self:tbl_copy(basemodelpose);
      offsetmodepose[index][i] = offsetmodepose[index][i]+self.JDelta;
      self.skenode:UpdateSkeleton(offsetmodepose);
      apolloengine.IPhysicSystem:UpdateCollision(0.01,10,1.0/60.0);
      local collisionInfo = apolloengine.IPhysicSystem:GetPhysicsCollisionInfo();
      local offsetcollision,_,_ = self:GetCollisionResult(collisionInfo,collisionfilter);

      local deltacollision = offsetcollision - basecollision;
      for j = 1,deltacollision:size()[1]*3 do
        if  Jacobian[j] ==nil then
          Jacobian[j] = {}
          for row = 1,42 do
            Jacobian[j][row] = 0;
          end
        end
        Jacobian[j][(index-1)*3+i] = deltacollision[math.floor((j-1)/3)+1][(j-1)%3+1]/self.JDelta;
      end
    end
  end
  if hascollision then
    return torch.Tensor(Jacobian),basecollision;
  else
    return
  end
end



function collisionhandler:GetCollisionResult(collisionInfo,collisionfilter)
  local result = {}; -- torch.Tensor(4,3):fill(0);
  local collisionIdx = {};
  
  local collisionResult = {};
  local findedidx = {};
  for i = 1,#collisionInfo do
    local bid = collisionInfo[i]["objBId"];
    local aid = collisionInfo[i]["objAId"];
    
    local resultidx = #result+1;
    local find = true;
    if collisionfilter~=nil then
      find = false;
      for cfi=1,#collisionfilter do
        if find == false and collisionfilter[cfi][1]==aid and collisionfilter[cfi][2]==bid then
          resultidx = cfi;
          find = true;
        end
      end
    end
    if find == true then
      --LOG("** COLLISION  **");

      local aname = self.skenode:GetNameByID(aid);
      local bname = self.skenode:GetNameByID(bid);
      --local aresultidx = self.ResultIdx[aname];
      --local bresultidx = self.ResultIdx[bname];
      local bboneidx,bboneidx2 = self:GetIdxInBone(bname);
      
      local pointList = collisionInfo[i]["collisionPointList"];
      local distance = 0;
      local normal = nil;
      for j=1,#pointList do
        if pointList[j].distance<distance then
          distance = pointList[j].distance;
          normal = pointList[j].worldNormalOnB;
        end
      end
      
      local bdepths = self:GetDepthS(result[i]);
      
      if bboneidx~=0 and bdepths<distance*distance and resultidx~=nil then
        result[resultidx] = {};
        result[resultidx][1] = distance*normal[1];
        result[resultidx][2] = distance*normal[2];
        result[resultidx][3] = distance*normal[3];
        collisionIdx[bboneidx]=1;
        collisionIdx[bboneidx2]=1;
        collisionResult[resultidx] = {};
        collisionResult[resultidx][1] = aid;
        collisionResult[resultidx][2] = bid;
      end
    end
    
  end 
  local torchret = nil;
  
  if collisionfilter~=nil then
    for idx = 1,#collisionfilter do 
      if result[idx] == nil then
        result[idx] = {0,0,0};
      end
    end
  end
  
  if result~= nil and #result>0  then
    torchret = torch.Tensor(result);
  end
  return torchret,collisionIdx,collisionResult;
end

function collisionhandler:GetDepthS(vec3)
  if vec3==nil then
    return 0
  end
  local depthsquare = vec3[1]*vec3[1]+vec3[2]*vec3[2]+vec3[3]*vec3[3];
  return depthsquare;
end


function collisionhandler:GetIdxInBone(name)
  local idx = 0;
  for i=1,#self.skenode.bonenamemap do 
    if name == self.skenode.bonenamemap[i] then
      idx = i;
      break;
    end
  end
  if idx == 0 then
    LOG(name.."   not found!")
  end
  local idx2 = 0;
  if idx ==5 then
    idx2 = 6;
  end
  if idx ==6 then
    idx2 = 5;
  end
  if idx ==2 then
    idx2 = 3;
  end
  if idx ==3 then
    idx2 = 2;
  end
  if idx == 8 then
    idx2 = 9;
  end
  if idx == 9 then
    idx2 = 8;
  end
  if idx == 11 then
    idx2 = 12;
  end
  if idx == 12 then
    idx2 = 11;
  end
  return idx,idx2;
end

function collisionhandler:GetIdx(name)
  local idx = 0;
  for i=1,#self.PartList do
    if self.PartList[i] == name then
      idx = i;
      break;
    end
  end
  return idx;
end

function collisionhandler:GetMass(name)
  local idx = self:GetIdx(name);
  if idx ~=0 then
    return  self.MassList[idx];
  end
end


return collisionhandler;