local PoseStablizer = {}

function PoseStablizer:Initialize(th_low, th_high, k)
  self.th_low = th_low 
  self.th_high = th_high 
  self.k = k
  self.cache = nil
end

function PoseStablizer:Update(pose)
  if self.cache == nil then
    self.cache = self:deep_copy(pose)
  else
    local dist = self:dist(pose, self.cache)
    local ratio = self:ratio(dist)
    self.cache = self:weighted_avg(self.cache, pose, ratio)
  end
  
  local pose_smoothed = self:deep_copy(self.cache)
  return pose_smoothed
end

function PoseStablizer:GetCached()
  local cached_pose = self:deep_copy(self.cache)
  return cached_pose
end

function PoseStablizer:Reset()
  self.cache = nil
end

-- utils
function PoseStablizer:weighted_avg(p1, p2, w1)
  -- w2 = 1 - w1
  assert(#p1 == #p2)
  assert(#p1[1] == #p2[1])
  local p_avg = {}
  for i = 1, #p1 do
    p_avg[i] = {}
    for j = 1, #p1[i] do
  	p_avg[i][j] = p1[i][j] * w1[i] + p2[i][j] * (1-w1[i])
    end
  end
  --local ret = 
  return p_avg
end

function PoseStablizer:ratio(dist)
  -- convert to list
  if type(self.th_low) == "number" then
    local thresh = {}
    for i = 1, #dist do
      thresh[i] = self.th_low
    end
    self.th_low = thresh
  end
  if type(self.th_high) == "number" then
    local thresh = {}
    for i = 1, #dist do
      thresh[i] = self.th_high
    end
    self.th_high = thresh
  end
  -- soft threshold smoothing
  assert(#dist == #self.cache)
  local ratio = {}
  for i = 1, #dist do
    ratio_low = 1 - 1/(1 + math.exp(-1*self.k*(dist[i]/self.th_low[i]-1)))
    ratio_high = 1 - 1/(1 + math.exp(self.k*(dist[i]/self.th_high[i]-1)))
    ratio[i] = math.min(1, ratio_low+ratio_high)
  end
  return ratio
end

function PoseStablizer:dist(p1, p2)
  assert(#p1 == #p2)
  assert(#p1[1] == #p2[1])
  local dist = {}
  for i = 1, #p1 do
    dist[i] = 0
    for j = 1, #p1[i] do
	  dist[i] = dist[i] + (p1[i][j] - p2[i][j]) ^ 2
	end
	dist[i] = math.sqrt(dist[i])
  end
  return dist
end

function PoseStablizer:deep_copy(orig)
  local copy
  if type(orig) == "table" then
    copy = {}
    for orig_key, orig_value in next, orig, nil do
      copy[self:deep_copy(orig_key)] = self:deep_copy(orig_value)
    end
    setmetatable(copy, self:deep_copy(getmetatable(orig)))
  else
    copy = orig
  end
  return copy
end


return PoseStablizer