local Object = require "classic"

local PoseStablizer = Object:extend();

function PoseStablizer:new()
	PoseStablizer.super.new(self);
end

function PoseStablizer:Initialize(th, k, metric)
  self.th = th      -- 绝对值
  self.k = k        -- k越大越接近硬阈值
  self.cache = nil
  self.dist = self:_GetDist(metric)
end

function PoseStablizer:Update(pose)
  if self.cache == nil then
    self.cache = self:deep_copy(pose)
  else
    local dist = self:_Dist(pose, self.cache)
    local ratio = self:_Ratio(dist)
    self.cache = self:_WeightedAverage(self.cache, pose, ratio)
  end
  
  local pose_smoothed = self:deep_copy(self.cache)
  return pose_smoothed
end

function PoseStablizer:GetCached()
  local cached_pose = self:deep_copy(self.cache)
  return cached_pose
end

function PoseStablizer:Reset(th, k)
  self.th = th
  self.k = k
  self.cache = nil
end

-- utils
function PoseStablizer:_WeightedAverage(p1, p2, w1)
  -- w2 = 1 - w1
  assert(#p1 == #p2)
  assert(#p1[1] == #p2[1])
  local p_avg = {}
  for i = 1, #p1 do
    p_avg[i] = {}
    for j = 1, #p1[i] do
  	p_avg[i][j] = p1[i][j] * w1[i] + p2[i][j] * (1-w1[i])
    end
  end
  --local ret = 
  return p_avg
end

function PoseStablizer:_Ratio(dist)
  local ratio = {}
  for i = 1, #dist do
    ratio[i] = 1 - 1/(1 + math.exp(-1*self.k*(dist[i]/self.th-1)))
  end
  return ratio
end

function PoseStablizer:_Dist(p1, p2)
  assert(#p1 == #p2)
  local dist = {}
  for i = 1, #p1 do
	  --dist[i] = self:_CosineDist(p1[i], p2[i])
	  --dist[i] = self:_L2Dist(p1[i], p2[i])
    dist[i] = self:dist(p1[i], p2[i])
  end
  return dist
end

function PoseStablizer:_GetDist(identifier)
if identifier == nil then
  return self._L2Dist
elseif identifier == 'l2' then
  return self._L2Dist
elseif identifier == 'cos' then
  return self._CosineDist
else
  print('PoseStablizer: Invalid metric, use default L2 distance')
  return self._L2Dist
end
end


function PoseStablizer:_L2Dist(p1, p2)
  assert(#p1 == #p2)
  local dist = 0
  for i = 1, #p1 do
	  dist = dist + (p1[i] - p2[i]) ^ 2
	end
  return math.sqrt(dist)
end

function PoseStablizer:_LpDist(p1, p2, p)
  assert(#p1 == #p2)
  assert(p >= 1)
  local dist = 0
  for i = 1, #p1 do
	  dist = dist + math.pow(math.abs(p1[i] - p2[i]), p)
	end
  return math.pow(dist, 1/p)
end

function PoseStablizer:_CosineDist(p1, p2)
  assert(#p1 == #p2)
  local product = 0
  local norm1 = 0
  local norm2 = 0
  for i = 1, #p1 do   
    product = product + p1[i] * p2[i]
    norm1 = norm1 + p1[i] * p1[i]
    norm2 = norm2 + p2[i] * p2[i]    
	end
  norm1 = math.sqrt(norm1)
  norm2 = math.sqrt(norm2)
  if norm1 == 0 or norm2 == 0 then
    return 0
  else
    return product / (norm1 * norm2)
  end
end

function PoseStablizer:deep_copy(orig)
  local copy
  if type(orig) == "table" then
    copy = {}
    for orig_key, orig_value in next, orig, nil do
      copy[self:deep_copy(orig_key)] = self:deep_copy(orig_value)
    end
    setmetatable(copy, self:deep_copy(getmetatable(orig)))
  else
    copy = orig
  end
  return copy
end

--[[
stb = PoseStablizer
stb:Initialize(1, 30, 'cos')
p1 = {{0,0,0}}
p2 = {{10,100,100}}
pp1 = stb:Update(p1)
pp2 = stb:Update(p2)
--]]

return PoseStablizer