local apollonode = require "apolloutility.apollonode"
local apolloengine = require "apolloengine"
local venusjson = require "venusjson"
local mathfunction = require "mathfunction"
local Object = require "classic"

local meshmodel = require "emoji.model"

local likeapp = require "likeapp"
local libmasquerade = require "libmasque"
local mlbvt = require "machinelearningservice"
local videodecet = require "videodecet"
local vc = require "venuscore"
local videodefined = require"videodecet.defined"

local utility = require "liveportrait.utility"

local landmarkNum = 106
local SEQUENCE = 1000
local facetracking = Object:extend()

function facetracking:new()
  self.models = {};
  self.show = false;
  self.size = videodecet:GetVideoSize()
end

function facetracking:_CreateCamera(n, f, pos, lookat, up)
  self.camera = apollonode.CameraNode()
  self.camera:SetName("TrackingCamera")
  if self.persp then
    self.camera:CreatePerspectiveProjection(0.01, 1000)
  else
    self.camera:CreateOrthographiProjection(f/2, 2*n/f, -1000, 1000)
  end
  self.camera:LookAt(pos, lookat, up);
  self.camera:Recalculate();
  self.camera:SetClearColor(mathfunction.Color(0.0,1.0,1.0,1));
  self.camera:Activate();
  self.camera:SetSequence(SEQUENCE);
end


function facetracking:ParseConfig(modelconfig)
  self.persp = modelconfig.tracking.use_persp
  self:_CreateCamera(self.size[1], self.size[2], 
    mathfunction.vector3(0,0,0),
		mathfunction.vector3(0,0,-1),
		mathfunction.vector3(0,1,0))
  videodecet.render:SetSequence(SEQUENCE)
  local pointrenders = videodecet.pointrenders
  for i = 1, #pointrenders do
    pointrenders[i]:SetSequence(SEQUENCE)
  end
  
  local zPos = modelconfig.zPosition or 0;
  zPos = zPos - 2000;
  
  local bindboxV = mathfunction.Aabbox3d( mathfunction.vector3(-1,-1, zPos-1),
                                          mathfunction.vector3(1,1, zPos));   

  self.renderbefore = modelconfig.renderbefore or false;
  
  if modelconfig.lightdirection then
    local dir = modelconfig.lightdirection;
    self.direct = apollonode.LightNode(apolloengine.LightComponent.LT_DIRECTIONAL);
    self.direct:SetLocalDirection(mathfunction.vector3(dir[1],dir[2],dir[3]));
    if modelconfig.lightcolor then
      local color = modelconfig.lightcolor;
      self.direct:SetColor(mathfunction.vector3(color[1],color[2],color[3]));
    else
      self.direct:SetColor(mathfunction.vector3(1,1,1));
    end    
  end
  if modelconfig.ambientcolor then  
    self.ambient = apollonode.LightNode(apolloengine.LightComponent.LT_AMBIENT);
    local color = modelconfig.ambientcolor;
    self.ambient:SetColor(mathfunction.vector3(color[1],color[2],color[3]));
  end
  if modelconfig.skybox and 6 == #modelconfig.skybox then  
    self.skybox = apolloengine.TextureEntity();      
    for i, p in ipairs(modelconfig.skybox) do        
      --self.skybox:PushMetadata(apolloengine.TextureFileMetadata(p));
      self.skybox:PushMetadata(apolloengine.TextureFileMetadata(
      apolloengine.TextureEntity.TT_TEXTURECUBE_FRONT + i - 1,
      apolloengine.TextureEntity.TU_STATIC,
      apolloengine.TextureEntity.PF_AUTO,
      1, false,
      apolloengine.TextureEntity.TW_CLAMP_TO_EDGE,--默认repeat
      apolloengine.TextureEntity.TW_CLAMP_TO_EDGE,
      apolloengine.TextureEntity.TF_LINEAR,
      apolloengine.TextureEntity.TF_LINEAR,p));
    end

    self.skybox:CreateResource();
    apolloengine.IMaterialSystem:SetGlobalParameter(apolloengine.ShaderEntity.SKY_BOX, self.skybox);
    apolloengine.ShaderEntity.EVN_SCALE = apolloengine.IMaterialSystem:NewParameterSlot(apolloengine.ShaderEntity.UNIFORM,"ENV_SCALE");
    local env_scale = modelconfig.env_scale or 1;
    apolloengine.IMaterialSystem:SetGlobalParameter(apolloengine.ShaderEntity.EVN_SCALE,mathfunction.vector1(env_scale));
  end
  
  if modelconfig.face ~= nil then
    self.face_model = meshmodel();
    
    local face_config = modelconfig.face;
    local blendSlot = apolloengine.IMaterialSystem:NewParameterSlot(apolloengine.ShaderEntity.UNIFORM, "TEXTURE_BLEND"); 
    self:SetModel(self.face_model, face_config, nil, bindboxV, self.renderbefore);
    local node = self.face_model:GetModel()
    local blend_tex = self:LoadTexture(face_config.blend)
    node:SetParameter(blendSlot, blend_tex);
    self.face_model.model:SetSequence(SEQUENCE)
  end
  
  self.root_node = apollonode.TransNode();
  self.root_node:AttachNode(self.face_model:GetModel());
  
  --if self:_InitMasquerade(modelconfig.tracking) == false then
    --ERROR("Failed to init masquerade.");
    --self.bilinearTracking = nil;
    --return nil
  --end
  
  self.trackingConfig = modelconfig.tracking 
  self.trackInited = false;
  self.bvtInited = false;
  return true;
  
end

function facetracking:GetInitStatus()
  return self.renderInited
end

function facetracking:ShowModel(show)
  self.show = show;
  for i = 1, #self.models do
    self.models[i]:SetShow(show);
  end
end

function facetracking:SetModel(model, config, attach_to_node, bindbox, render_before)
  if (model ~= nil ) then 
    model:CreateModel(config, attach_to_node, bindbox, render_before);  
    table.insert(self.models, model);  
  end
end

function facetracking:LoadTexture(tex_path)
  if (tex_path == nil)
  then
    return;
  end
  
  local texture = apolloengine.TextureEntity();
  texture:PushMetadata(  apolloengine.TextureFileMetadata (
                         apolloengine.TextureEntity.TU_STATIC,
                         apolloengine.TextureEntity.PF_AUTO,
                         1, false,
                         apolloengine.TextureEntity.TW_CLAMP_TO_EDGE,
                         apolloengine.TextureEntity.TW_CLAMP_TO_EDGE,
                         apolloengine.TextureEntity.TF_LINEAR,
                         apolloengine.TextureEntity.TF_LINEAR,
                         tex_path));
     
  texture:SetKeepSource(true); 
  --texture:SetJobType(venuscore.IJob.JT_SYNCHRONOUS);   
  texture:CreateResource(); 
  
  return texture;
end

function facetracking:GetModelMat()
  if self.bilinearTracking ~= nil then
    return self.bilinearTracking:GetModelMat()
  end
end

function facetracking:GetTranslation()
  if self.bilinearTracking ~= nil then
    return self.bilinearTracking:GetTranslation()
  end
end

function facetracking:GetRotation()
  if self.bilinearTracking ~= nil then
    return self.bilinearTracking:GetRotation()
  end
end

function facetracking:GetScale()
  if self.bilinearTracking ~= nil then
    return self.bilinearTracking:GetScale()
  end
end

function facetracking:_InitMasquerade(config)
  self.bilinearTracking = libmasquerade.BilinearTracking();
  if self.bilinearTracking:Init(
    self.size[1], self.size[2], config.use_persp,
    config.config,
    config.bin,
    config.global,
    config.rigid_prior) then
    self.bilinearTracking:EnableDynamicUser(true, false)
    return true;
  end
  return false;
end

function facetracking:_InitBvt(w, h)
  self.cnnIris = mlbvt.MachinelearningService();
  self.cnnTongue = mlbvt.MachinelearningService();
  
  if _PLATFORM_WINDOWS then  --只有windows需要设置模型路径
    self.cnnIris:SetModelAndParamsWin(
      6,
      vc.IFileSystem:PathAssembly(videodefined.bvt_model_path));
    self.cnnTongue:SetModelAndParamsWin(
      10,
      vc.IFileSystem:PathAssembly(videodefined.bvt_model_path));
  end
  self.cnnIris:SetType(6);
  local res = self.cnnIris:Init(w,h);
  if res ~= 0 then
    LOG("bvt iris init failed: "..res);
    return false;
  end
  self.cnnTongue:SetType(10);
  res = self.cnnTongue:Init(w,h);
  if res ~= 0 then
    LOG("bvt tongue init failed: "..res);
    return false;
  end
  return true;
end

function facetracking:SetPosition(position, rotation)
  if self.persp then
    self.root_node:SetLocalPosition(position)
  else
    self.root_node:SetLocalPosition(mathfunction.vector3(position:x(), position:y(), 0))
  end
  
  self.root_node:SetLocalRotation(rotation);
end

function facetracking:InitModel(vertices, indices, uvs)
  local node = self.face_model:GetModel()
  
  local idxstream = apolloengine.IndicesStream();
  local indicesNum = #indices;
  
  idxstream:SetIndicesType(apolloengine.IndicesBufferEntity.IT_UINT16);
  idxstream:ReserveBuffer(indicesNum);

  for i = 1, indicesNum do
    idxstream:PushIndicesData(indices[i])
  end
  node.render:ChangeIndexBuffer(idxstream)

  local vtxstream = node:GetVertexStream()
  
  local newpoint =  mathfunction.vector2(0, 0)
  local offset = vtxstream:GetAttributeIndex(apolloengine.ShaderEntity.ATTRIBUTE_COORDNATE0);
  local vtxnum = #uvs / 2
  
  for i = 1, vtxnum do
    newpoint:Set(1 - uvs[i * 2 - 1], 1 - uvs[i * 2])
    vtxstream:ChangeVertexDataWithAttributeFast(
                                  offset,
                                  i,
                                  newpoint);
  end
   
  vtxstream:SetReflushInterval(1, vtxnum);
  node.render:ChangeVertexBuffer(vtxstream);
end

function facetracking:UpdateModel(tracking_data, def)

  local node = self.face_model:GetModel()
  local vtxstream = node:GetVertexStream()
  
  local newpoint =  mathfunction.vector4(0, 0, 0, 1)
  local offset = vtxstream:GetAttributeIndex(apolloengine.ShaderEntity.ATTRIBUTE_POSITION);
  local vtxnum = #tracking_data.vertices / 3
  local scale = self:GetScale()
  for i = 1, vtxnum do
    newpoint:Set(tracking_data.vertices[i * 3 - 2] * scale, tracking_data.vertices[i * 3 - 1] * scale, tracking_data.vertices[i * 3] * scale , 1)
    vtxstream:ChangeVertexDataWithAttributeFast(
                                  offset,
                                  i,
                                  newpoint);
  end
   
  vtxstream:SetReflushInterval(1, vtxnum);
  node.render:ChangeVertexBuffer(vtxstream);
end

-- this is a temporary function which will be replaced after the likeapp implment the interface
function facetracking:ReceiveInputDatas()
  local size = videodecet:GetVideoSize();
  local ts = videodecet:GetVideoFrame();
  local faces = videodecet:GetFaces();
  if #faces > 0 and size and ts then    
    if not self.bvtInited then
      if self:_InitBvt(size[1], size[2]) then
        self.bvtInited = true;
      else
        self.cnnIris = nil;
        self.cnnTongue = nil;
        return nil;
      end
    end
    
    if not self.trackInited then
      if self:_InitMasquerade(self.trackingConfig) == false then
        ERROR("Failed to init masquerade.");
        self.bilinearTracking = nil;
        return nil
      end
      self.trackInited = true
    end
    
    local landmarks = {};
    local keyPoints = mathfunction.vector2array();
    local faceKeyPoints = videodecet:GetPixelFacekeypointArray();
    for i=1, landmarkNum do
      keyPoints:PushBack(faceKeyPoints:Get(i));
      landmarks[2*i-1] = faceKeyPoints:Get(i):x();
      landmarks[2*i] = faceKeyPoints:Get(i):y();
    end

    local eyes = self.cnnIris:Run(keyPoints,ts);
    local leftEye = mathfunction.vector2(eyes:Get(20):x(),eyes:Get(20):y());    
    local rightEye = mathfunction.vector2(eyes:Get(40):x(),eyes:Get(40):y());
    local bvtTongue = self.cnnTongue:Run(keyPoints,ts):Get(1):x();
    if leftEye:x() == 0 and leftEye:y() == 0 then
      leftEye = keyPoints:Get(105);
    end
    if rightEye:x() == 0 and rightEye:y() == 0 then
      rightEye = keyPoints:Get(106);
    end

    local pry = faces[1]:GetRotation();
    local euler = mathfunction.vector3(pry[1],pry[3],pry[2]);
    local visibility = likeapp.AI:GetFaceVisibility();
    local proj = self.camera:GetProject()
    if proj == nil or self.bilinearTracking:SetProjMat(size[1], size[2], proj) ~= true then
      LOG("[Bilinear]: Use default projection matrix of Bilinear Model.")
    end
    if not self.bilinearTracking:Track(euler, landmarks, visibility, leftEye, rightEye, bvtTongue) then
      return nil
    end
    
    local trackingresult = {}
    trackingresult.vertices = self.bilinearTracking:GetTrackingVertices()
    return trackingresult
  end

  return nil;
end

return facetracking