local apolloengine = require "apolloengine"
local mathfunction = require "mathfunction"
local venuscore = require "venuscore"
local venusjson = require "venusjson"
local nodeutility = require "apolloutility.nodeutility"
local apollonode = require "apolloutility.apollonode"
local mnn = require "mnnfunction"
require "utility"
require "venusdebug"

local mean3d =   {19.2636853 , -364.07221957,  -49.70211423,
                  27.16492364, -196.40941273,  -20.37640802,
                  22.19212815, -141.56468872,   -5.60397339,
                  -5.58979448, -364.93143521,  -55.38263611,
                 -20.33558863, -190.03006973,  -28.19176664,
                 -20.35292413, -128.4411834 ,  -11.07210294,
                10.95295143, -3.93221042,  2.62632808,
                -0.295091466,  297.266238,  65.4302828,
                 -6.60301005, 595.80278792, 104.0461003 ,
                -10.95295143,   3.93221042,  -2.62632808,
                -27.20607717, 304.02601631,  60.93665227,
                -29.8390817 , 601.00520868, 101.34394339,
                   7.82560758, -399.63497744,  -57.41391151,
                  11.73082648, -614.31339237,  -84.39472289}
                
                
local  std3d =    {165.61815005, 128.59831801, 177.02192118,
                   250.74411554, 190.4404165 , 248.59888738,
                   293.07497601, 296.80384374, 289.69961481,
                   168.92396867, 134.27708876, 179.31947835,
                   247.40264662, 187.61003002, 247.06318987,
                   286.40234624, 285.69366664, 290.44794477,
                   88.80125935, 29.06297271,    81.58489786,
                   217.54632543, 241.31781421, 227.17140183,
                   262.61936067, 322.53261843, 292.94816938,
                   88.80125935, 29.06297271,   81.58489786,
                   208.94766691, 235.99653981, 232.40474048,
                   264.9637643 , 316.55532318, 299.10806233,
                   146.69068404, 135.82975555, 165.10515741,
                   221.59899471, 193.711728  , 247.66506011}


local vnectdection = {}

function vnectdection:Initialize(modelpath)
  local count = venuscore.IServicesSystem:GetThreadCount();
  self.Interpreter = mnn.Interpreter();
  self.Interpreter:CreateFromFile(modelpath);
  self.Session = self.Interpreter:CreateSession(0,count/2);
  self.InputTensor = self.Interpreter:GetSessionInput( self.Session);
  self.OutputTensor = self.Interpreter:GetSessionOutput( self.Session,"vnect3d_delta_bonelength_connect/conv2d/Conv2D");
  
  self.ImageConfig = {
    ["sourceFormat"] = mnn.ImageProcesser.RGB,
    ["destFormat"] = mnn.ImageProcesser.BGR,
    ["filterType"] = mnn.ImageProcesser.BILINEAR,
    ["mean"] = {102,102,102},
    ["normal"] = {1/255,1/255,1/255},
  }
  
  self.ImageProcesser = mnn.ImageProcesser(self.ImageConfig);
  
  self.IdList = {0,1,2,3,4,5,6,7,8,9,10,11,12,13}
  
  self.heatmapsize = {24,46}
  self.imagesize = {184,368}
end

function vnectdection:Estimate(rgbtexture)
  local stream = rgbtexture:GetSourceStream();
  self.ImageProcesser:Convert(stream,self.InputTensor,self.imagesize[1],self.imagesize[2],0);
  local succeed = self.Interpreter:RunSession(self.Session);
  if succeed == false then
    LOG("MNN RUN FAILED");
    return false;
  end
  scoreduv = self.OutputTensor:GaussianHeatMapToScoredUV(self.IdList,0.3);
  local pos3dsample = {}
  for i =1,#scoreduv do 
    pos3dsample[i*3-2] = {}
    table.insert(pos3dsample[i*3-2],scoreduv[i][1]+14)
    table.insert(pos3dsample[i*3-2],scoreduv[i][3])
    table.insert(pos3dsample[i*3-2],scoreduv[i][4])
    pos3dsample[i*3-1] = {}
    table.insert(pos3dsample[i*3-1],scoreduv[i][1]+14*2)
    table.insert(pos3dsample[i*3-1],scoreduv[i][3])
    table.insert(pos3dsample[i*3-1],scoreduv[i][4])
    pos3dsample[i*3] = {}
    table.insert(pos3dsample[i*3],scoreduv[i][1]+14*3)
    table.insert(pos3dsample[i*3],scoreduv[i][3])
    table.insert(pos3dsample[i*3],scoreduv[i][4])
  end
  
  pos3d = self.OutputTensor:GetValue(pos3dsample);
  
  local pos2d = {};
  for i =1,#scoreduv do 
    pos2d[i] = {};
    table.insert(pos2d[i],scoreduv[i][3]);
    table.insert(pos2d[i],scoreduv[i][4]);
  end
  
  local pos3dret = {}
 
  for i=1 , 14 do 
    pos3dret[i] = {}
    table.insert( pos3dret[i],((pos3d[i*3-2] * std3d[i*3-2])+mean3d[i*3-2])/5)
    table.insert( pos3dret[i],((pos3d[i*3-1] * std3d[i*3-1])+mean3d[i*3-1])/5)
    table.insert( pos3dret[i],((pos3d[i*3]   * std3d[i*3])  +mean3d[i*3])/5  )
    
  end
  
  local scores = {}
  for i =1,#scoreduv do 
    table.insert(scores,scoreduv[i][2]);
  end
  return pos2d,pos3dret,scores;
end

return vnectdection;