--local torch = require "torch"
local torch = require "torch"

local OutliersFilter = {}

function OutliersFilter:Initialize(bonepair, window, joints)
  --[[
      bonepair: 2d lua table, {{p, c}, ...}
      window: integer,
      joints: joints index to perform prediction
      weights: 
  --]]
  self.bonepair = bonepair
  self.len_q = window
  self.joints = joints
  self.H = {}
  self.S = {}
  self.fid = {}
  self.p = {}
  -- hack
  --self.th = {3,1,5, 3,4,5, 2,4,5, 1,3,5, 3,4}
  self.th = {3,1,2, 3,4,2, 2,4,2, 1,3,2, 3,4}
  self.th = {3,4,2, 3,4,2, 2,4,2, 2,4,2, 2,3}
  --self.th = {7,7,9, 7,7,9, 7,7,9, 7,7,9, 7,7}
  --self.th_high = {4,2,3, 4,5,3, 3,5,3, 2,4,3, 4,5}
  self.th_high = self.th
  self.count = {}
  for i = 1, #bonepair do
    self.p[self.bonepair[i][2]] = self.bonepair[i][1]
    self.count[i] = 0
  end
  for i = 1, #self.joints do
    self.H[i] = {}
    self.S[i] = {}
    self.fid[i] = {}
  end
  self.time = 0
  self.cur_3d = {}
  self.cur_fid = {}
end


function OutliersFilter:Update(pos3d, fid)
  self.time = self.time + 1
  
  self.last_3d = self.cur_3d
  self.last_fid = self.cur_fid
  
  self.cur_3d = pos3d
  self.cur_fid = fid  
  


end

function OutliersFilter:Filter(pos3d, fid)
  
  self:Update(pos3d, fid)
  print('OF:===================================')
  print('OF:fid_counter:')
  print('OF:', unpack(self.count))
  local c = {}
  local p = {}
  -- local diag_entries = {}
  -- local W = {}
  local H = {}
  local S = {}
  local theta = {}
  local feature = {}
  local pred = {}
  if self.time > self.len_q then

    for id_j = 1, #self.joints do
      c = self.joints[id_j]
      p = self.p[c]
      
      if self.cur_fid[c] < self.th[c] and 
          self.count[c] >= self.len_q and
          self.count[p] >= self.len_q then
           
      --if self.cur_fid[c] < self.th[c] then            
        print('OF:')
        print('OF:cur_joint:', c)
        print('OF:cur_fid:', fid[c])        
        print('OF:cur3d:', self.cur_3d[c][1], self.cur_3d[c][2], self.cur_3d[c][3])
        
        H = torch.Tensor(self.H[id_j])
        S = torch.Tensor(self.S[id_j])
        theta = self:pinv(H:transpose(1, 2)*H) * H:transpose(1, 2) * S
        feature = torch.Tensor({{self.cur_3d[p][1], self.cur_3d[p][2], self.cur_3d[p][3],
                                self.last_3d[c][1], self.last_3d[c][2], self.last_3d[c][3]}})
        pred =   feature * theta
        --update predicted coordinates and scores
        self.cur_3d[c] = {pred[1][1], pred[1][2], pred[1][3]}
        self.cur_fid[c] = (self.cur_fid[p] + self.last_fid[c])/2
        
        --log
        
        print('OF:cur3d predict:', self.cur_3d[c][1], self.cur_3d[c][2], self.cur_3d[c][3])
        print('OF:')
        print('OF:===================================')

      end
    end
  end
  
  for id_j = 1, #self.joints do
    c = self.joints[id_j]
    p = self.p[c]
    if self.time == 1 then
      self:tbl_push(self.H[id_j], {self.cur_3d[p][1],
                                self.cur_3d[p][2],
                                self.cur_3d[p][3],
                                0,0,0})
      self:tbl_push(self.fid[id_j], self.cur_fid[p])                      
    else
      self:tbl_push(self.H[id_j], {self.cur_3d[p][1],
                                self.cur_3d[p][2],
                                self.cur_3d[p][3],
                                self.last_3d[c][1],
                                self.last_3d[c][2],
                                self.last_3d[c][3]})
      --self:tbl_push(self.fid[id_j], (self.cur_fid[p] + self.last_fid[c])/2)
      self:tbl_push(self.fid[id_j], self.cur_fid[p])                      
    end
    self:tbl_push(self.S[id_j], self.cur_3d[c])
  
  end
  
  for id_j = 1, #self.joints do
    c = self.joints[id_j]
    pos3d[c] = self:deep_copy(self.cur_3d[c])
    fid[c] = self.cur_fid[c]
  end
  
  for i = 1, #self.bonepair-1 do
    if fid[i] >= self.th_high[i] then
      self.count[i] = self.count[i] + 1
    else
      self.count[i] = 0
    end
  end
  
  return pos3d, fid
end

function OutliersFilter:tbl_push(tbl, val)
  table.insert(tbl, self:deep_copy(val))
  if #tbl > self.len_q then
    table.remove(tbl, 1)
  end
end

function OutliersFilter:get_latest(tbl)
  local latest = self:deep_copy(tbl[#tbl])
  return latest
end

function OutliersFilter:deep_copy(orig)
  local copy
  if type(orig) == "table" then
    copy = {}
    for orig_key, orig_value in next, orig, nil do
      copy[self:deep_copy(orig_key)] = self:deep_copy(orig_value)
    end
    setmetatable(copy, self:deep_copy(getmetatable(orig)))
  else
    copy = orig
  end
  return copy
end


function OutliersFilter:pinv(A)
  -- A  -   torch.Tensor
  
  local U = {} 
  local S = {} 
  local V = {} 
  U, S, V = torch.svd(A)

  for i = 1, S:size(1) do
    if S[i] > 1e-8 or S[i] < -1e-8 then
      S[i] = 1 / S[i]
    end
  end

  local A_inv = V * torch.diag(S):transpose(1, 2) * U:transpose(1, 2)
  return A_inv
end

--[[
outfilter = OutliersFilter
outfilter:Initialize({{1,2},{2,3}}, 3, {1,2})

a = outfilter:deep_copy({})
print(a)
--]]
return OutliersFilter
