local venuscore = require "venuscore"
local mathfunction = require "mathfunction"
local apolloengine = require "apolloengine"
local libmasquerade = require "libmasque"
local cv = require "computervisionfunction"


local Facemesh = venuscore.VenusBehavior:extend("FacemeshBehavior");

local config = {
    use_persp = true;
    bin = "bigo_low_cereal.bin";
    config = "bilinear_config.json";
    global = "bilinear_global.json";
  }

function Facemesh:new()
    self.renderInited = false 
    self.recogition=nil;
    self.isshow = nil;
    self.needReset = false;
end

function Facemesh:GetRecongnition()
    return  self.recognition;
end

function Facemesh:SetRecongnition(rec)
    LOG("[Bilinear]: SET REC")
    self.recognition = rec;
end

function Facemesh:GetCamera()
    return  self.camera;
end

function Facemesh:SetCamera(camera)
    LOG("[Bilinear]: SET CAMERA")
    self.camera = camera;
end

function Facemesh:GetTransform()
    if self.transform==nil then
        self.transform = self.Node:GetComponent(apolloengine.Node.CT_TRANSFORM);
        if self.transform==nil then
            LOG("[Bilinear]: WE HAVE NO TRANSFORM COMPOENENT!")
        end
    end
    return  self.transform;
end

function Facemesh:GetRender()
    if self.render==nil then
        self.render = self.Node:GetComponent(apolloengine.Node.CT_RENDER);
        if self.render==nil then
            LOG("[Bilinear]: WE HAVE NO RENDER COMPOENENT!")
        end
    end
    return  self.render;
end

function Facemesh:GetResolution()
    local camera = self:GetCamera();
    if camera~=nil then
        local res = camera.CameraResolution;
        return res
    end
    return nil
end

function Facemesh:SetShow(isshow)
  if self.isshow ~= isshow then
    self.isshow = isshow;
    local render = self:GetRender()
    if isshow then
      render:SetRenderProperty(apolloengine.RenderComponent.RP_SHOW);
    else
      render:EraseRenderProperty(apolloengine.RenderComponent.RP_SHOW);
    end  
  end
end

function Facemesh:_Play()

end

function Facemesh:_OnAwake()

end

function Facemesh:_OnStart()
  
end

function Facemesh:Init()
    if self:GetCamera() ~= nil then
        local resolution = self:GetResolution();
        self.bilinearTracking = libmasquerade.BilinearTracking();
        --self.bilinearTracking:EnableLog();
        if self.bilinearTracking:Init(
          resolution:x(), resolution:y(), config.use_persp,
          config.config,
          config.bin,
          config.global) then
          --self.bilinearTracking:SetSmoothing(false)
          --self.bilinearTracking:SetPosit(false)
          self.bilinearTracking:EnableDynamicUser(true, false)
          LOG("[Bilinear]: INIT DONE");
          self.isInited = true
          return true
        else
          self.isInited = false
          ERROR("[Bilinear]: fail to init bilinear tracking.")
        end
    end
    return false
end

function Facemesh:GetFaceData()
    --LOG("[Bilinear]: Update lmk");
    if self:GetRecongnition()==nil then
        LOG("[Bilinear]: NO REC");
        return nil;
    end
    local results = self:GetRecongnition():GetResult();
    if results==nil or results[cv.RecognitionComponent.cvFace]==nil then
        LOG("[Bilinear]: NO RESULT");
        return nil;
    end

    local faceData = {}

    local facerets = results[cv.RecognitionComponent.cvFace][1];
    if facerets==nil then
        return nil;
    end
    --LOG("[Bilinear]: Update face lmk");
    local lmk = {}
    for i = 1, 106 do 
      lmk[i * 2 - 1] = facerets[i * 2 - 1]
      lmk[i * 2] = facerets[i * 2]
    end
    
    local advFacerets = results[cv.RecognitionComponent.cvAdvancedLandmark];
    local useAdv = false
    if advFacerets ~= nil and advFacerets[1] ~= nil and advFacerets[1][107 * 2 - 1] >= 0 then
      for i = 107, 240 do
        lmk[i * 2 - 1] = advFacerets[1][i * 2 - 1]
        lmk[i * 2] = advFacerets[1][i * 2]
      end
      useAdv = true
      --lmk[85 * 2 - 1] = advFacerets[1][(106 + 71) * 2 - 1];
      --lmk[85 * 2] = advFacerets[1][(106 + 71) * 2];
      --lmk[86 * 2 - 1] = advFacerets[1][(106 + 74) * 2 - 1];
      --lmk[86 * 2] = advFacerets[1][(106 + 74) * 2];
      --lmk[87 * 2 - 1] = advFacerets[1][(106 + 77) * 2 - 1];
      --lmk[87 * 2] = advFacerets[1][(106 + 77) * 2];
      --lmk[88 * 2 - 1] = advFacerets[1][(106 + 79) * 2 - 1];
      --lmk[88 * 2] = advFacerets[1][(106 + 79) * 2];
      --lmk[89 * 2 - 1] = advFacerets[1][(106 + 81) * 2 - 1];
      --lmk[89 * 2] = advFacerets[1][(106 + 81) * 2];
      --lmk[90 * 2 - 1] = advFacerets[1][(106 + 84) * 2 - 1];
      --lmk[90 * 2] = advFacerets[1][(106 + 84) * 2];
      --lmk[91 * 2 - 1] = advFacerets[1][(106 + 87) * 2 - 1];
      --lmk[91 * 2] = advFacerets[1][(106 + 87) * 2];
      --lmk[92 * 2 - 1] = advFacerets[1][(106 + 131) * 2 - 1];
      --lmk[92 * 2] = advFacerets[1][(106 + 131) * 2];
      --lmk[93 * 2 - 1] = advFacerets[1][(106 + 129) * 2 - 1];
      --lmk[93 * 2] = advFacerets[1][(106 + 129) * 2];
      --lmk[94 * 2 - 1] = advFacerets[1][(106 + 127) * 2 - 1];
      --lmk[94 * 2] = advFacerets[1][(106 + 127) * 2];
      --lmk[95 * 2 - 1] = advFacerets[1][(106 + 125) * 2 - 1];
      --lmk[95 * 2] = advFacerets[1][(106 + 125) * 2];
      --lmk[96 * 2 - 1] = advFacerets[1][(106 + 123) * 2 - 1];
      --lmk[96 * 2] = advFacerets[1][(106 + 123) * 2];
      --lmk[97 * 2 - 1] = advFacerets[1][(106 + 88) * 2 - 1];
      --lmk[97 * 2] = advFacerets[1][(106 + 88) * 2];
      --lmk[98 * 2 - 1] = advFacerets[1][(106 + 93) * 2 - 1];
      --lmk[98 * 2] = advFacerets[1][(106 + 93) * 2];
      --lmk[99 * 2 - 1] = advFacerets[1][(106 + 96) * 2 - 1];
      --lmk[99 * 2] = advFacerets[1][(106 + 96) * 2];
      --lmk[100 * 2 - 1] = advFacerets[1][(106 + 99) * 2 - 1];
      --lmk[100 * 2] = advFacerets[1][(106 + 99) * 2];
      --lmk[101 * 2 - 1] = advFacerets[1][(106 + 104) * 2 - 1];
      --lmk[101 * 2] = advFacerets[1][(106 + 104) * 2];
      --lmk[102 * 2 - 1] = advFacerets[1][(106 + 115) * 2 - 1];
      --lmk[102 * 2] = advFacerets[1][(106 + 115) * 2];
      --lmk[103 * 2 - 1] = advFacerets[1][(106 + 112) * 2 - 1];
      --lmk[103 * 2] = advFacerets[1][(106 + 112) * 2];
      --lmk[104 * 2 - 1] = advFacerets[1][(106 + 109) * 2 - 1];
      --lmk[104 * 2] = advFacerets[1][(106 + 109) * 2];
      --LOG("[Bilinear]: USE ADV FACE LANDMARKS")
    else
      for i = 107, 240 do
        lmk[i * 2 - 1] = -1
        lmk[i * 2] = -1
      end
    end
    
    faceData.useAdv = useAdv
    faceData.landmarks = lmk

    local visibility = {}
    for i = 112, 217 do 
      visibility[i - 111] = facerets[i * 2 - 1]
    end
    faceData.visibility = visibility
    
    local pyr = {}
    pyr[1] = facerets[109 * 2 - 1]
    pyr[2] = facerets[110 * 2 - 1]
    pyr[3] = facerets[111 * 2 - 1]
    faceData.euler = mathfunction.vector3(pyr[1], pyr[2], pyr[3])
    
    local tongue = results[cv.RecognitionComponent.cvTongueDetection]
    faceData.tongue = 0
    if tongue ~= nil and tongue[1] ~= nil then 
      faceData.tongue = tongue[1][1]
    end
  
    local iris = results[cv.RecognitionComponent.cvIrisDetection]
    local leftEye = mathfunction.vector2(0, 0)
    local rightEye = mathfunction.vector2(0, 0)
    if iris ~= nil and iris[1] ~= nil then
      leftEye = mathfunction.vector2(iris[1][20 * 2 - 1], iris[1][20 * 2])
      rightEye = mathfunction.vector2(iris[1][40 * 2 - 1], iris[1][40 * 2])
    end
    if leftEye:x() == 0 and leftEye:y() == 0 then
      leftEye = mathfunction.vector2(facerets[105 * 2 - 1], facerets[105 * 2])
    end
    if rightEye:x() == 0 and rightEye:y() == 0 then
      rightEye = mathfunction.vector2(facerets[106 * 2 - 1], facerets[106 * 2])
    end
    faceData.leftEye = leftEye
    faceData.rightEye = rightEye
    return faceData
end

function Facemesh:Positioning()
  local trans = self.bilinearTracking:GetTranslation()
  local rot = self.bilinearTracking:GetRotation()
  if trans == nil or rot == nil then
    return false
  end
  local pos = {}
  pos.rotation = rot:ToQuaternion()
  pos.position = trans
  
  local camera = self:GetCamera()
  local cameraPosition = camera:GetPosition()
  local transform = self:GetTransform();
  if transform then
    transform:SetLocalPosition(pos.position + cameraPosition);
    transform:SetLocalRotation(pos.rotation);
  end
  return true
end

function Facemesh:InitMesh(tracking_data)
  local render = self:GetRender()
  
  local idxstream = apolloengine.IndicesStream();
  local indicesNum = #tracking_data.indices;
  
  idxstream:SetIndicesType(apolloengine.IndicesBufferEntity.IT_UINT16);
  idxstream:ReserveBuffer(indicesNum);

  for i = 1, indicesNum do
    idxstream:PushIndicesData(tracking_data.indices[i])
  end
  render:ChangeIndexBuffer(idxstream)

  local vtxstream = render:GetVertexStream()
  if vtxstream == nil then
    ERROR("fail to get vtxstream")
    return
  end
  
  local newpoint =  mathfunction.vector2(0, 0)
  local offset = vtxstream:GetAttributeIndex(apolloengine.ShaderEntity.ATTRIBUTE_COORDNATE0);
  local vtxnum = #tracking_data.uvs / 2
  
  for i = 1, vtxnum do
    newpoint:Set(1 - tracking_data.uvs[i * 2 - 1], 1 - tracking_data.uvs[i * 2])
    vtxstream:ChangeVertexDataWithAttributeFast(
                                  offset,
                                  i,
                                  newpoint);
  end
   
  vtxstream:SetReflushInterval(1, vtxnum);
  render:ChangeVertexBuffer(vtxstream);
  
  self.renderInited = true
end

function Facemesh:UpdateMesh(tracking_data)

  local render = self:GetRender()
  local vtxstream = render:GetVertexStream()
  if vtxstream == nil then
    ERROR("fail to get vtxstream")
    return
  end
  
  local newpoint =  mathfunction.vector4(0, 0, 0, 1)
  local offset = vtxstream:GetAttributeIndex(apolloengine.ShaderEntity.ATTRIBUTE_POSITION);
  local vtxnum = #tracking_data.vertices / 3
  local scale = self.bilinearTracking:GetScale()
  for i = 1, vtxnum do
    newpoint:Set(tracking_data.vertices[i * 3 - 2] * scale, tracking_data.vertices[i * 3 - 1] * scale, tracking_data.vertices[i * 3] * scale , 1)
    vtxstream:ChangeVertexDataWithAttributeFast(
                                  offset,
                                  i,
                                  newpoint);
  end
  
  
   
  vtxstream:SetReflushInterval(1, vtxnum);
  render:ChangeVertexBuffer(vtxstream);
end

function Facemesh:_OnUpdate(timespan)
    if self.isInited ~= true then
      if not self:Init() then
        return
      end
    end
    
    local facedata = self:GetFaceData()
    if facedata == nil then
      ERROR("[Bilinear]: fail to get face data.");
      self.needReset = true
      self:SetShow(false)
      return
    end
    
    local camera = self:GetCamera()
    if camera == nil then
      self:SetShow(false)
      return
    end
    
    if self.needReset then
      self.bilinearTracking:Reset()
      self.needReset = false
    end
    
    local proj = camera:GetProject()
    local resolution = self:GetResolution()
    if proj == nil or self.bilinearTracking:SetProjMat(resolution:x(), resolution:y(), proj) ~= true then
      LOG("[Bilinear]: Use default projection matrix of Bilinear Model.")
    end

    if not self.bilinearTracking:Track(facedata.useAdv, facedata.euler, facedata.landmarks, facedata.visibility, facedata.leftEye, facedata.rightEye, facedata.tongue) then
      self:SetShow(false)
      ERROR("[Bilinear]: fail to track.");
      return
    end
    
    if not self:Positioning() then
      self:SetShow(false)
      ERROR("[Bilinear]: fail to set mesh location.");
      return
    end
    
    local trackingresult = {}
    trackingresult.vertices = self.bilinearTracking:GetTrackingVertices()
    --trackingresult.normals = self.bilinearTracking:GetNormals()
    --if self.renderInited ~= true then
      --trackingresult.indices = self.bilinearTracking:GetTriangles()
      --trackingresult.uvs = self.bilinearTracking:GetUvs()
      --self:InitMesh(trackingresult)
    --end
    self:UpdateMesh(trackingresult)
    self:SetShow(true)
    --LOG("Update mesh")
end

Facemesh:MemberRegister("camera",
  venuscore.ScriptTypes.ReferenceType(
    apolloengine.CameraComponent:RTTI(),    
    Facemesh.GetCamera,
    Facemesh.SetCamera
)); 

Facemesh:MemberRegister("recognition",
    venuscore.ScriptTypes.ReferenceType(
        cv.RecognitionComponent:RTTI(),    
        Facemesh.GetRecongnition,
        Facemesh.SetRecongnition
)); 

return Facemesh;