local venusjson = require "venusjson"
local venuscore = require "venuscore"

local torch = require "torch"
local AutoTensor = require "torchutility.autotensor"
local Operator = require "torchutility.operator"


local SubTensor = Operator:extend();

function SubTensor:new(list,...)
    SubTensor.super.new(self,...);
    self.indexlist = list;
end

function SubTensor:Caculate( param1 )
  local rettensor = nil
  if type(self.indexlist[1]) ~= "table" then
    rettensor = torch.Tensor(#self.indexlist,param1:size()[2]);
  elseif #self.indexlist[1]==2  then
    rettensor = torch.Tensor(1,#self.indexlist);
  end
  
  for i=1,#self.indexlist do
    if type(self.indexlist[i]) ~= "table" then
      rettensor[i] = param1[self.indexlist[i]]
    elseif #self.indexlist[i]==2 then
      rettensor[1][i] = param1[self.indexlist[i][1]][self.indexlist[i][2]]
    end
  end
  
  return rettensor;
end

function SubTensor:Derivative(param)
    if self:HasParamForD(1,param) then
      local size = param:size()[2];
      local rowsize = 0 ;
      if type(self.indexlist[1]) == "table" then
        rowsize = #self.indexlist;
      else
        rowsize = #self.indexlist*size;
      end
      
      local ret = torch.Tensor(rowsize,self.paramcache[1]:R():size()[1]*size):fill(0);
      for i=1,#self.indexlist do
        if type(self.indexlist[i]) ~= "table" then
          for j=1,size do
            local columidx = (self.indexlist[i]-1)*size+j;
            local rowidx = (i-1)*size+j;
            ret[rowidx][columidx] = 1;
          end
        elseif #self.indexlist[i]==2 then
          --for j=1,size do
            local columidx = (self.indexlist[i][1]-1)*size+self.indexlist[i][2];
            local rowidx = i;
            ret[rowidx][columidx] = 1;
         -- end
        end
      end
      return ret;
    else
      assert(false);
    end
end



return SubTensor;