local AutoTensor = require "torchutility.autotensor"
local Operator = require "torchutility.operator"
local venusjson = require "venusjson"
local venuscore = require "venuscore"
local torch = require "torch"


local TransFormMat = Operator:extend();
local EPSILON = 1e-5;

function TransFormMat:new(...)
  TransFormMat.super.new(self,...);
end

function TransFormMat:_BaseCaculate(r)
  local theta = r:norm();
  local itheta = theta > 1e-5 and 1/theta or 0;  
  r = r * itheta;
  --if not self.theta or not self.theta == theta then
    self.theta = theta;
    
    local c = math.cos(theta);
    local s = math.sin(theta);
    local c1 = 1 - c;

    local rrt = torch.ger(r, r);
    local r_x = torch.Tensor({
        {    0, -r[3],  r[2]},
        { r[3],     0, -r[1]},
        {-r[2],  r[1],     0}});
    self.c = c;
    self.c1 = c1;
    self.s = s;
    self.theta = theta;
    self.itheta = itheta;
    self.rrt = rrt;
    self.r_x = r_x;
  --end
  return self.c, self.c1, self.s, self.theta, self.itheta, r, self.rrt, self.r_x;
end

local eye = torch.eye(3);
function TransFormMat:Caculate(pr,pos)
  --LOG("TransFormMat:Caculate");
  local matcount = 1;
  matcount = pr:size()[1];
  local result = torch.Tensor(matcount,4,4):fill(0);
  --if matcount>1 then
    for i =1,matcount do
      self:CaculateOne(pr[i],pos[i],1,1,result[i]);
    end
  -- else
  --  self:CaculateOne(pr,pos,1,1,result);
  --end
  return result;
end

function TransFormMat:CaculateOne(pr,pos,starti,startj,mat44)
  local c, c1, s, theta, itheta, r, rrt, r_x = self:_BaseCaculate(pr);
  local ret = c * eye + c1 * rrt + s * r_x;
  --form to a transform matrix
  --local mat44 = torch.Tensor(4,4);
  for i=0,2 do
    for j = 0,2 do
      mat44[i+starti][j+startj] = ret[i+1][j+1];
    end
  end
  --ret:resize(4,4);
  mat44[starti][startj+3] = pos[1];
  mat44[starti+1][startj+3] = pos[2];
  mat44[starti+2][startj+3] = pos[3];
  mat44[starti+3][startj+3] = 1;
  mat44[starti+3][startj] = 0;
  mat44[starti+3][startj+1] = 0;
  mat44[starti+3][startj+2] = 0;
end

local I = torch.eye(3):view(-1);
local d_r_x = torch.Tensor({
      {0, 0, 0, 0, 0, -1, 0, 1, 0 },
      {0, 0, 1, 0, 0, 0, -1, 0, 0 },
      {0, -1, 0, 1, 0, 0, 0, 0, 0 },
      });

function TransFormMat:Derivative(pr)
  if self:HasParamForD(1,pr) then
    
    local matcount = 1;
    matcount = pr:size()[1];
   -- if matcount==1 then
   --   return self:Drotone(pr):t();
   -- end
  
    local result = torch.Tensor(3*matcount,16*matcount):fill(0);
    for i=1,matcount do 
      self:Drotone(pr[i],(i-1)*3,(i-1)*16,result);
    end
    return result:t();
  else
    assert(false);
  end
end

function TransFormMat:Drotone(pr,starti,startj,fillmat)
   local c, c1, s, theta, itheta, r, rrt, r_x = self:_BaseCaculate(pr);
     
    local rrt_f = rrt:view(-1);
    local r_x_f = r_x:view(-1);
    local drrt = torch.Tensor({
        {2*r[1], r[2], r[3], r[2], 0, 0, r[3], 0, 0,},
        {0, r[1], 0, r[1], 2*r[2], r[3], 0, r[3], 0,},
        {0, 0, r[1], 0, 0, r[2], r[1], r[2], 2*r[3]}})
    
    local J = torch.Tensor(3,9):fill(0);

    if theta <= EPSILON then
      J = d_r_x;
    else
      for i=1, 3 do
        local a0 = -s*r[i];
        local a1 = (s - 2*c1*itheta)*r[i];
        local a2 = c1*itheta;
        local a3 = (c - s*itheta)*r[i];
        local a4 = s*itheta;
      
        J[{{i}, {}}] = a0 * I
                  + a1 * rrt_f
                  + a2 * drrt[{{i}, {}}]
                  + a3 * r_x_f
                  + a4 * d_r_x[{{i}, {}}];
                  
      end
    end
    if fillmat~=nil then
      for i=1,3 do
        for j = 1,9 do
          local jump = 0;
          if j>6 then
            jump=2;
          elseif j>3 then
            jump=1;
          end
          
          fillmat[i+starti][j+startj+jump] = J[i][j];
        end
      end 
    end
    
    return J;
end

return TransFormMat;