local apolloengine = require "apolloengine"
local mathfunction = require "mathfunction"
local venuscore = require "venuscore"
local venusjson = require "venusjson"
local nodeutility = require "apolloutility.nodeutility"
local apollonode = require "apolloutility.apollonode"
local bigonn = require "bigonnfunction"
require "utility"
require "venusdebug"


local vnectdection = {}

function vnectdection:Initialize(modelpath)
  local count = venuscore.IServicesSystem:GetThreadCount();
  self.Net = bigonn.NetBigoNN(modelpath);
  self.Session = self.Net:CreateSession(0,count/2);
  self.OutputTensor = bigonn.TensorBigoNN();   --self.Session:GetSessionOutput();
  self.ImageConfig = {
    ["sourceFormat"] = bigonn.TensorBigoNN.RGB,
    ["destFormat"] = bigonn.TensorBigoNN.BGR,
    ["mean"] = {127.5,127.5,127.5},
    ["normal"] = {1/255,1/255,1/255},
    ["destWidth"] = 92,
    ["destHeight"] = 184,
  }
  self.IdList = {0,1,2,3,4,5,6,7,8,9,10,11,12,13}
  self.heatmapsize = {12,24}
  self.inputname = "input_image";
  self.outputname = "predict_conv2a/Conv2D";
  
  pos2d,pos3dret,scores = self:GetHackedPose();
  self.pos2d = self:tbl_copy(pos2d);
  self.pos3d = self:tbl_copy(pos3dret);
  self.score = self:tbl_copy(scores);
  
  self.cropHW = 2;
end

function vnectdection:AsyncEstimate(rgbtexture)
    local stream = rgbtexture;
  self.InputTensor = bigonn.TensorBigoNN(stream,self.ImageConfig);
  self.Session:SetSessionInput(self.inputname , self.InputTensor);
  self.Session:RunAllPaths();
  self.Session:GetSessioOutput(self.outputname, self.OutputTensor);

  scoreduvxyz = self.OutputTensor:GaussianHeatMapToTopKScoredUVXYZ(self.IdList,0.3,5);
  
  pos2d, pos3dret, scores = vnectdection:ParseOutput(scoreduvxyz);

  local scnt = 0
  for i = 1, #scores do
    if scores[i] < 0.5 then
      scnt = scnt + 1
    end
  end
  if scnt > 10 then
    pos2d,pos3dret,scores = self:GetHackedPose();
  end

  self.pos2d = pos2d;
  self.pos3d = pos3dret;
  self.score = scores;
  
 
end

function vnectdection:Estimate(rgbtexture)
  self:AsyncEstimate(rgbtexture);
  return self:GetResult(rgbtexture);
end

function vnectdection:GetResult()
  return self.pos2d,self.pos3d,self.score;
end

function vnectdection:tbl_copy(orig)
    local orig_type = type(orig)
    local copy
    if orig_type == "table" then
        copy = {}
        for orig_key, orig_value in next, orig, nil do
            copy[self:tbl_copy(orig_key)] = self:tbl_copy(orig_value)
        end
    else -- number, string, boolean, etc
        copy = orig
    end
    return copy
end

function vnectdection:ParseOutput(scoreduvxyz)
  --[[
    inputs:
      scoreduvxyz   -   num_joints * top_k * 6(score, u, v, x, y, z)
  ]]
  local pos2d = {}
  local pos3d = {}
  local scores = {}
  for j = 1, #scoreduvxyz do
    scores[j] = scoreduvxyz[j][1][1]
    pos2d[j] = {scoreduvxyz[j][1][2], scoreduvxyz[j][1][3]}
    local sum_score = 0
    local weighted_pos3d = {0, 0, 0}
    local weighted_pos2d = {0, 0}
    for k = 1, #scoreduvxyz[j] do
      sum_score = sum_score + scoreduvxyz[j][k][1]
      weighted_pos3d[1] = weighted_pos3d[1] + scoreduvxyz[j][k][4] * scoreduvxyz[j][k][1]
      weighted_pos3d[2] = weighted_pos3d[2] + scoreduvxyz[j][k][5] * scoreduvxyz[j][k][1]
      weighted_pos3d[3] = weighted_pos3d[3] + scoreduvxyz[j][k][6] * scoreduvxyz[j][k][1]
      weighted_pos2d[1] = weighted_pos2d[1] + scoreduvxyz[j][k][2] * scoreduvxyz[j][k][1]
      weighted_pos2d[2] = weighted_pos2d[2] + scoreduvxyz[j][k][3] * scoreduvxyz[j][k][1]
    end
    pos3d[j] = {1000*weighted_pos3d[1]/sum_score,
                1000*weighted_pos3d[2]/sum_score,
                1000*weighted_pos3d[3]/sum_score}
    pos2d[j] = {weighted_pos2d[1]/sum_score,
                weighted_pos2d[2]/sum_score}
  end
  return pos2d, pos3d, scores
end

function vnectdection:GetHackedPose()
  pos2d =    {{ 4.5        , 8.6 },
              { 4.0625     , 11.22},
              { 4.0625     , 13.239},
              {6.9375     , 17.2173913 },
              {7.375      , 11.152},
              {7.375      , 12.7825},
              {5.        , 13.304},
              {5.        , 18.4345},
              {5.        , 19.04347},
              {6.4375     , 13.239},
              {6.4375     , 16.4348261},
              {6.4375     , 19.5        },
              {5.5        , 7.56},
              {5.5        , 5.93}}
  pos3dret = {{-138, -461,  -67},
                {-210, -185,  -32},
                {-199,   30,  -53},
                { 149, -431, -103},
                { 210, -181,  -56},
                { 215,   52,  -27},
                { -83,   32,    0},
                { -86,  462,   46},
                { -83,  820,  133},
                { 108,   48,  -20},
                {  86,  472,   45},
                {  88,  819,   82},
                { -18, -552,  -83},
                {   7, -745,  -86}}
  scores = {10,10,10,10,10,10,10,10,10,10,10,10,10,10}
  return pos2d, pos3dret, scores
end
return vnectdection;