local venusjson = require "venusjson"
local venuscore = require "venuscore"
local object = require "classic"
local torch = require "torch"



local GaussianMixtureModel = object:extend();

function GaussianMixtureModel:new()
  self.cache = {}
end

function GaussianMixtureModel:Load(tensorfile)
  local fullpath = venuscore.IFileSystem:PathAssembly(tensorfile);
  local params = torch.load(fullpath);
  local means = params["means"];
  local covars = params["covars"];
  local weights = params["weights"];
  self.n = #weights;
  
  --resize mean
  self.means = {}
  for _, mean in ipairs(means) do
    table.insert(self.means, mean:view(1,-1));
  end
  
  --协方差的逆
  local covarsinv = {}
  for _, cov in ipairs(covars) do
    local covinv = torch.inverse(cov);
    table.insert(covarsinv, covinv);
  end
  
  --巧乐斯分解
  self.covrsinvlow = {}
  for _, covinv in ipairs(covarsinv) do
    local lowm = torch.potrf(covinv, 'L');
    table.insert(self.covrsinvlow, lowm);
  end  
  
  local sqrtdet = {}
  local mindet = nil;
  for _, cov in ipairs(covars) do
    local det = torch.sqrt(torch.det(cov));
    mindet = mindet and (mindet > det and det or mindet) or det;
    table.insert(sqrtdet, det);
  end
  
  self.consts = {}
  local dem = covars[1]:size(1);
  local part_const = (torch.pow(2*math.pi, dem/2))
  for i, det in ipairs(sqrtdet) do
    local const = weights[i] / ( part_const * det / mindet );
    table.insert(self.consts, const);
  end 
end

function GaussianMixtureModel:MeanPose()
  local res = nil
  for i=1, self.n do
    local w = self.consts[i];
    local m = w * self.means[i];
    res = res and res + m or m;
  end
  return res;
end

function GaussianMixtureModel:_CaculatePriori(x, i)
  --(x-m)*s*sqrt(0.5)
  local dif = (x - self.means[i]);
  local cov = self.covrsinvlow[i];
  local weight = self.consts[i];
  local loglikelihoods = dif * cov * torch.sqrt(0.5);
  local lls = loglikelihoods:dot(loglikelihoods);
  return lls - torch.log(weight), loglikelihoods;
end

function GaussianMixtureModel:APriori(x)
  local prob = 0;
  for i=1, self.n do
    prob = prob + self:_CaculatePriori(x, i);
  end
  return prob;
end

function GaussianMixtureModel:_MinAPriori(x)
  local minindex, minprob, minloglike;
  for i=1, self.n do
    local prob, lll = self:_CaculatePriori(x, i);
    if not minprob or minprob > prob then
      minindex = i;
      minprob = prob;
      minloglike = lll;
    end    
  end  
  local min_w = math.sqrt(-math.log(self.consts[minindex]));
  local expend_weight = torch.Tensor(1,1):fill(min_w);
  minloglike = minloglike:cat(expend_weight);-- 还没搞清楚为什么要将权重加进来  
  return minindex, minprob, minloglike;
end

function GaussianMixtureModel:MinAPriori(x)
  if not self.cache.x or not self.cache.x:equal(x) then
    self.cache.x = x:clone();
    local mi, mp, mll = self:_MinAPriori(x);
    self.cache.maxindex = mi;
    self.cache.minprob = mp;
    self.cache.minloglike = mll;
  end
  return self.cache.maxindex, self.cache.minprob, self.cache.minloglike;
end

function GaussianMixtureModel:_dx(x)
  local index = self:MinAPriori(x);
  local cov = self.covrsinvlow[index];
  --local j = (2 * math.sqrt(0.5)) * cov:t();--the original version may be an bug.
  local j = math.sqrt(0.5) * cov:t();
  local c = j:size(2);
  local expend = torch.Tensor(1,c):zero();
  j = torch.cat(j, expend, 1);
  return j;
end

function GaussianMixtureModel:dx(x)
  if not self.cache.dx or not self.cache.dx:equal(x) then
    self.cache.dx = x:clone();
    local j = self:_dx(x);
    self.cache.j = j;
  end
  return self.cache.j
end

return GaussianMixtureModel;
