local venusjson = require "venusjson"
local venuscore = require "venuscore"
local object = require "classic"
local torch = require "torch"
local AutoTensor = require "torchutility.autotensor"
local Operator = require "torchutility.operator"

local NumMultiply = Operator:extend();

function NumMultiply:new(...)
    NumMultiply.super.new(self,...);
end

function NumMultiply:Caculate( param1,param2 )
  return torch.cmul(param1,param2);
end

function NumMultiply:Derivative(param)
 
  if self:HasParamForD(1,param) then
    local adim = self.paramcache[2]:R():nDimension()
    if adim ==1 then
      return torch.diag(self.paramcache[2]:R());
    elseif adim ==2 then
      local temp = self.paramcache[2]:R():view(-1);
      return torch.diag(temp);
    else
      assert(false);
    end
  else
    local adim = self.paramcache[1]:R():nDimension()
    if adim ==1 then
      return torch.diag(self.paramcache[1]:R());
    elseif adim ==2 then
      local temp = self.paramcache[1]:R():view(-1);
      return torch.diag(temp);
    else
      assert(false);
    end
  end
   
end

return NumMultiply;