--[[ A plain implementation of SGD

ARGS:

- `opfunc` : a function that takes a single input (X), the point
             of a evaluation, and returns f(X) and df/dX
- `x`      : the initial point
- `config` : a table with configuration parameters for the optimizer
- config.e_N : are stopping conditions:
- config.e_1 : is gradient magnatude threshold
- config.e_2 : is step size magnatude threshold
- config.e_3 : is improvement threshold (as a ratio; 0.1 means it must improve by 10%% at each step)

RETURN:
- `x`     : the new x vector
- `f(x)`  : the function, evaluated before the update
]]

local Object = require "classic"


local state;
local config;
local Trial = Object:extend(); 

function Trial:new(proposed_r)
  self.r = proposed_r;
  --rho is the ratio of...
  --(improvement in SSE) / (predicted improvement in SSE)
  self.rho = state.r:norm()^2 - proposed_r:norm()^2;
  if self.rho > 0 then
    local predicted_improvement = 2 * state.g:t() * state.d_dl - state.d_dl:t() * state.A *state.d_dl;
    self.rho = self.rho / predicted_improvement[{1, 1}];
  end  
end

function Trial:is_improvement()
  return self.rho > 0;
end

function Trial:improvement()
  return (state.r:norm()^2 - self.r:norm()^2) / state.r:norm()^2;
end


local function _init(config)
  state = {}
  state.fevals = 0;
  state.iteration = 0;
  config = config or {}
  config.max_fevals = config.max_fevals or math.huge;
  config.maxiter = config.maxiter or 200;
  config.e_1 = config.e_1 or 1e-15;
  config.e_2 = config.e_2 or 1e-15;
  config.e_3 = config.e_3 or 0;
  config.ub = 0.9;
  config.lb = 0.05;
  return config;
end



local function _updateJ()
  local j = nil;
  for _, func in pairs(state.funcs) do
    local pj = func.dx(state.x);
    j = j and torch.cat(j, pj, 1) or pj;
  end  
  state.j = j;
end

local function _evaluateR(x)
  local r = nil;
  for _, func in pairs(state.funcs) do
    local pr = func.fx(state.x);
    pr:resize(pr:numel(), 1);
    r = r and torch.cat(r, pr, 1) or pr;
  end  
  return r
end

local function _updateR(nr)
  if nr then
    state.r = nr;
  else    
    state.r = _evaluateR(state.x);
  end
  state.fevals = state.fevals + 1;
end


local function _updateAg()
  local jt = state.j:t();
  state.A = jt * state.j; --hess
  state.g = -jt * state.r; --dfx
end

local function _updateGN()
  local function slove(g, A)
    return torch.gesv(g, A);
  end
  local status, retval = pcall(slove, state.g, state.A);
  if not status then
    retval = torch.gelsd(state.g, state.A);
  end  
  state.d_gn = retval;
end

local function _begin_iteration()
  state.iteration = state.iteration + 1;
  local gn = state.g:norm();
  state.d_sd = gn^2 / (state.j * state.g):norm()^2 * state.g;
  state.d_gn = nil;
end

local function _beta_multiplier()
  local delta_sq = state.delta^2
  local diff = state.d_gn - state.d_sd
  local sqnorm_sd = state.d_sd:norm()^2
  local pnow = diff:t() * diff * delta_sq + torch.pow(state.d_gn:t() * state.d_sd, 2) - state.d_gn:norm()^2 * sqnorm_sd
  return (delta_sq - sqnorm_sd) / (diff:t() * state.d_sd + torch.sqrt(pnow))[{1, 1}]
end

local function _update_step()
  local d_sd_norm = state.d_sd:norm();
  if state.delta and d_sd_norm >= state.delta then
    state.d_dl = state.delta / d_sd_norm * state.d_sd;
  else
    if not state.d_gn then
      _updateGN()
    end
    local d_gn_norm = state.d_gn:norm();
    if not state.delta or d_gn_norm <= state.delta then
      state.d_dl = state.d_gn:clone();
      if not state.delta then
        state.delta = d_gn_norm;
      end
    else
      local bm = _beta_multiplier()
      state.d_dl = state.d_sd + bm * (state.d_gn - state.d_sd)
    end
  end  
end

local function _updateRadius(rho)
  if rho > config.ub then
    state.delta = math.max(state.delta, 2.5 * state.d_dl:norm())
  elseif rho < config.lb then
    state.delta = state.delta * 0.25
  end  
end

local function dogleg(funcs, x, c)
  config = _init(c);
  state.x = x;
  state.p = x;
  state.funcs = funcs;
  
  _updateJ();
  _updateR();
  _updateAg();  
  
  local done = false;
  local function _stop(fmat, ...)
    if not done then
      local input = string.format(fmat, ...);
      local msg = string.format("total iterate %d, message: %s", state.iteration, input);
      LOG(msg);
    end
    done = true;
  end  
  while not done do 
    _begin_iteration();
    local _serch = true;
    while _serch do
      local trial;
      _update_step();
      local step_size = state.d_dl:norm();
      local pn = state.p:norm();
      local ce2 = config.e_2 * pn;
      if step_size <= config.e_2 * state.p:norm() then
        _stop('stopping because of small step size (norm_dl < %.2f)' , config.e_2 * state.p:norm())
      else
        state.x = state.p + state.d_dl;--d_dl is step
        local nr = _evaluateR(state.x);
        trial = Trial(nr);
        if trial:is_improvement() then
          state.p = state.p + state.d_dl
          if config.e_3 > 0. and trial:improvement() < config.e_3 then
            _stop('stopping because improvement < %f' , config.e_3)
          else
            _updateJ();
            _updateR(nr);
            _updateAg();
            if torch.norm(state.g, 10) < config.e_1 then --pseudo inf-normal
              _stop('stopping because norm(g, 10) < %.2e', config.e_1)
            end
          end
        else
          state.x = state.p;
        end
        _updateRadius(trial.rho);
        if state.delta <= config.e_2 * state.p:norm() then
          _stop('stopping because trust region is too small')
        end
      end
      if done or trial:is_improvement() or state.fevals >= config.max_fevals then
        _serch = false;
      end
    end
    if state.iteration >= config.maxiter then
      _stop('stopping because max number of user-specified iterations (%d) has been met', config.maxiter)
    elseif state.fevals >= config.max_fevals then
      _stop('stopping because max number of user-specified func evals (%d) has been met', config.max_fevals)
    end    
  end
  x:copy(state.x);
  return state.x;
end

return dogleg;