local AutoTensor = require "torchutility.autotensor"
local Operator = require "torchutility.operator"
local venusjson = require "venusjson"
local venuscore = require "venuscore"
local torch = require "torch"

local Rodrigues = Operator:extend();
local EPSILON = 1e-5;

function Rodrigues:new(...)
  Rodrigues.super.new(self,...);
end

function Rodrigues:_BaseCaculate(r)
  local theta = r:norm();
  local itheta = theta > EPSILON and 1/theta or 0;  
  r = r * itheta;
  if not self.theta or not self.theta == theta then
    self.theta = theta;
    
    local c = math.cos(theta);
    local s = math.sin(theta);
    local c1 = 1 - c;

    local rrt = torch.ger(r, r);
    local r_x = torch.Tensor({
        {    0, -r[3],  r[2]},
        { r[3],     0, -r[1]},
        {-r[2],  r[1],     0}});
    self.c = c;
    self.c1 = c1;
    self.s = s;
    self.theta = theta;
    self.itheta = itheta;
    self.rrt = rrt;
    self.r_x = r_x;
  end
  return self.c, self.c1, self.s, self.theta, self.itheta, r, self.rrt, self.r_x;
end

local eye = torch.eye(3);
function Rodrigues:Caculate(pr)
  
  local c, c1, s, theta, itheta, r, rrt, r_x = self:_BaseCaculate(pr);
  
  local R;
  if theta > EPSILON then
    --R = cos(theta)*I + (1 - cos(theta))*r*rT + sin(theta)*[r_x]
    R = c * eye + c1 * rrt + s * r_x;
  else
    R = torch.eye(3,3);
  end
  
  return R;
end

local I = torch.eye(3):view(-1);
local d_r_x = torch.Tensor({
      {0, 0, 0, 0, 0, -1, 0, 1, 0 },
      {0, 0, 1, 0, 0, 0, -1, 0, 0},
      {0, -1, 0, 1, 0, 0, 0, 0, 0}});

function Rodrigues:Derivative(pr)
 
  local c, c1, s, theta, itheta, r, rrt, r_x = self:_BaseCaculate(pr);
  
  local J;
  if theta > EPSILON then
    local rrt_f = rrt:view(-1);
    local r_x_f = r_x:view(-1);
    local drrt = torch.Tensor({
        {2*r[1], r[2], r[3], r[2], 0, 0, r[3], 0, 0,},
        {0, r[1], 0, r[1], 2*r[2], r[3], 0, r[3], 0,},
        {0, 0, r[1], 0, 0, r[2], r[1], r[2], 2*r[3]}})
    
    J = torch.Tensor(3,9);
    for i=1, 3 do
      local a0 = -s*r[i];
      local a1 = (s - 2*c1*itheta)*r[i];
      local a2 = c1*itheta;
      local a3 = (c - s*itheta)*r[i];
      local a4 = s*itheta;
    
      J[{{i}, {}}] = a0 * I
                + a1 * rrt_f
                + a2 * drrt[{{i}, {}}]
                + a3 * r_x_f
                + a4 * d_r_x[{{i}, {}}];
    end
  else
    J = torch.Tensor({
      {0,0,0,0,0,-1,0,1,0},
      {0,0,1,0,0,0,-1,0,0},
      {0,-1,0,1,0,0,0,0,0}});
  end
  
  return J;
end

--静态函数传入旋转矩阵，返回罗德里格斯旋转向量
function Rodrigues.Decompose(r)
  --传入标准正交旋转矩阵，这行可以不要
  --local u,s,v = torch.svd(r)  
  --local R = u * v  
  R = r;
  
  local axis = torch.Tensor({
      R[{3, 2}] - R[{2, 3}],
      R[{1, 3}] - R[{3, 1}],
      R[{2, 1}] - R[{1, 2}]})
  
  if axis:norm()~=0 then
    axis = axis / axis:norm();  
    local c = (R[{1, 1}] + R[{2, 2}] + R[{3, 3}] - 1) * 0.5;  
    local theta = math.acos(c);
    return axis * theta;
  else
    return {0,0,0}
  end
  
end

return Rodrigues;