local venusjson = require "venusjson"
local venuscore = require "venuscore"
local object = require "classic"
local torch = require "torch"
local AutoTensor = require "torchutility.autotensor"
local Operator = require "torchutility.operator"

local sequencialdot = Operator:extend();

function sequencialdot:new(...)
  sequencialdot.super.new(self,...);
  self:R();
end


function sequencialdot:Caculate( ... )
  self.paramlist = {...}
  local mulcache = self.paramlist[1];
  local length = #self.paramlist;
  
  if length==nil then
    --LOG(self.paramlist);
  end
  
  for i=2,#self.paramlist do
    mulcache = self.paramlist[i]*mulcache;
  end
  
  return mulcache;
end


function sequencialdot:PartialD(index)
  if #self.paramlist == 1 then
    return torch.eye(16);
  end
  local leftcache;
  local rightcache;
  for i=1,index-1 do
    if leftcache == nil then
      leftcache = self.paramlist[i];
    else
      leftcache = self.paramlist[i]*leftcache;
    end  
  end
  
  for i=index+1,#self.paramlist do
    if rightcache == nil then
      rightcache = self.paramlist[i];
    else
      rightcache = self.paramlist[i]*rightcache;
    end  
  end
  
  if leftcache==nil then
    return self:DOfMul(2,rightcache,self.paramlist[index]);
  elseif rightcache==nil then
    return self:DOfMul(1,self.paramlist[index], leftcache);
  else
    local mid = rightcache*self.paramlist[index];
    local midd = self:DOfMul(2,rightcache,self.paramlist[index])
    local final = mid*leftcache;
    local finald = self:DOfMul(1,mid,leftcache);
    local totald = finald*midd;
    return totald;
  end
end

function sequencialdot:DOfMul(index,paraml,paramr)
    
    local ar,br;
    local adim = paraml:nDimension()
    if adim<2 then
        ar = paraml:view(1,-1);
    else
        ar = paraml:view(-1, paraml:size()[adim]);
    end

    if paramr:nDimension()<2 then
        br = paramr:view(-1,1);
    else
        br = paramr:view(paramr:size()[1], -1);
    end

    if index==1 then
        if ar:nDimension() <= 2 then
            local eye = torch.eye(ar:size()[1])

            local T = br:t()
            local value = torch.kron(eye:contiguous(),T:contiguous())
            return value
        else
            return nil;
        end
    elseif index==2 then
        if paramr:nDimension() ==1 then
            return paraml;
        elseif paramr:nDimension() ==2 then
            ret = torch.kron(ar, torch.eye(br:size()[1]))
            return ret
        end
    else
        assert(false);
    end
end

function sequencialdot:Derivative(param)
    
end

return sequencialdot;