
local venusjson = require "venusjson"
local venuscore = require "venuscore"
local object = require "classic"
local torch = require "torch"

local Operator = require "torchutility.operator"

local SignedSqrt = Operator:extend();


function SignedSqrt:new(...)
    SignedSqrt.super.new(self,...);
end

function SignedSqrt:Caculate( param )
    return self:SignedSqrt(param);
end

function SignedSqrt:Derivative(param)
    --assert(#paramlist==1);
    if self:HasParamForD(1,param) then
        return self:PD_X_SignedSqrt(param,true);
    else
        assert(false);
    end
end

function SignedSqrt:SignedSqrt(value)
    local sqrted = torch.sqrt(torch.abs(value));
    local sign = torch.sign(value);
    return torch.cmul( sqrted,sign);   
end

function SignedSqrt:PD_X_SignedSqrt(value,tomat)
    local result = torch.sqrt(torch.abs(value));
    result = torch.pow(torch.mul(result,2),-1);
    result = result:view(-1);
    result[torch.eq(result, math.huge)] = 0;
    local size = result:size()[1];

    if tomat==true then
        local dxmat = torch.Tensor(size,size):fill(0);
        for i=1,size do
            dxmat[i][i] = result[i];
        end
        return dxmat;
    else
        return result;
    end
end 

return SignedSqrt;