local venuscore = require "venuscore"
local venusjson = require "venusjson"
local mnn = require "mnnfunction"
require "utility"
require "venusdebug"

local vnectdection = {}

function vnectdection:Initialize(modelpath)
  local count = venuscore.IServicesSystem:GetThreadCount();
  self.Interpreter = mnn.Interpreter();
  self.Interpreter:CreateFromFile(modelpath);
  self.Session = self.Interpreter:CreateSession(0,count/2);
  self.InputTensor = self.Interpreter:GetSessionInput( self.Session);
  --self.OutputTensor = self.Interpreter:GetSessionOutput( self.Session,"predict_bn2a/FusedBatchNorm");
  self.OutputTensor = self.Interpreter:GetSessionOutput( self.Session,"predict_conv2a/BiasAdd");

  self.ImageConfig = {
    ["sourceFormat"] = mnn.ImageProcesser.RGB,--YUV_NV21,
    ["destFormat"] = mnn.ImageProcesser.BGR,
    ["filterType"] = mnn.ImageProcesser.BILINEAR,
    ["mean"] = {127.5,127.5,127.5},
    ["normal"] = {1/255,1/255,1/255},
  }
  
  self.ImageProcesser = mnn.ImageProcesser(self.ImageConfig);
  self.ImageProcesser:SetImageResize(184,368,92,184) --src, dst
  
  self.IdList = {0,1,2,3,4,5,6,7,8,9,10,11,12,13}
  
  self.heatmapsize = {12,24}
  self.imagesize = {184,368} --原图大小
end

function vnectdection:Estimate(rgbtexture)

  local stream = rgbtexture;

  if stream:GetBufferSize()<self.imagesize[1]*self.imagesize[2]*stream:GetPixelSize() then
    return self:GetHackedPose();
  end

  self.ImageProcesser:Convert(stream,self.InputTensor,self.imagesize[1],self.imagesize[2],0);
  local succeed = self.Interpreter:RunSession(self.Session);
  if succeed == false then
    LOG("MNN RUN FAILED");
    return false;
  end
  scoreduvxyz = self.OutputTensor:GaussianHeatMapToTopKScoredUVXYZ(self.IdList,0.3,5);
  pos2d, pos3dret, scores = vnectdection:ParseOutput(scoreduvxyz)
  
  -- hack by jiasen 2019/12/27 16:39
  -- 检测不到人的时候给一个固定的pose
  local scnt = 0
  for i = 1, #scores do
    if scores[i] < 0.5 then
      scnt = scnt + 1
    end
  end
  if scnt > 10 then
    pos2d,pos3dret,scores = self:GetHackedPose();
  end
  -- end jiasen
  
  return pos2d,pos3dret,scores;
end
--[[]]
function vnectdection:ParseOutput(scoreduvxyz)
  --[[
    inputs:
      scoreduvxyz   -   num_joints * top_k * 6(score, u, v, x, y, z)
  ]]
  local pos2d = {}
  local pos3d = {}
  local scores = {}
  for j = 1, #scoreduvxyz do
    scores[j] = scoreduvxyz[j][1][1]
    pos2d[j] = {scoreduvxyz[j][1][2], scoreduvxyz[j][1][3]}
    local sum_score = 0
    local weighted_pos3d = {0, 0, 0}
    for k = 1, #scoreduvxyz[j] do
      sum_score = sum_score + scoreduvxyz[j][k][1]
      weighted_pos3d[1] = weighted_pos3d[1] + scoreduvxyz[j][k][4] * scoreduvxyz[j][k][1]
      weighted_pos3d[2] = weighted_pos3d[2] + scoreduvxyz[j][k][5] * scoreduvxyz[j][k][1]
      weighted_pos3d[3] = weighted_pos3d[3] + scoreduvxyz[j][k][6] * scoreduvxyz[j][k][1]
    end
    pos3d[j] = {1000*weighted_pos3d[1]/sum_score,
                1000*weighted_pos3d[2]/sum_score,
                1000*weighted_pos3d[3]/sum_score}
  end
  return pos2d, pos3d, scores
end

function vnectdection:GetHackedPose()
  pos2d =    {{ 9.        , 17.2173913 },
              { 8.125     , 22.43478261},
              { 8.125     , 26.47826087},
              {13.875     , 17.2173913 },
              {14.75      , 22.30434783},
              {14.75      , 25.56521739},
              {10.        , 26.60869565},
              {10.        , 32.86956522},
              {10.        , 38.08695652},
              {12.875     , 26.47826087},
              {12.875     , 32.86956522},
              {12.875     , 39.        },
              {11.        , 15.13043478},
              {11.        , 11.86956522}}
  pos3dret = {{-138, -461,  -67},
                {-210, -185,  -32},
                {-199,   30,  -53},
                { 149, -431, -103},
                { 210, -181,  -56},
                { 215,   52,  -27},
                { -83,   32,    0},
                { -86,  462,   46},
                { -83,  820,  133},
                { 108,   48,  -20},
                {  86,  472,   45},
                {  88,  819,   82},
                { -18, -552,  -83},
                {   7, -745,  -86}}
  scores = {10,10,10,10,10,10,10,10,10,10,10,10,10,10}
  return pos2d, pos3dret, scores
end

return vnectdection;