
local torch = require "torch"
local Object = require "classic"

local oneeurofilter = Object:extend();

function oneeurofilter:smoothing_factor(t_e, cutoff)
    local r = 2 * math.pi  * t_e *cutoff;
    return r / (r + 1);
end

function oneeurofilter:smoothing_factor_tensor(t_e, cutoff)
    local r = 2 * math.pi  * t_e;
    r = torch.mul(cutoff,r);
    return  torch.cdiv( r , (r + 1));
end

function oneeurofilter:exponential_smoothing(a, x, x_prev)
    return torch.mul(x,a)  + torch.mul( x_prev,(1 - a));
end

function oneeurofilter:exponential_smoothing_tensor(a, x, x_prev)
    return torch.cmul(x,a)  + torch.cmul( x_prev,(1 - a));
end

function oneeurofilter:new(t0, x0, dx0, min_cutoff, beta,
    d_cutoff)
    
    if min_cutoff == nil then
        min_cutoff = 1;
    end
    if beta == nil then
        beta = 0;
    end
    if d_cutoff == nil then
        d_cutoff = 1
    end
    
    --The parameters.
    self.min_cutoff = min_cutoff;
    self.beta = beta;
    self.d_cutoff = d_cutoff;
    
    --Previous values.
    self.x_prev = x0;
    if dx0 == nil then
      self.dx_prev = torch.Tensor(self.x_prev:size()):fill(0);
    else
      self.dx_prev = torch.Tensor(dx0);
    end
    self.t_prev = t0;
end

function oneeurofilter:filter(t, torchx)
    --local torchx = torch.Tensor(x);
    --Compute the filtered signal.
    local t_e = t - self.t_prev

    --The filtered derivative of the signal.
    local a_d = self:smoothing_factor(t_e, self.d_cutoff)
    local  dx = (torchx - self.x_prev) / t_e
    local dx_hat = self:exponential_smoothing(a_d, dx, self.dx_prev)

    --The filtered signal.
    local cutoff = self.min_cutoff + torch.mul(torch.abs(dx_hat),self.beta )
    
    local a = self:smoothing_factor_tensor(t_e, cutoff)
    local x_hat = self:exponential_smoothing_tensor(a, torchx, self.x_prev)

    --Memorize the previous values.
    self.x_prev = x_hat
    self.dx_prev = dx_hat
    self.t_prev = t

    return x_hat
end

return oneeurofilter;