"""Provide progress and plot feedback when running the SRH-2D model in SMS."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os
import sys

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import tree_util
from xms.data_objects.parameters import Component

# 4. Local modules
from xms.srh.components.parameters.parameters_manager import ParametersManager
from xms.srh.components.sim_component import SimComponent
from xms.srh.model.progress_plots import PlotManager


class SRHProgressTracker:
    """Class that tracks an SRH-2D model running in SMS."""
    query = None
    prog = None
    plots_manager = None
    new_dip_time_step = -1.0
    dip_activation_time = -1.0
    dip_updated = False
    dip_termination = False
    update_post_progress = False
    post_echo_pos = 0
    post_num_times = 0.0
    netq_termination_value = -1

    @staticmethod
    def progress_function():
        """Calculates the progress and sends it to SMS."""
        # Compute the current progress percent.
        inf_file = SRHProgressTracker.plots_manager.inf_file
        inf_file.update_progress()
        SRHProgressTracker._update_dip_time_step()
        SRHProgressTracker._update_dip_termination()

        # Update the run plots
        if not SRHProgressTracker.update_post_progress:
            SRHProgressTracker.plots_manager.update_plots()
            SRHProgressTracker.query.update_progress_percent(0.9 * inf_file.progress_percent)
            if inf_file.progress_percent >= 100.0:
                SRHProgressTracker.update_post_progress = True
        else:
            pass
            cmd_file = SRHProgressTracker.prog.command_line_output_file
            if not os.path.isfile(cmd_file):
                return
            found_iter = False
            with open(cmd_file, "r") as f:
                f.seek(SRHProgressTracker.post_echo_pos)
                echo_line = f.readline()
                while echo_line:
                    if (echo_line.endswith('\n') or echo_line.endswith('\r')) and \
                            echo_line.strip().startswith("Processed timestep: "):
                        echo_vals = echo_line.split()
                        try:
                            current_ts = float(echo_vals[2])
                            if SRHProgressTracker.post_num_times == 0.0:
                                SRHProgressTracker.post_num_times = float(echo_vals[5])
                        except IndexError:
                            pass
                        found_iter = True
                        SRHProgressTracker.post_echo_pos = f.tell()
                    echo_line = f.readline()
            if found_iter:
                percent_done = 90.0 + (10.0 * (current_ts / SRHProgressTracker.post_num_times))
                SRHProgressTracker.query.update_progress_percent(percent_done)

    @staticmethod
    def _get_sim(sim_uuid: str, model_name: str, unique_name: str) -> Component:
        pt = SRHProgressTracker.query.project_tree
        sim_item = tree_util.find_tree_node_by_uuid(pt, sim_uuid)
        model_sim = None
        if sim_item:
            sim_items = tree_util.descendants_of_type(
                tree_root=sim_item, xms_types=['TI_DYN_SIM_PTR'], allow_pointers=True, model_name=model_name
            )
            for sim_ptr in sim_items:
                sim_comp = SRHProgressTracker.query.item_with_uuid(
                    sim_ptr.uuid, model_name=model_name, unique_name=unique_name
                )
                model_sim = sim_comp
        return model_sim

    @staticmethod
    def start_tracking():
        """Entry point for the SRH-2D progress script."""
        SRHProgressTracker.query = Query(progress_script=True)
        sim_uuid = SRHProgressTracker.query.current_item_uuid()
        unique_name = 'Sim_Manager'
        model_name = 'SRH-2D'
        sim_do_comp = SRHProgressTracker._get_sim(sim_uuid, model_name, unique_name)
        if not sim_do_comp:
            sim_do_comp = SRHProgressTracker.query.item_with_uuid(
                sim_uuid, model_name=model_name, unique_name=unique_name
            )
        sim_comp = SimComponent(sim_do_comp.main_file)
        if sim_comp.data.advanced.use_secondary_time_step:
            SRHProgressTracker.new_dip_time_step = sim_comp.data.advanced.secondary_time_step
            SRHProgressTracker.dip_activation_time = sim_comp.data.advanced.activation_time
        if sim_comp.data.advanced.use_termination_criteria:
            SRHProgressTracker.netq_termination_value = sim_comp.data.advanced.netq_value

        # get the main file for the sim component and get the case name from the model control
        case_name = sim_comp.data.hydro.case_name
        if len(sys.argv) > 1:
            run_num = int(sys.argv[1])
            param_data = ParametersManager.read_parameter_file(sim_comp.main_file)
            if param_data and param_data['use_parameters']:
                run_name = param_data['runs']['Run'][run_num - 1]
                case_name = os.path.join(run_name, run_name)

        SRHProgressTracker.prog = SRHProgressTracker.query.xms_agent.session.progress_loop
        SRHProgressTracker.log_file = case_name + '_srh_prog.log'
        SRHProgressTracker.plots_manager = PlotManager(sim_comp, case_name, SRHProgressTracker.prog)
        SRHProgressTracker.prog.set_progress_function(SRHProgressTracker.progress_function)
        SRHProgressTracker.prog.start_loop()

    @staticmethod
    def _update_dip_time_step():
        """Update the time step size in the DIP file if the option is on."""
        if SRHProgressTracker.dip_updated:
            return
        if SRHProgressTracker.new_dip_time_step < 0.0:
            return

        inf_file = SRHProgressTracker.plots_manager.inf_file
        if inf_file.current_time > SRHProgressTracker.dip_activation_time:
            SRHProgressTracker._update_dip_file(f' dtnew = {SRHProgressTracker.new_dip_time_step}\n')

    @staticmethod
    def _update_dip_termination():
        """Update the simulation end time if we have met the termination criteria."""
        if SRHProgressTracker.dip_termination:
            return
        if SRHProgressTracker.netq_termination_value < 0.0:
            return

        val = SRHProgressTracker.netq_termination_value
        inf_file = SRHProgressTracker.plots_manager.inf_file
        if len(inf_file.data) > 2:
            time = inf_file.data[-1][0]
            nq_0 = inf_file.data[-1][1]
            nq_1 = inf_file.data[-2][1]
            if abs(nq_0) < val and abs(nq_1) < val:
                SRHProgressTracker._update_dip_file(f' total_simulation_time = {time}\n')
                SRHProgressTracker.dip_termination = True

    @staticmethod
    def _update_dip_file(new_line):
        """Update the DIP file with a new line."""
        inf_file = SRHProgressTracker.plots_manager.inf_file
        dip_file = inf_file.inf_file.replace('_INF.dat', '_DIP.dat')
        if os.path.isfile(dip_file):
            collect_lines = False
            lines_to_write = []
            with open(dip_file, 'r') as f:
                lines = f.readlines()
                for line in lines:
                    if '$DATAC' in line.upper():
                        collect_lines = True
                    elif '$ENDC' in line.upper():
                        collect_lines = False
                    elif collect_lines:
                        lines_to_write.append(f'{line}')
            lines_to_write.append(new_line)
            with open(dip_file, 'w') as f:
                f.write('$DATAC\n')
                f.writelines(lines_to_write)
                f.write('$ENDC\n')
                SRHProgressTracker.dip_updated = True
