--@author      : Jin Shiyu
--@date        : 2020-11-02
--@description : lua script for face beauty
--@version     : 1.0

local apolloengine  = require "apolloengine"
local apollonode    = require "apolloutility.apollonode"
local renderqueue   = require "apolloutility.renderqueue"
local mathfunction  = require "mathfunction"
local venuscore     = require "venuscore"
local defined       = require "apolloutility.defiend"
local quad_render   = require "facebeautyutils.face_beauty_quad_render"
local beauty_render = require "facebeautyutils.face_beauty_render"
local face_defined  = require "facebeautyutils.defined"
local cv            = require "computervisionfunction"

--[[
    Face beauty on one frame, including brighten eye, nasolabial removal, eyebag removal and whiten teeth
    NOTE:
    It can be used in testshell main and editor behavior which should offer one main camera.
]]

local FaceBeauty = {};

local BRIGHT_EYE   = 1;
local NASOLABIAL   = 2;
local EYEBAG       = 3;
local WHITEN_TEETH = 4;

local MOUTH_OPEN_ACTION = 256;

function FaceBeauty:Create(layer_sequence)
    self.mask_texture_path         = "docs:facebeauty/resource/standard_face_v4.png"
    self.teeth_lookup_texture_path = "docs:facebeauty/resource/whiten_teeth_lut.png"
    self.eye_mask_texture_path     = "docs:facebeauty/resource/eye_face_mask.png"

    --[[
        1: current strength of brighten eye (0-1)
        2: current strength of nasolabial removal (0-1)
        3: current strength of eyebag removal (0-1)
        4: current strength of whiten teeth (0-1)
    ]]
    self.beauty_strengths = { 0.0, 0.0, 0.0, 0.0};

    self.scaled_factor        = 0.5;            -- scaled factor of downsample
    self.offset_factor        = 2.8958829;      -- factor of texture sampling step
    self.render_vertex_num    = face_defined.NUM_FACE_LAMDMARK;
    self.render_triangle_num  = face_defined.NUM_DELAUNAY_POINTS;
    self.max_face_num         = 3;

    --self.main_render         = nil;
    self.new_frame_call_back = nil;
    self.face_renders        = {};
    self.teeth_show_ratio    = {};
    self.virtual_node        = nil;

    self.is_show     = true;
    self.initialized = false;
    self.need_update = false;
    self:Initialize(layer_sequence);
end

function FaceBeauty:UpdateTextureSize(frame_width, frame_height)

  local scaled_width  = frame_width * self.scaled_factor;
  local scaled_height = frame_height * self.scaled_factor;

  local scaler_updated     = self.image_scaler:Update(scaled_width, scaled_height);
  local vertical_updated   = self.vertical_box_blur:Update(scaled_width, scaled_height);
  local horizontal_updated = self.horizontal_box_blur:Update(scaled_width, scaled_height);

  -- update
  if scaler_updated and vertical_updated and horizontal_updated then
      self.vertical_box_blur:SetParameter("UNIFORM_INPUT_TEXTURE", self.image_scaler:GetResultImage());
      self.vertical_box_blur:SetParameter("UNIFORM_TEXEL_WIDTH_OFFSET", mathfunction.vector1(0));
      self.vertical_box_blur:SetParameter("UNIFORM_TEXEL_HEIGHT_OFFSET", mathfunction.vector1(self.offset_factor / scaled_height));

      self.horizontal_box_blur:SetParameter("UNIFORM_INPUT_TEXTURE", self.vertical_box_blur:GetResultImage());
      self.horizontal_box_blur:SetParameter("UNIFORM_TEXEL_WIDTH_OFFSET", mathfunction.vector1(self.offset_factor / scaled_width));
      self.horizontal_box_blur:SetParameter("UNIFORM_TEXEL_HEIGHT_OFFSET", mathfunction.vector1(0));

      for i = 1, self.max_face_num do
          self.face_renders[i]:SetParameter("UNIFORM_BLUR_TEXTURE",  self.image_scaler:GetResultImage());
          self.face_renders[i]:SetParameter("UNIFORM_BLUR_TEXTURE2", self.horizontal_box_blur:GetResultImage());
      end
  end
end

function FaceBeauty:LoadTexture(texture_path)
    local texture = apolloengine.TextureEntity();
    texture:PushMetadata(apolloengine.TextureFileMetadata(
                                   apolloengine.TextureEntity.TU_STATIC,
                                   apolloengine.TextureEntity.PF_AUTO,
                                   1, false,
                                   apolloengine.TextureEntity.TW_REPEAT,
                                   apolloengine.TextureEntity.TW_REPEAT,
                                   apolloengine.TextureEntity.TF_LINEAR,
                                   apolloengine.TextureEntity.TF_LINEAR,
                                   venuscore.IFileSystem:PathAssembly(texture_path)));
    texture:SetJobType(venuscore.IJob.JT_SYNCHRONOUS);
    texture:CreateResource();
    return texture;
end

function FaceBeauty:Initialize(layer_sequence)

    -- load texture
    self.mask_texture         = self:LoadTexture(self.mask_texture_path);
    self.teeth_lookup_texture = self:LoadTexture(self.teeth_lookup_texture_path);
    self.eye_mask_texture     = self:LoadTexture(self.eye_mask_texture_path);

    -- setup quad render
    local initial_width  = 360;
    local initial_height = 640;
    local max_sequence   = layer_sequence - self.max_face_num;

    self.image_scaler        = quad_render("face_beauty_identity", max_sequence - 3, "identity_layer", "docs:facebeauty/material/identity.material", initial_width, initial_height);
    self.vertical_box_blur   = quad_render("face_beauty_vertical", max_sequence - 2, "vertical_blur_layer", "docs:facebeauty/material/box_blur.material", initial_width, initial_height);
    self.horizontal_box_blur = quad_render("face_beauty_horizontal", max_sequence - 1, "horizontal_blur_layer", "docs:facebeauty/material/box_blur.material", initial_width, initial_height);

    self.vertical_box_blur:SetParameter("UNIFORM_INPUT_TEXTURE", self.image_scaler:GetResultImage());
    self.vertical_box_blur:SetParameter("UNIFORM_TEXEL_WIDTH_OFFSET", mathfunction.vector1(0));
    self.vertical_box_blur:SetParameter("UNIFORM_TEXEL_HEIGHT_OFFSET", mathfunction.vector1(self.offset_factor / initial_height));

    self.horizontal_box_blur:SetParameter("UNIFORM_INPUT_TEXTURE", self.vertical_box_blur:GetResultImage());
    self.horizontal_box_blur:SetParameter("UNIFORM_TEXEL_WIDTH_OFFSET", mathfunction.vector1(self.offset_factor / initial_width));
    self.horizontal_box_blur:SetParameter("UNIFORM_TEXEL_HEIGHT_OFFSET", mathfunction.vector1(0));

    -- setup face render
    self.vetex_stream = apolloengine.VertexStream();
    self.vetex_stream:SetVertexType(apolloengine.ShaderEntity.ATTRIBUTE_POSITION,
                                    apolloengine.VertexBufferEntity.DT_FLOAT,
                                    apolloengine.VertexBufferEntity.DT_FLOAT,
                                    3);
    self.vetex_stream:SetVertexType(apolloengine.ShaderEntity.ATTRIBUTE_COORDNATE0,
                                    apolloengine.VertexBufferEntity.DT_FLOAT,
                                    apolloengine.VertexBufferEntity.DT_FLOAT,
                                    2);
    self.vetex_stream:ReserveBuffer(self.render_vertex_num);

    self.index_stream = apolloengine.IndicesStream();
    self.index_stream:SetIndicesType(apolloengine.IndicesBufferEntity.IT_UINT16);
    self.index_stream:ReserveBuffer(self.render_triangle_num * 3);

    self:UpdateRenderDataOnce();

    self.face_renders      = {};
    self.teeth_show_ratios = {};
    for i = 1, self.max_face_num do
        local face_render = beauty_render("docs:facebeauty/material/face_beauty.material",
                                          self.vetex_stream,
                                          self.index_stream,
                                          layer_sequence);

        face_render:SetParameter("UNIFORM_MASK_TEXTURE", self.mask_texture);
        face_render:SetParameter("UNIFORM_TEETH_LOOKUP_TEXTURE", self.teeth_lookup_texture);
        face_render:SetParameter("UNIFORM_EYE_MASK_TEXTURE", self.eye_mask_texture);

        face_render:SetParameter("UNIFORM_BLUR_TEXTURE",  self.image_scaler:GetResultImage());
        face_render:SetParameter("UNIFORM_BLUR_TEXTURE2", self.horizontal_box_blur:GetResultImage());
        face_render:SetParameter("UNIFORM_BRIGHT_EYE_STRENGTH", mathfunction.vector1(self.beauty_strengths[BRIGHT_EYE]));
        face_render:SetParameter("UNIFORM_NAOLABIAL_STRENGTH", mathfunction.vector1(self.beauty_strengths[NASOLABIAL]));
        face_render:SetParameter("UNIFORM_EYEBAG_STRENGTH", mathfunction.vector1(self.beauty_strengths[EYEBAG]));
        face_render:SetParameter("UNIFORM_WHITEN_TEETH_STRENGTH", mathfunction.vector1(self.beauty_strengths[WHITEN_TEETH]));

        face_render:SetDrawCount(self.render_triangle_num * 3);
        table.insert(self.face_renders, face_render);
        table.insert(self.teeth_show_ratios, 0.0);
    end

    self.CaptureInput = apolloengine.TextureEntity();
    self.CaptureInput:PushMetadata(apolloengine.TextureReferenceMetadata(apolloengine.DeviceResource.DEVICE_CAPTURE));
    self.CaptureInput:CreateResource();

    self.virtual_node       = apollonode.VirtualNode();
    local classify          = self.virtual_node:CreateComponent(apolloengine.Node.CT_CV_CLASSIFY);
    classify.Enable         = true;
    classify.Mode           = cv.IVisionComponent.VIDEO;
    classify.FaceExpression = true;
    classify:SetTexture(self.CaptureInput);
    self.classify = classify;

    local recognition   = self.virtual_node:CreateComponent(apolloengine.Node.CT_CV_RECOGNITION);
    recognition.Enable  = true;
    recognition.Mode    = cv.IVisionComponent.VIDEO;
    recognition.Type    = cv.RecognitionComponent.cvFace;
    recognition.ShareCV = true;
    recognition:SetTexture(self.CaptureInput);
    self.recognition   = recognition;

    self.initialized = true;
end

function FaceBeauty:SetShow(is_show)
    self.is_show = is_show;
end

function FaceBeauty:StrengthValid()
    local valid_flag = false;
    -- Increase the valid threshold appropriately in case that minimum strength is not 0
    for i = 1, 4 do
        if self.beauty_strengths[i] >= 0.01 then
            valid_flag = true;
        end
    end
    return valid_flag;
end

--[[
    This function should be called before Update(), it will
    1. disable all render first
    2. enable render and return true if has face or valid strength
]]
function FaceBeauty:PreUpdate(def)
    if not self.initialized then
        ERROR("[FaceBeauty] not initialized");
        return false;
    end

    -- reset state and disable all render first
    self:Reset();

    -- check if enabled
    if self.is_show == false or self:StrengthValid() == false then
        --LOG("[FaceBeauty] don't show");
        return false;
    end

    if _KRATOSEDITOR then
    else
        self.virtual_node.node:Update(def);
    end

    local facemesh_type = cv.RecognitionComponent.cvFace;
    local results       = self.recognition:GetResultByTypeFast(facemesh_type);
    if results == nil or results[facemesh_type] == nil then
        --ERROR("[FaceBeauty] recognition NO RESULT");
        return false;
    end

    local classify_results = nil;
    if self.classify.Enable == true then
        classify_results = self.classify:GetResult();

        if classify_results == nil then
            ERROR("[FaceBeauty] classify is nil");
            return false;
        end
    end

    local detect_size = self.recognition:GetDetectSize();
    if detect_size == nil then
        ERROR("[FaceBeauty]: detect_size is nil");
        return false;
    end

    --local sizeInv      = {1.0 / detect_size:x(), 1.0 / detect_size:y()};
    local sizeInv             = {2.0 / detect_size:x(), 2.0 / detect_size:y()};
    local normalize_keypoints = {}
    local face_num            = math.min(self.max_face_num, #results[facemesh_type]);

    -- update teeth show ratio
    if classify_results ~= nil then
        local classify_num = math.min(face_num, #classify_results);
        for i = 1, classify_num do
            local face_actions = classify_results[i];
            if face_actions ~= nil then
                for j = 1, #face_actions do
                    if face_actions[j].ActionID == MOUTH_OPEN_ACTION then
                        self.teeth_show_ratios[i] = 1.0;
                    end
                end
            end
        end
        --ERROR("[FaceBeauty] classify is not nil");
    end

    for i = 1, face_num do
        local keypoints = results[facemesh_type][i];
        if keypoints ~= nil then
            local keypoints_count  = keypoints:Size();
            local visibility_start = keypoints_count - face_defined.NUM_FACE_LAMDMARK;
            for j = 1, face_defined.NUM_FACE_LAMDMARK do
                local point      = keypoints:Get(j);
                local visibility = keypoints:Get(j + visibility_start);
                normalize_keypoints[j] = {point.mx * sizeInv[1] - 1.0, 1.0 - point.my * sizeInv[2], visibility.mx};
            end
            self:UpdateVertex(normalize_keypoints);
            self.face_renders[i]:Update(self.vetex_stream);
            self.need_update = true;
        end
    end

    return self.need_update;
end

function FaceBeauty:UpdateSequence(sequence)
    local max_sequence       = sequence - self.max_face_num;
    self.image_scaler:SetSequence(max_sequence - 3);
    self.vertical_box_blur:SetSequence(max_sequence - 2);
    self.horizontal_box_blur:SetSequence(max_sequence - 1);

    for i = 1, self.max_face_num do
        self.face_renders[i]:SetSequence(sequence);
    end
end

function FaceBeauty:Update(sequence, input_tex)
    if not self.need_update then
        LOG("[FaceBeauty] no need to update");
        return;
    end

    -- update sequence
    self:UpdateSequence(sequence);

    for i = 1, self.max_face_num do
        self:UpdateFaceRender(i, self.teeth_show_ratios[i], input_tex);
    end

    self:UpdateQuadRender(input_tex);
end

function FaceBeauty:Reset()
    self.need_update = false;
    self.image_scaler:SetShow(false);
    self.vertical_box_blur:SetShow(false);
    self.horizontal_box_blur:SetShow(false);

    for i = 1, self.max_face_num do
        self.face_renders[i]:SetShow(false);
    end

    for i = 1, self.max_face_num do
        self.teeth_show_ratios[i] = 0.0;
    end

    if self.beauty_strengths[WHITEN_TEETH] > 0 then
        self.classify.Enable = true;
    else
        self.classify.Enable = false;
    end
end

function FaceBeauty:UpdateFaceRender(render_id, mouth_show_ratio, input_tex)
    if not self.face_renders[render_id].is_show then
        return;
    end

    self.face_renders[render_id]:SetParameter("UNIFORM_INPUT_TEXTURE", input_tex);
    self.face_renders[render_id]:SetParameter("UNIFORM_BRIGHT_EYE_STRENGTH", mathfunction.vector1(self.beauty_strengths[BRIGHT_EYE]));
    self.face_renders[render_id]:SetParameter("UNIFORM_NAOLABIAL_STRENGTH", mathfunction.vector1(self.beauty_strengths[NASOLABIAL]));
    self.face_renders[render_id]:SetParameter("UNIFORM_EYEBAG_STRENGTH", mathfunction.vector1(self.beauty_strengths[EYEBAG]));
    self.face_renders[render_id]:SetParameter("UNIFORM_WHITEN_TEETH_STRENGTH", mathfunction.vector1(self.beauty_strengths[WHITEN_TEETH] * mouth_show_ratio));
    --LOG("FaceBeauty: "..render_id.." "..mouth_show_ratio);
end

function FaceBeauty:UpdateQuadRender(input_tex)
    local texture_size = input_tex:GetSize();
    self:UpdateTextureSize(texture_size:x(), texture_size:y());

    self.image_scaler:SetParameter("UNIFORM_INPUT_TEXTURE", input_tex);
end

function FaceBeauty:UpdateRenderDataOnce()
    -- index
    local index_num = self.render_triangle_num * 3;
    for i = 1, index_num do
        self.index_stream:PushIndicesData(face_defined.FACE_DELAUNAY_INDICE[i]);
    end

    -- vetex
    for i = 1, self.render_vertex_num do
        local tex_coord = mathfunction.vector2(face_defined.STANDARD_FACE_MASK_TEX_COORD[2 * i - 1], face_defined.STANDARD_FACE_MASK_TEX_COORD[2 * i]);
        self.vetex_stream:PushVertexData(apolloengine.ShaderEntity.ATTRIBUTE_COORDNATE0, tex_coord);
        self.vetex_stream:PushVertexData(apolloengine.ShaderEntity.ATTRIBUTE_POSITION, mathfunction.vector3(0,0,0));
    end
end

function FaceBeauty:UpdateVertex(keypoints)
    local newpoint     =  mathfunction.vector3(0, 0, 0);
    local vertex_index = self.vetex_stream:GetAttributeIndex(apolloengine.ShaderEntity.ATTRIBUTE_POSITION);

    for i = 1, self.render_vertex_num do
      newpoint:Set(keypoints[i][1], keypoints[i][2], keypoints[i][3]);
      self.vetex_stream:ChangeVertexDataWithAttributeFast(vertex_index, i, newpoint);
      --LOG("face point: " .. i .. ", " .. newpoint:x() .. ", " .. newpoint:y());
    end
    self.vetex_stream:SetReflushInterval(1, self.render_vertex_num);
end

function FaceBeauty:SetBeautyStrength(beauty_type, strength)
    if beauty_type <= 0 or beauty_type > 4 then
        return false;
    end

    if strength < 0.0 or strength > 1.0 then
        return false;
    end

    self.beauty_strengths[beauty_type] = strength;
    return true;
end

function FaceBeauty:GetBeautyStrength(beauty_type)
    if beauty_type <= 0 or beauty_type > 4 then
        return 0.0;
    end
    return self.beauty_strengths[beauty_type];
end

function FaceBeauty:GetMaskTexturePath()
  return self.mask_texture_path;
end

function FaceBeauty:SetMaskTexturePath(value)
  self.mask_texture_path = value;
  self.mask_texture      = self:LoadTexture(self.mask_texture_path);
end

function FaceBeauty:GetTeethLookupTexturePath()
  return self.teeth_lookup_texture_path;
end

function FaceBeauty:SetTeethLookupTexturePath(value)
  self.teeth_lookup_texture_path = value;
  self.teeth_lookup_texture      = self:LoadTexture(self.teeth_lookup_texture_path);
end

function FaceBeauty:GetEyeMaskTexturePath()
  return self.eye_mask_texture_path;
end

function FaceBeauty:SetEyeMaskTexturePath(value)
  self.eye_mask_texture_path = value;
  self.eye_mask_texture_texture = self:LoadTexture(self.eye_mask_texture_path);
end

function FaceBeauty:Destroy()
    if self.image_scaler ~= nil then
        self.image_scaler:Clear()
        self.image_scaler = nil;
    end
    if self.vertical_box_blur ~= nil then
        self.vertical_box_blur:Clear()
        self.vertical_box_blur = nil;
    end
    if self.horizontal_box_blur ~= nil then
        self.horizontal_box_blur:Clear()
        self.horizontal_box_blur = nil;
    end

    for i = 1, self.max_face_num do
        if self.face_renders[i] ~= nil then
            self.face_renders[i]:Clear();
        end
    end

    if self.virtual_node ~= nil then
        self.virtual_node:Destroy();
        self.virtual_node = nil;
    end

    self.face_renders      = {};
    self.teeth_show_ratios = {};
    self.initialized       = false;
end

return FaceBeauty;