local apolloengine = require "apolloengine"
local mathfunction = require "mathfunction"
local venuscore = require "venuscore"
local nodeutility = require "apolloutility.nodeutility"
local apollonode = require "apolloutility.apollonode"
local bigonn = require "bigonnfunction"
local stb_bbox = require "avatar.filter.pose_stablizer_soft"
local stb_2d = require "avatar.filter.pose_stablizer_soft_bilateral"
--debug
local debugbbox = require "avatar.posedection.debugbbox"
local venusjson = require "venusjson"
require "utility"
require "venusdebug"


local vnectdection = {}

function vnectdection:Initialize(modelpath)
  local count = venuscore.IServicesSystem:GetThreadCount();
  self.Net = bigonn.NetBigoNN(modelpath);
  self.Session = self.Net:CreateSession(0,count/2);
  self.OutputTensor = bigonn.TensorBigoNN();
  self.ImageConfig = {
    ["sourceFormat"] = bigonn.TensorBigoNN.RGB,
    ["destFormat"] = bigonn.TensorBigoNN.BGR,
    ["mean"] = {127.5,127.5,127.5},
    ["normal"] = {1/255,1/255,1/255},
    ["destWidth"] = 128,
    ["destHeight"] = 256,
  }
  self.IdList = {0,1,2,3,4,5,6,7,8,9,10,11,12,13,14}
  self.heatmapsize = {32,64}--{24,48}
  self.inputname = "input_1"
  self.outputname = "pose_net/pose_net_output/output/Conv2D"
  self.max_score = 3 --score 取值范围
  self.cache = nil
  
  -- bounding box
  self.imagesize = nil
  self.bbox = nil
  self.bbox_dilate_scale = 1.2--1.3
  self.F_TARGET_ABSCENT = true
  
  -- bbox stablizer
  self.stb_bbox = stb_bbox
  self.stb_bbox:Initialize(45, 1)   -- 1080 x 1920
  --self.stb_bbox:Initialize(5, 0.1)   --  540 x  960
  
  -- pose stablizer
  self.stb_2d = stb_2d
  --local th_low = {1.5,1.5,2, 1.5,1.5,2, 1.5,1.5,2, 1.5,1.5,2, 1,1,1}
  --local th_high = {4,6,6, 4,6,6, 4,6,6, 4,6,6, 4,4,4}
  self.stb_2d:Initialize(30, 9999, 1) -- 1080 x 1920
  --self.stb_2d:Initialize(10, 9999, 1) --  540 x  960
  
  -- debug
  self.debugbbox = debugbbox
  self.debugbbox:Initialize()
  
end

function vnectdection:Estimate(rgbtexture)
  local stream = rgbtexture;
  -- 初始化bbox
  if self.imagesize == nil or self.F_TARGET_ABSCENT then
    self.imagesize = {stream:GetSize():x(), stream:GetSize():y()}; 
    self:SetBoundingBox(0, 0, self.imagesize[1], self.imagesize[2]);
  end
  -- 推理&解析
  self.InputTensor = bigonn.TensorBigoNN(stream, unpack(self.bbox));
  self.InputTensor:ConvertNormalize(self.ImageConfig);
  self.Session:SetSessionInput(self.inputname , self.InputTensor);
  self.Session:RunAllPaths();
  self.Session:GetSessioOutput(self.outputname, self.OutputTensor);
  
  local scoreduv = self.OutputTensor:GaussianHeatMapToTopKScoredUV(self.IdList,0.1,3);
  local pos2d = {}
  local scores = {}
  pos2d, scores = vnectdection:ParseOutput(scoreduv);
  self:UpdateFlags(scores);
  if self.F_TARGET_ABSCENT then
    -- reset bbox
    self:SetBoundingBox(0, 0, self.imagesize[1], self.imagesize[2]);
    self.stb_2d:Reset()
    self.stb_bbox:Reset()
    print("reset bbox due to target absecent")
  else
    -- 更新bbox & smooth
    self.cache = self.stb_2d:Update(pos2d)
    self:UpdateBoundingBox(pos2d, self.bbox_dilate_scale, 2);
  end
  pos2d = self:GetCached()
  --pos2d = self:Normalize(pos2d)
  return pos2d, scores
end

function vnectdection:Normalize(pos2d)
  -- normalize: imagesize -> (0, 1)
  local pos2d_norm = {}
  for i = 1, #pos2d do
    pos2d_norm[i] = { pos2d[i][1] /self.imagesize[1],
                      pos2d[i][2] /self.imagesize[1]}
  end
  return pos2d_norm
end

function vnectdection:UpdateFlags(scores)
  local keypoints_detected = 0
  local legs_detected = 0
  local arms_detected = 0
  for i = 1, #scores do
    if scores[i] > 0.4 then
      -- full body
      keypoints_detected = keypoints_detected + 1
      -- legs
      if i == 8 or i == 9 or i == 11 or i == 12 then
        legs_detected = legs_detected + 1
      end
      -- arms
      if i == 1 or i == 2 or i == 4 or i == 5 then
        arms_detected = arms_detected + 1
      end
    end
  end
  -- determine target presence
  if  keypoints_detected < 10 or
      legs_detected < 3 or
      arms_detected < 3 then
    self.F_TARGET_ABSCENT = true
  else
    self.F_TARGET_ABSCENT = false
  end
end


function vnectdection:ParseOutput(scoreduv)
  --[[
    inputs:
      scoreduv   -   num_joints * top_k * 3 (score, u, v)
  ]]
  -- scale: heatmap -> cropped
  -- offset: cropped -> original
  local scale = {self.bbox[3]/self.heatmapsize[1], self.bbox[4]/self.heatmapsize[2]}
  local offset = {self.bbox[1], self.bbox[2]}  -- bbox upper-left x, y
  -- parse output
  local pos2d = {}
  local scores = {}
  for j = 1, #scoreduv do
    scores[j] = scoreduv[j][1][1] / self.max_score -- normalize score
    local sum_score = 0
    local weighted_pos2d = {0, 0}
    for k = 1, #scoreduv[j] do
      sum_score = sum_score + scoreduv[j][k][1]
      weighted_pos2d[1] = weighted_pos2d[1] + scoreduv[j][k][2] * scoreduv[j][k][1]
      weighted_pos2d[2] = weighted_pos2d[2] + scoreduv[j][k][3] * scoreduv[j][k][1]
    end
    -- heatmap coord -> original image coord
    pos2d[j] = {offset[1] + scale[1] * (weighted_pos2d[1]/sum_score + 0.5),
                offset[2] + scale[2] * (weighted_pos2d[2]/sum_score + 0.5)}
  end
  return pos2d, scores
end

function vnectdection:SetBoundingBox(x, y, w, h)
  --assert(h == w * 2);
  self.bbox = {x, y, w, h};
  --print(unpack(self.bbox))
  local bbox_points = { {x,   y   },
                        {x+w, y   },
                        {x,   y+h },
                        {x+w, y+h }}
  self.debugbbox:Draw(bbox_points)
end

function vnectdection:UpdateBoundingBox(pos2d, scale, hwr)
  --[[  pos2d   -   原始图像坐标系下的2d-keypoint坐标, (x, y)
        scale   -   bbox扩张比例，下一帧检测bbox尺寸是当前tight-bbox的
                    scale_factor倍
        hwr     -   高宽比，2:1
 --]]
 
  -- 计算当前帧tight-bbox
  local u_min = self.imagesize[1]
  local u_max = 0
  local v_min = self.imagesize[2]
  local v_max = 0
  for i = 1, #pos2d do
    -- 注意pos2d坐标的顺序？？
    if pos2d[i][1] < u_min then
      u_min = pos2d[i][1]
    end
    if pos2d[i][1] > u_max then
      u_max = pos2d[i][1]
    end
    if pos2d[i][2] < v_min then
      v_min = pos2d[i][2]
    end
    if pos2d[i][2] > v_max then
      v_max = pos2d[i][2]
    end
  end
  local cx = (u_min+u_max)/2
  local cy = (v_min+v_max)/2
  local w_tight = u_max - u_min
  local h_tight = v_max - v_min
  local bbox_tight = {{cx, cy}, {w_tight, h_tight}}
  -- bbox平滑
  bbox_tight = self.stb_bbox:Update(bbox_tight)
  
  local min_margin = 30
  local margin = math.max((w_tight + h_tight)/2 * 0.05, min_margin)
  if h_tight <= hwr * w_tight then
    box_w = w_tight + 2 * margin
  else
    box_w = h_tight / hwr + 2 * margin
  end
  box_h = hwr*box_w
  
  local u_loose = cx - scale*box_w/2
  local v_loose = cy - scale*box_h/2
  local w_loose = scale*box_w
  local h_loose = scale*box_h
  
  -- 更新坐标，如果bbox大于texture，就复位至原图
  -- 允许左上坐标 x, y < 0，引擎中会计算bbox和图像的交集进行拷贝
  if w_loose > self.imagesize[1] then
    self:SetBoundingBox(0, 0, self.imagesize[1], self.imagesize[2]);
  else
    self:SetBoundingBox(u_loose, v_loose, w_loose, h_loose);
  end
end

function vnectdection:GetCached()
  local cached_pose = self:deep_copy(self.cache)
  return cached_pose
end

function vnectdection:deep_copy(orig)
  local copy
  if type(orig) == "table" then
    copy = {}
    for orig_key, orig_value in next, orig, nil do
      copy[self:deep_copy(orig_key)] = self:deep_copy(orig_value)
    end
    setmetatable(copy, self:deep_copy(getmetatable(orig)))
  else
    copy = orig
  end
  return copy
end

return vnectdection;