--
local venuscore = require "venuscore"
local venusjson = require "venusjson"
local mnn = require "mnnfunction"
require "utility"
require "venusdebug"


local gcndetection = {}

function gcndetection:Initialize(modelpath, n_frame, delay)
  -- general config
  assert(delay < n_frame and delay >= 0)
  self.n_frame = n_frame
  self.delay = delay
  
  self.flip_idx = {4,5,6, 1,2,3, 10,11,12, 7,8,9, 13,14,15}
  self.hm_w = 24
  self.hm_h = 48
  self.scale = 0.5
  self.n_joint = #self.flip_idx
  self.input_queue_ori = nil
  self.input_queue_flip = nil
  -- init network
  self.input_node = "input_1_1"
  self.output_node = "stgcn_1/private_stgcn__output/reshape_2/Reshape"
  self.Interpreter = mnn.Interpreter();
  self.Interpreter:CreateFromFile(modelpath);
  local count = venuscore.IServicesSystem:GetThreadCount();
  self.Session = self.Interpreter:CreateSession(0, count/2);
  self.InputTensor = self.Interpreter:GetSessionInput( self.Session);
  self.OutputTensor = self.Interpreter:GetSessionOutput( self.Session,self.output_node);
end

function gcndetection:Estimate(input_2d)
  local input = self:UpdateInputQueue(input_2d)
  self.InputTensor:Fill(input)
  local succeed = self.Interpreter:RunSession(self.Session);
  if succeed == false then
    LOG("BIGONN RUN FAILED");
    return false;
  end
  output_3d = self.OutputTensor:GetTensorValue()
  output_3d_refined = self:ParseOutput3D(output_3d)
  return output_3d_refined
end

function gcndetection:ParseOutput3D(output_3d)
  -- N,C,T,J -> N,T,J,C
  -- N  -   batch = 2, [original, flipx]
  -- T  -   n_frame, [earliest, ..., latest]
  -- J  -   n_joints, [R/L-arm, R/L-leg, torso/head]
  -- C  -   dimension = 3 [x, y, z]
  local t = self.n_frame - self.delay
  local output_ori = {};
  local output_flip = {};
  for i = 1, 15 do
    output_ori[i] = {};
    output_flip[i] = {};
    for j = 1, 3 do
      --print(self:GetIndex(1,t,i,j), self:GetIndex(2,t,i,j))
      output_ori[i][j] = output_3d[self:GetIndex(1,t,i,j)];
      output_flip[i][j] = output_3d[self:GetIndex(2,t,i,j)];
    end
  end
  output_flip = self:FlipSymmetricJoints(output_flip)
  -- refine and denormalize
  local output_refined = {}
  for i = 1, 15 do
    output_refined[i] = {};
    for j = 1, 3 do
      output_refined[i][j] = (output_ori[i][j] + output_flip[i][j]) * 1000 / 2
      --output_refined[i][j] = output_ori[i][j] * 1000
    end
  end 
  return output_refined
end

function gcndetection:GetIndex(n,h,w,c)
  return  (n-1)*self.n_frame*self.n_joint*3 +
          (h-1)*self.n_joint*3 +
          (w-1)*3 +
          (c-1) + 1
end

function gcndetection:UpdateInputQueue(input_2d)
  local input_2d_copy = self:deep_copy(input_2d)
  input_2d_copy = self:NormalizeInput2D(input_2d_copy)
  if self.input_queue_ori == nil then
    self:InitInputQueue(input_2d_copy)
  else
    self:PushInputQueue(input_2d_copy)
  end
  local input_flatten = self:ConvertNetworkInput()
  return input_flatten
end

function gcndetection:NormalizeInput2D(input_2d)
  -- normalize
  for i = 1,14 do
    input_2d[i][1] = self.scale * (2 * (input_2d[i][1] / self.hm_w) - 1);
    input_2d[i][2] = self.scale * (2 * (input_2d[i][2] / self.hm_w) - self.hm_h/self.hm_w);
  end
  local root_2d_x = (input_2d[7][1] + input_2d[10][1]) / 2;
  local root_2d_y = (input_2d[7][2] + input_2d[10][2]) / 2;
  for i = 1,14 do
    input_2d[i][1] = input_2d[i][1] - root_2d_x;
    input_2d[i][2] = input_2d[i][2] - root_2d_y;
  end
  input_2d[15] = {0,0}  
  return input_2d
end

function gcndetection:ConvertNetworkInput()
  assert(self.input_queue_ori ~= nil)
  assert(self.input_queue_flip ~= nil)
  local input_flatten = {}
  self.input_batch = {self.input_queue_ori, self.input_queue_flip}
  for n = 1,2 do                -- N
    for t = 1,self.n_frame do   -- T
      for i = 1,self.n_joint do -- J
        for j = 1,2 do          -- C
          local cur_offset = (n-1)*self.n_frame*self.n_joint*2 + (t-1)*self.n_joint*2 + (i-1)*2 + j;
          input_flatten[cur_offset] = self.input_batch[n][t][i][j];
        end
      end
    end
  end
  return input_flatten
end

function gcndetection:PushInputQueue(input_2d)
  local input_2d_flip = self:FlipSymmetricJoints(input_2d)
  for i = 1, self.n_frame-1 do
    self.input_queue_ori[i] = self.input_queue_ori[i+1]
    self.input_queue_flip[i] = self.input_queue_flip[i+1]
  end
  self.input_queue_ori[self.n_frame] = input_2d
  self.input_queue_flip[self.n_frame] = input_2d_flip
end

function gcndetection:InitInputQueue(input_2d)
  local input_2d_flip = self:FlipSymmetricJoints(input_2d)
  self.input_queue_ori = {}
  self.input_queue_flip = {}
  for t = 1, self.n_frame do
    self.input_queue_ori[t] = input_2d
    self.input_queue_flip[t] = input_2d_flip
  end
end

function gcndetection:FlipSymmetricJoints(input)
  local input_flipx = {};
  for i = 1, #self.flip_idx do
    input_flipx[i] = {};
    for j = 1, #input[i] do
      if j == 1 then
        input_flipx[i][j] = -input[self.flip_idx[i]][j];
      else
        input_flipx[i][j] = input[self.flip_idx[i]][j];
      end
    end
  end
  return input_flipx
end

function gcndetection:ShowLogs(output_3d)
  print('3d prediction:')
  for i = 1, #output_3d do
    print(unpack(output_3d[i]))
  end
end

function gcndetection:deep_copy(orig)
  local copy
  if type(orig) == "table" then
    copy = {}
    for orig_key, orig_value in next, orig, nil do
      copy[self:deep_copy(orig_key)] = self:deep_copy(orig_value)
    end
    setmetatable(copy, self:deep_copy(getmetatable(orig)))
  else
    copy = orig
  end
  return copy
end

return gcndetection;
--]]
--[[
local venuscore = require "venuscore"
local venusjson = require "venusjson"
local bigonn = require "bigonnfunction"
require "utility"
require "venusdebug"


local gcndetection = {}

function gcndetection:Initialize(modelpath)
  local count = venuscore.IServicesSystem:GetThreadCount();
  self.idx_L = {3,4,5,9,10,11};
  self.idx_R = {0,1,2,6,7,8};
  self.input_node = 'input_1'
  self.output_node = "stgcn/private_stgcn__output/predict/BiasAdd"
  
  self.Net = bigonn.NetBigoNN(modelpath);
  self.Session = self.Net:CreateSession(0,count/2);
  self.OutputTensor = bigonn.TensorBigoNN();
end

function gcndetection:Estimate(input_2d)

  local dummy_input = {};
  for i = 1, 2*3*14*2 do
    dummy_input[i] = 0.01 * i;
  end
  
  self.InputTensor = bigonn.TensorBigoNN(dummy_input, 0, 0);
  self.Session:SetSessionInput(self.input_node, self.InputTensor);
  self.Session:RunAllPaths();
  self.Session:GetSessioOutput(self.output_node, self.OutputTensor);
  output_3d = self.OutputTensor:GetTensorValue()
  
  self:ShowLogs(output_3d)
  return output_3d
end

function gcndetection:ShowLogs(output_3d)
  print('3d prediction:')
  for i = 1, #output_3d do
    print(unpack(output_3d[i]))
  end
end

return gcndetection;
--]]