local venusjson = require "venusjson"
local venuscore = require "venuscore"
local object = require "classic"
local torch = require "torch"
local AutoTensor = require "torchutility.autotensor"
local Operator = require "torchutility.operator"

local Multiply = Operator:extend();

function Multiply:new(...)
    Multiply.super.new(self,...);
end

function Multiply:Caculate( param1,param2 )
    return param1*param2;
end

function Multiply:Derivative(param)
    
    local ar,br;
    local adim = self.paramcache[1]:R():nDimension()
    if adim<2 then
        ar = self.paramcache[1]:R():view(1,-1);
    else
        ar = self.paramcache[1]:R():view(-1, self.paramcache[1]:R():size()[adim]);
    end

    if self.paramcache[2]:R():nDimension()<2 then
        br = self.paramcache[2]:R():view(-1,1);
    else
        br = self.paramcache[2]:R():view(self.paramcache[1]:R():size()[1], -1);
    end

    if self:HasParamForD(1,param) then
        if ar:nDimension() <= 2 then
            local eye = torch.eye(ar:size()[1])

            local T = br:t()
            local value = torch.kron(eye,T)
            return value
        else
            return nil;
        end
    elseif self:HasParamForD(2,param) then
        if self.paramcache[2]:R():nDimension() ==1 then
            return self.paramcache[1]:R();
        elseif self.paramcache[2]:R():nDimension() ==2 then
            ret = torch.kron(ar, torch.eye(br:size()[2],br:size()[2]))
            return ret
        end
    else
        assert(false);
    end
end

function AutoTensor.__mul(no1, no2)
  return Multiply(no1,no2);
end

return Multiply;