--@author      : Zhu Fei
--@date        : 2020-08-03
--@description : lua script for smart video enhancement, used in edit page
--@version     : 1.0

local venuscore    = require "venuscore"
local apolloengine = require "apolloengine"
local mathfunction = require "mathfunction"

local SmartEnhance = {}
SmartEnhance.Queue = 100   --sequence of smart enhance

--FBOs: render targets of enhance substeps
SmartEnhance.fbo_luma_histogram = 0
SmartEnhance.fbo_rgb_lut        = 0

--materials
SmartEnhance.ro_histogram_gen_step_mat = 0
SmartEnhance.ro_lut_gen_step_mat       = 0
SmartEnhance.ro_lut_apply_step_mat     = 0


----------------------------interfaces----------------------------
function SmartEnhance:Initialize(host, size)
    --public configs
    self.gamma_rgb                    = mathfunction.vector3(1.1, 1.1, 1.1)
    self.alpha_high_rgb               = mathfunction.vector3(0.93, 0.93, 0.93) -- s curve param related to absolute color value
    self.alpha_high_diff_gain_rgb     = mathfunction.vector3(1.0, 1.0, 1.0) -- s curve param related to color value relative to median&&medium
    self.alpha_high_diff_gain_rgb_min = mathfunction.vector3(1.0, 1.0, 1.0) -- clamp on alpha_high_diff_gain_rgb
    self.alpha_high_diff_gain_rgb_max = mathfunction.vector3(1.0, 1.0, 1.0) -- clamp on alpha_high_diff_gain_rgb
    self.alpha_low_rgb                = mathfunction.vector3(0.71, 0.71, 0.71) -- s curve param related to absolute color value
    self.alpha_low_diff_gain_rgb      = mathfunction.vector3(1.0, 1.0, 1.0) -- s curve param related to color value relative to median&&medium
    self.alpha_low_diff_gain_rgb_min  = mathfunction.vector3(1.0, 1.0, 1.0) -- clamp on alpha_low_diff_gain_rgb
    self.alpha_low_diff_gain_rgb_max  = mathfunction.vector3(1.0, 1.0, 1.0) -- clamp on alpha_low_diff_gain_rgb
    self.g0_gain_rgb                  = mathfunction.vector3(0.62, 0.62, 0.62) -- control s curve center
    if _PLATFORM_IOS then
        self.sigma_rgb                    = mathfunction.vector3(90.3, 90.3, 90.3) -- linear curve cutoff
    else
        self.sigma_rgb                    = mathfunction.vector3(88.3, 88.3, 88.3) -- linear curve cutoff
    end
    self.beta_rgb                     = mathfunction.vector3(1.0, 1.0, 1.0) -- linear curve slope
    self.saturation_gain              = 0.05                                -- adjust saturation
    self.model_level                  = 2                                   -- 0: low-level, 1: mid-level, 2: high-level
    self.split_direction              = 0                                   -- 0: horizontal, 1: vertical
    self.split_ratio                  = 1.0                                 -- ratio of enhanced area in final output

    --private configs
    self.block_size          = mathfunction.vector2(16, 16)
    self.max_frame_width = 540
    self:_UpdateMaxFrameWidth()

    --maximum input frame size
    self.proc_frame_size = mathfunction.vector2(size:x(), size:y())
    if self.proc_frame_size:x() > self.max_frame_width then
        self.proc_frame_size = mathfunction.vector2(self.max_frame_width, self.proc_frame_size:y() * self.max_frame_width / self.proc_frame_size:x())
    end

    --backup current state to trigger render parameter update
    self.size = mathfunction.vector2(size:x(), size:y())
    self:_BackupStates()

    --backup host for use in process
    self.host = host

    --paramters determined by configs
    self.histogram_size = mathfunction.vector2(math.ceil(self.proc_frame_size:x() / self.block_size:x()) * 16, math.ceil(self.proc_frame_size:y() / self.block_size:y()) * 16)
    self:_CreateFBOs(host)

    --create render objects and set parameters that do not update frequently
    self.ro_histogram_gen_step_mat = host:CreateMaterial("comm:documents/shaders/posteffect/smart_enhance/histogram_gen_step.material")
    self.ro_histogram_gen_step_ro = host:CreateRenderObject()
    local histogram_gen_material = self.ro_histogram_gen_step_mat
    histogram_gen_material:SetParameter("FRAME_SIZE", self.proc_frame_size)
    histogram_gen_material:SetParameter("BLOCK_SIZE", self.block_size)
    histogram_gen_material:SetParameter("HISTOGRAM_SIZE", self.histogram_size)
    self.ro_lut_gen_step_mat = host:CreateMaterial("comm:documents/shaders/posteffect/smart_enhance/lut_gen_step.material")
    self.ro_lut_gen_step_ro = host:CreateRenderObject()
    local lut_gen_material = self.ro_lut_gen_step_mat
    lut_gen_material:SetParameter("FRAME_SIZE", self.proc_frame_size)
    lut_gen_material:SetParameter("BLOCK_SIZE", self.block_size)
    lut_gen_material:SetParameter("HISTOGRAM_SIZE", self.histogram_size)
    lut_gen_material:SetParameter("GAMMA_RGB", self.gamma_rgb)
    lut_gen_material:SetParameter("ALPHA_HIGH_RGB", self.alpha_high_rgb)
    lut_gen_material:SetParameter("ALPHA_HIGH_DIFF_GAIN_RGB", self.alpha_high_diff_gain_rgb)
    lut_gen_material:SetParameter("ALPHA_HIGH_DIFF_GAIN_RGB_MIN", self.alpha_high_diff_gain_rgb_min)
    lut_gen_material:SetParameter("ALPHA_HIGH_DIFF_GAIN_RGB_MAX", self.alpha_high_diff_gain_rgb_max)
    lut_gen_material:SetParameter("ALPHA_LOW_RGB", self.alpha_low_rgb)
    lut_gen_material:SetParameter("ALPHA_LOW_DIFF_GAIN_RGB", self.alpha_low_diff_gain_rgb)
    lut_gen_material:SetParameter("ALPHA_LOW_DIFF_GAIN_RGB_MIN", self.alpha_low_diff_gain_rgb_min)
    lut_gen_material:SetParameter("ALPHA_LOW_DIFF_GAIN_RGB_MAX", self.alpha_low_diff_gain_rgb_max)
    lut_gen_material:SetParameter("G0_GAIN_RGB", self.g0_gain_rgb)
    lut_gen_material:SetParameter("SIGMA_RGB", self.sigma_rgb)
    lut_gen_material:SetParameter("BETA_RGB", self.beta_rgb)
    self.ro_lut_apply_step_mat = host:CreateMaterial("comm:documents/shaders/posteffect/smart_enhance/lut_apply_step.material")
    self.ro_lut_apply_step_ro = host:CreateRenderObject()
    local lut_apply_material = self.ro_lut_apply_step_mat
    lut_apply_material:SetParameter("FRAME_SIZE", self.proc_frame_size)
    lut_apply_material:SetParameter("BLOCK_SIZE", self.block_size)
    lut_apply_material:SetParameter("HISTOGRAM_SIZE", self.histogram_size)
    lut_apply_material:SetParameter("SATURATION_GAIN", mathfunction.vector1(self.saturation_gain))
    lut_apply_material:SetParameter("SPLIT_DIRECTION", mathfunction.vector1(self.split_direction))
    lut_apply_material:SetParameter("SPLIT_RATIO", mathfunction.vector1(self.split_ratio))

    --register configs
    host:RegisterScriptParameter(self, "gamma_rgb")
    host:RegisterScriptParameter(self, "alpha_high_rgb")
    host:RegisterScriptParameter(self, "alpha_high_diff_gain_rgb")
    host:RegisterScriptParameter(self, "alpha_high_diff_gain_rgb_min")
    host:RegisterScriptParameter(self, "alpha_high_diff_gain_rgb_max")
    host:RegisterScriptParameter(self, "alpha_low_rgb")
    host:RegisterScriptParameter(self, "alpha_low_diff_gain_rgb")
    host:RegisterScriptParameter(self, "alpha_low_diff_gain_rgb_min")
    host:RegisterScriptParameter(self, "alpha_low_diff_gain_rgb_max")
    host:RegisterScriptParameter(self, "g0_gain_rgb")
    host:RegisterScriptParameter(self, "sigma_rgb")
    host:RegisterScriptParameter(self, "beta_rgb")
    host:RegisterScriptParameter(self, "saturation_gain")
    host:RegisterScriptParameter(self, "model_level")
    host:RegisterScriptParameter(self, "split_direction")
    host:RegisterScriptParameter(self, "split_ratio")

    return self.Queue
end

function SmartEnhance:Resizeview(size)
    self.size = mathfunction.vector2(size:x(), size:y())
    self.proc_frame_size = mathfunction.vector2(size:x(), size:y())
    if self.proc_frame_size:x() > self.max_frame_width then
        self.proc_frame_size = mathfunction.vector2(self.max_frame_width, self.proc_frame_size:y() * self.max_frame_width / self.proc_frame_size:x())
    end
    self.histogram_size = mathfunction.vector2(math.ceil(self.proc_frame_size:x() / self.block_size:x()) * 16, math.ceil(self.proc_frame_size:y() / self.block_size:y()) * 16)
    local histogram_gen_material = self.ro_histogram_gen_step_mat
    histogram_gen_material:SetParameter("FRAME_SIZE", self.proc_frame_size)
    histogram_gen_material:SetParameter("HISTOGRAM_SIZE", self.histogram_size)
    local lut_gen_material = self.ro_lut_gen_step_mat
    lut_gen_material:SetParameter("FRAME_SIZE", self.proc_frame_size)
    lut_gen_material:SetParameter("HISTOGRAM_SIZE", self.histogram_size)
    local lut_apply_material = self.ro_lut_apply_step_mat
    lut_apply_material:SetParameter("FRAME_SIZE", self.proc_frame_size)
    lut_apply_material:SetParameter("HISTOGRAM_SIZE", self.histogram_size)
    self:_CreateFBOs(self.host)

    -- LOG("FRAME SIZE "..self.proc_frame_size:x().." "..self.proc_frame_size:y())
    -- LOG("HISTOGRAM SIZE "..self.histogram_size:x().." "..self.histogram_size:y())
end

function SmartEnhance:Process(context, Original, Scene, Output)
    self:_UpdateRenderParameters()
    self:_BackupStates()

    local zero_val = mathfunction.Color(0.0, 0.0, 0.0, 1.0)
    --generate luminance histogram
    context:BeginRenderPass(self.fbo_luma_histogram, apolloengine.RenderTargetEntity.CF_COLOR, zero_val);
    local histogram_gen_material = self.ro_lut_gen_step_mat
    histogram_gen_material:SetParameter("CUR_FRAME", Original:GetAttachment(apolloengine.RenderTargetEntity.TA_COLOR_0))
    context:Draw(self.ro_histogram_gen_step_ro, self.ro_histogram_gen_step_mat);
    context:EndRenderPass();
  
    --generate rgb lut
    context:BeginRenderPass(self.fbo_rgb_lut, apolloengine.RenderTargetEntity.CF_COLOR, zero_val);
    local lut_gen_material = self.ro_lut_gen_step_mat
    lut_gen_material:SetParameter("HISTOGRAM", self.fbo_luma_histogram:GetAttachment(apolloengine.RenderTargetEntity.TA_COLOR_0))
    context:Draw(self.ro_lut_gen_step_ro, self.ro_lut_gen_step_mat);
    context:EndRenderPass();

    --apply rgb lut
    context:BeginRenderPass(Output, apolloengine.RenderTargetEntity.CF_COLOR, zero_val);
    local lut_apply_material = self.ro_lut_apply_step_mat
    lut_apply_material:SetParameter("CUR_FRAME", Scene:GetAttachment(apolloengine.RenderTargetEntity.TA_COLOR_0))
    lut_apply_material:SetParameter("RGB_LUT", self.fbo_rgb_lut:GetAttachment(apolloengine.RenderTargetEntity.TA_COLOR_0))
    context:Draw(self.ro_lut_apply_step_ro, self.ro_lut_apply_step_mat);
    context:EndRenderPass();
end

----------------------------internal functions----------------------------

function SmartEnhance:_CreateFBOs(host)
    self.fbo_luma_histogram = host:CreateRenderTarget(apolloengine.RenderTargetEntity.ST_SWAP_UNIQUE, self.histogram_size)
    self.fbo_rgb_lut        = host:CreateRenderTarget(apolloengine.RenderTargetEntity.ST_SWAP_UNIQUE, self.histogram_size)
end

function SmartEnhance:_UpdateRenderParameters()
    local lut_gen_material = self.ro_lut_gen_step_mat
    if self.gamma_rgb_bak ~= self.gamma_rgb then
        lut_gen_material:SetParameter("GAMMA_RGB", self.gamma_rgb)
    end
    if self.alpha_high_rgb_bak ~= self.alpha_high_rgb then
        lut_gen_material:SetParameter("ALPHA_HIGH_RGB", self.alpha_high_rgb)
    end
    if self.alpha_high_diff_gain_rgb_bak ~= self.alpha_high_diff_gain_rgb then
        lut_gen_material:SetParameter("ALPHA_HIGH_DIFF_GAIN_RGB", self.alpha_high_diff_gain_rgb)
    end
    if self.alpha_high_diff_gain_rgb_min_bak ~= self.alpha_high_diff_gain_rgb_min then
        lut_gen_material:SetParameter("ALPHA_HIGH_DIFF_GAIN_RGB_MIN", self.alpha_high_diff_gain_rgb_min)
    end
    if self.alpha_high_diff_gain_rgb_max_bak ~= self.alpha_high_diff_gain_rgb_max then
        lut_gen_material:SetParameter("ALPHA_HIGH_DIFF_GAIN_RGB_MAX", self.alpha_high_diff_gain_rgb_max)
    end
    if self.alpha_low_rgb_bak ~= self.alpha_low_rgb then
        lut_gen_material:SetParameter("ALPHA_LOW_RGB", self.alpha_low_rgb)
    end
    if self.alpha_low_diff_gain_rgb_bak ~= self.alpha_low_diff_gain_rgb then
        lut_gen_material:SetParameter("ALPHA_LOW_DIFF_GAIN_RGB", self.alpha_low_diff_gain_rgb)
    end
    if self.alpha_low_diff_gain_rgb_min_bak ~= self.alpha_low_diff_gain_rgb_min then
        lut_gen_material:SetParameter("ALPHA_LOW_DIFF_GAIN_RGB_MIN", self.alpha_low_diff_gain_rgb_min)
    end
    if self.alpha_low_diff_gain_rgb_max_bak ~= self.alpha_low_diff_gain_rgb_max then
        lut_gen_material:SetParameter("ALPHA_LOW_DIFF_GAIN_RGB_MAX", self.alpha_low_diff_gain_rgb_max)
    end
    if self.g0_gain_rgb_bak ~= self.g0_gain_rgb then
        lut_gen_material:SetParameter("G0_GAIN_RGB", self.g0_gain_rgb)
    end
    if self.sigma_rgb_bak ~= self.sigma_rgb then
        lut_gen_material:SetParameter("SIGMA_RGB", self.sigma_rgb)
    end
    if self.beta_rgb_bak ~= self.beta_rgb then
        lut_gen_material:SetParameter("BETA_RGB", self.beta_rgb)
    end
    local lut_apply_material = self.ro_lut_apply_step_mat
    if self.saturation_gain_bak ~= self.saturation_gain then
        lut_apply_material:SetParameter("SATURATION_GAIN", mathfunction.vector1(self.saturation_gain))
    end
    if self.split_direction_bak ~= self.split_direction then
        lut_apply_material:SetParameter("SPLIT_DIRECTION", mathfunction.vector1(self.split_direction))
    end
    if self.split_ratio_bak ~= self.split_ratio then
        lut_apply_material:SetParameter("SPLIT_RATIO", mathfunction.vector1(self.split_ratio))
    end
    if self.model_level_bak ~= self.model_level then
        self:_UpdateMaxFrameWidth()
        self:Resizeview(self.size)
    end
end

function SmartEnhance:_BackupStates()
    self.gamma_rgb_bak                    = self.gamma_rgb
    self.alpha_high_rgb_bak               = self.alpha_high_rgb
    self.alpha_high_diff_gain_rgb_bak     = self.alpha_high_diff_gain_rgb
    self.alpha_high_diff_gain_rgb_min_bak = self.alpha_high_diff_gain_rgb_min
    self.alpha_high_diff_gain_rgb_max_bak = self.alpha_high_diff_gain_rgb_max
    self.alpha_low_rgb_bak                = self.alpha_low_rgb
    self.alpha_low_diff_gain_rgb_bak      = self.alpha_low_diff_gain_rgb
    self.alpha_low_diff_gain_rgb_min_bak  = self.alpha_low_diff_gain_rgb_min
    self.alpha_low_diff_gain_rgb_max_bak  = self.alpha_low_diff_gain_rgb_max
    self.g0_gain_rgb_bak                  = self.g0_gain_rgb
    self.sigma_rgb_bak                    = self.sigma_rgb
    self.beta_rgb_bak                     = self.beta_rgb
    self.saturation_gain_bak              = self.saturation_gain
    self.model_level_bak                  = self.model_level
    self.split_direction_bak              = self.split_direction
    self.split_ratio_bak                  = self.split_ratio
end

function SmartEnhance:_UpdateMaxFrameWidth()
    if self.model_level < 1 then
        self.max_frame_width = 54
    elseif self.model_level < 2 then
        self.max_frame_width = 108
    else
        self.max_frame_width = 540
    end
end

return SmartEnhance