--Geman-McClure robustifier

local venusjson = require "venusjson"
local venuscore = require "venuscore"
local object = require "classic"
local torch = require "torch"
local Operator = require "torchutility.operator"

local Robustifer = Operator:extend();

function Robustifer:new(sigma,...)
    self.sigmavalue = sigma;
    Robustifer.super.new(self,...);
end

function Robustifer:SignedSqrt(value)
    local sqrted = torch.sqrt(torch.abs(value));
    local sign = torch.sign(value);
    return torch.cmul( sqrted,sign);   
end

function Robustifer:InitParam(x)
    --[[local xsize = table.getn(x);
    self.x = torch.Tensor(xsize)
    for i=1,xsize do
        self.x[i] = x[i];
    end]]--
    self.x = x;
    self.xsquare = torch.cmul( self.x,self.x )
   

    self.sigma = torch.Tensor(x:size()):fill(self.sigmavalue);
    --LOG(self.sigma);
    self.sigmasquare = torch.cmul(self.sigma,self.sigma);
end

function Robustifer:Caculate( param )
    --assert(#paramlist==1);
    return self:Robustify(param);
end

function Robustifer:Derivative(param)
   
    if self:HasParamForD(1,param) then
        return self:PD_X_Robustify(param,true);
    else
        assert(false);
    end
end

--[[
function Robustifer:R(x)
   self:Robustify(x);
end  

function Robustifer:DR(x)
    self:PD_X_Robustify(x,true);
end  ]]

function Robustifer:Robustify(x)
    self:InitParam(x);
   -- LOG("Robustifer_Robustify");
    local div = torch.cdiv(self.xsquare,torch.add(self.sigmasquare,self.xsquare));
    local ret = torch.cmul(self.sigmasquare,div);
    local sign = torch.sign(self.x);
    local value = torch.cmul(ret,sign);
    return value;
end  

function Robustifer:SignedSqrt_Robustify(x)
    return self:SignedSqrt(self:Robustify(x))
end

--[[function Robustifer:PD_X()
    local dx = self.sigmasquare/(self.sigmasquare+self.xsquare)-self.sigmasquare*(self.xsquare/math.pow((self.sigmasquare+self.xsquare),2) );
    dx = 2 * self.x * dx
    return dx;
end]]--

function Robustifer:PD_X_Robustify(x,tomat)
    self:InitParam(x);

    local udv = torch.cdiv(self.sigmasquare,torch.add(self.sigmasquare,self.xsquare));
    local duv = torch.cmul(self.sigmasquare,torch.cdiv(self.xsquare,torch.pow(torch.add(self.sigmasquare,self.xsquare),2)));
    local dx = torch.csub(udv,duv);
    dx = torch.cmul(self.x,dx);
    dx = torch.mul(dx,2);
    dx = torch.cmul(torch.sign(self.x),dx);
    dx = dx:view(-1);
    local size = dx:size()[1];
    if tomat==true then
        local dxmat = torch.Tensor(size,size):fill(0);
        for i=1,size do
            dxmat[i][i] = dx[i];
        end
        return dxmat;
    else
        return dx;
    end
    
end 

function Robustifer:PD_X_SignedSqrt(value,tomat)
    local result = torch.sqrt(torch.abs(value));
    result = torch.pow(torch.mul(result,2),-1);
    
    result[torch.eq(result, math.huge)] = 0;
    
   -- LOG("PD_X_SignedSqrt");
    local size = value:size()[1];

   -- LOG(result);
    if tomat==true then
        local dxmat = torch.Tensor(size,size):fill(0);
        for i=1,size do
            dxmat[i][i] = result[i];
        end
        return dxmat;
    else
        return result;
    end
end 


function Robustifer:PD_X_SignedSqrt_Robustify(x,tomat)
    return self:PD_X_SignedSqrt(self:PD_X_Robustify(x),tomat);
end 


return Robustifer;

