"""Module for managing the plot data of SRH model run progress scripts."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import logging
import math
import os
import shlex
import sqlite3

# 2. Third party modules

# 3. Aquaveo modules

# 4. Local modules
from xms.srh.model import progress_files as prf

# Monitor point database table indices
PT_GROUND_TABLE = 0
PT_WSE_TABLE = 1


class PlotManager:
    """Class to manage plots for an SRH-2D run."""
    logger = None

    def __init__(self, sim_comp, case_name, progress_looper):
        """Constructor.

        Args:
            sim_comp (:obj:`SimComponent`): The simulation being run
            case_name (:obj:`str`): case name of the model (may be a subdirectory)
            progress_looper (:obj:`ProgressLoop`): The api progress loop project.
        """
        self.comp = sim_comp
        self.looper = progress_looper
        self.inf_file = None
        self._sim_dir = os.getcwd()
        self._case_name = os.path.basename(case_name)
        if os.path.basename(case_name) != case_name:
            self._case_name = os.path.basename(case_name)
            self._sim_dir = os.path.join(os.getcwd(), os.path.dirname(case_name))
        self._inf_file_name = os.path.join(self._sim_dir, f'{self._case_name}_INF.dat')
        self._points_files = {}
        self._arc_q_files = {}
        self._arc_qs_files = {}
        self._enable_sediment = self.comp.data.enable_sediment
        self._logger = None
        self._log_file = os.path.join(self._sim_dir, f'{self._case_name}_progress_plots.log')
        if os.path.isfile(self._log_file):
            os.remove(self._log_file)
        self._set_logger()
        self._logger.info('Begin logging')
        self._logger.info(f'cur dir: {os.getcwd()}')
        self._logger.info(f'Case name: {self._case_name}')
        self._setup_plots()

    def _set_logger(self):
        """Sets up the logger."""
        self._logger = logging.getLogger('xms.srh')
        handler = logging.FileHandler(self._log_file)
        self._logger.addHandler(handler)
        self._logger.setLevel(logging.INFO)

    def update_plots(self):
        """Update the plots."""
        if self._logger is None:
            self._set_logger()
        self._logger.info('update_plots')
        connection = sqlite3.connect(self.looper.plot_db_file)
        self._update_residual_plot(connection)
        self._update_points_plots(connection)
        self._update_arc_q_plots(connection)
        if self._enable_sediment:
            self._update_arc_qs_plots(connection)

    def _setup_plots(self):
        """Set up the plot files and database references."""
        self._setup_mass_balance_plot()
        self._setup_points_plots()
        self._setup_arc_q_plots()
        # Only add the QS_tot plot if this is a sediment run.
        if self._enable_sediment:
            self._setup_arc_qs_plots()

    def _setup_mass_balance_plot(self):
        """Setup the residual plot."""
        # Create the residual plot object
        self.inf_file = prf.InfFile()
        self.inf_file.inf_file = self._inf_file_name
        self.inf_file.using_restart = self.comp.data.hydro.initial_condition == 'Restart File'
        self.inf_file.start = self.comp.data.hydro.start_time
        self.inf_file.end = self.comp.data.hydro.end_time
        self.inf_file.run_duration = self.inf_file.end - self.inf_file.start
        self._logger.info(f'INF_FILE: {self.inf_file.inf_file}')
        self._logger.info(
            f'Start: {self.inf_file.start}; End: {self.inf_file.end}; '
            f'Duration: {self.inf_file.run_duration}'
        )

    def _setup_points_plots(self):
        """Setup the points plots."""
        # Create the monitor point plots
        self._logger.info('_setup_points_plots')
        data_names = self.looper.get_plot_data_names("plot2")
        ids_names = []
        for i, name in enumerate(data_names):
            name = name.replace('Pt', '')
            name = name.replace(' WSE', '')
            name = name.replace(' Z', '')
            ids_names.append((int(name), data_names[i]))
        ids_names.sort()
        data_names = [item[1] for item in ids_names]
        self._logger.info(f'data_names: {data_names}')
        pt_count_wse = 0
        pt_count_z = 0
        for data in data_names:
            data_str = str(data)
            tables = self.looper.get_plot_data_db_table("plot2", data_str)
            if "WSE" in data_str:
                set_table = PT_WSE_TABLE
                pt_count_wse = pt_count_wse + 1
                pt_count = pt_count_wse
            elif "Z" in data_str:
                set_table = PT_GROUND_TABLE
                pt_count_z = pt_count_z + 1
                pt_count = pt_count_z
            else:
                continue
            tabs = [str(t) for t in tables]
            self._logger.info(f'data: {data_str}; tables: {tabs}')

            for tbl in tables:
                pt_file = f'{self._case_name}_PT{pt_count}.dat'
                pt_file = os.path.join(self._sim_dir, 'Output_MISC', pt_file)
                self._logger.info(f'pt_file: {pt_file}')
                if pt_file not in self._points_files:
                    self._points_files[pt_file] = prf.PointFile(0)
                self._points_files[pt_file].db_table[set_table] = str(tbl)

    def _setup_arc_q_plots(self):
        """Setup the arc q plots."""
        self._setup_arc_plots('plot3', self._arc_q_files)

    def _setup_arc_qs_plots(self):
        """Setup the arc qs plots."""
        self._setup_arc_plots('plot4', self._arc_qs_files)

    def _setup_arc_plots(self, plot_name: str, arc_files: dict):
        """Setup the arc qs plots."""
        data_names = self.looper.get_plot_data_names(plot_name)
        ids_names = []
        for i, name in enumerate(data_names):
            name = name.replace('Arc', '')
            ids_names.append((int(name), data_names[i]))
        ids_names.sort()
        data_names = [item[1] for item in ids_names]
        for data in data_names:
            data_str = str(data)
            arc_id = int(data_str[3:])  # Pull arc id off of plot name
            tables = self.looper.get_plot_data_db_table(plot_name, data_str)
            for tbl in tables:
                arc_file = f'{self._case_name}_LN{arc_id}.dat'
                arc_file = os.path.join(self._sim_dir, 'Output_MISC', arc_file)
                arc_files[arc_file] = prf.ArcFile(0, str(tbl))

    def _update_residual_plot(self, connection):
        """Update the residual plot.

        Args:
            connection: The sqlite database connection
        """
        if self.inf_file.num_inserted_to_db >= len(self.inf_file.data):
            return

        cursor = connection.cursor()
        netq = self.looper.get_plot_data_db_table("netq_plot", "Net_Q/INLET_Q")
        mass_error = self.looper.get_plot_data_db_table("mb_plot", "Mass_Error")
        cum_mass_error = self.looper.get_plot_data_db_table("mb_plot", "Cumu_Mass_Error")
        wet_elems = self.looper.get_plot_data_db_table('wet_plot', 'Wet_Elements')
        if netq and mass_error and cum_mass_error and wet_elems:
            tables = [netq[0], mass_error[0], cum_mass_error[0], wet_elems[0]]
            while self.inf_file.num_inserted_to_db < len(self.inf_file.data):
                data_idx = self.inf_file.num_inserted_to_db
                time = self.inf_file.data[data_idx][0]
                for idx, t in enumerate(tables):
                    stmt = f'INSERT INTO {t} VALUES ({time}, {self.inf_file.data[data_idx][idx+1]})'
                    cursor.execute(stmt)
                self.inf_file.num_inserted_to_db += 1
        connection.commit()

    def _update_points_plots(self, connection):
        """Update the monitor point plots.

        Args:
            connection: The sqlite database connection
        """
        cursor = connection.cursor()
        self._logger.info('_update_points_plots')
        self._logger.info(f'len pt files: {len(self._points_files)}')
        for pt_file_name in self._points_files:  # Loop through each monitor point in the simulation
            try:
                self._logger.info(f'pt_file_name: {pt_file_name}')
                with open(pt_file_name, 'r') as output_file:
                    self._logger.info(f'opened file : {pt_file_name}')
                    point_file = self._points_files[pt_file_name]
                    output_file.seek(point_file.file_pos)
                    pt_line = output_file.readline()
                    while pt_line:  # do not use this syntax!: for pt_line in pt_file:
                        pt_vals = shlex.split(pt_line)
                        if len(pt_vals) >= 11 and (pt_line.endswith('\n') or pt_line.endswith('\r')):
                            try:
                                time = float(pt_vals[prf.PT_COLUMN_TIME])
                                wse = float(pt_vals[prf.PT_COLUMN_WSE])
                                ground = float(pt_vals[prf.PT_COLUMN_GROUND])
                                if not math.isclose(wse, -999.0, abs_tol=0.00001):
                                    table_name = str(point_file.db_table[PT_WSE_TABLE])
                                    cursor.execute(f'INSERT INTO {table_name} VALUES ({str(time)}, {str(wse)})')
                                if not math.isclose(ground, -999.0, abs_tol=0.00001):
                                    table_name = str(point_file.db_table[PT_GROUND_TABLE])
                                    cursor.execute(f'INSERT INTO {table_name} VALUES ({str(time)}, {str(ground)})')
                            except ValueError:
                                pass
                            except sqlite3.OperationalError:
                                pass
                            else:
                                point_file.file_pos = output_file.tell()
                        pt_line = output_file.readline()
                    connection.commit()
            except Exception:
                pass  # File might not exist yet

    def _update_arc_q_plots(self, connection):
        """Update the monitor line q plots.

        Args:
            connection: The sqlite database connection
        """
        self._update_arc_plots(connection, self._arc_q_files, prf.ARC_COLUMN_Q)

    def _update_arc_qs_plots(self, connection):
        """Update the monitor line qs plots.

        Args:
            connection: The sqlite database connection
        """
        self._update_arc_plots(connection, self._arc_qs_files, prf.ARC_COLUMN_QS)

    def _update_arc_plots(self, connection, arc_files: dict, column: int):
        """Update the monitor line plots.

        Args:
            connection: The sqlite database connection
            arc_files: dictionary of arc files
            column: column index to read from the arc file
        """
        cursor = connection.cursor()
        for arc_file_name in arc_files:  # Loop through each monitor line in the simulation
            try:
                with open(arc_file_name, 'r') as output_file:
                    arc_file = arc_files[arc_file_name]
                    output_file.seek(arc_file.file_pos)
                    arc_line = output_file.readline()
                    while arc_line:
                        arc_vals = shlex.split(arc_line)
                        if len(arc_vals) >= 3 and (arc_line.endswith('\n') or arc_line.endswith('\r')):
                            try:
                                time = float(arc_vals[prf.ARC_COLUMN_TIME])
                                q = float(arc_vals[column])
                                if not math.isclose(q, -999.0, abs_tol=0.00001):
                                    table_name = str(arc_file.db_table)
                                    cursor.execute(f'INSERT INTO {table_name} VALUES ({str(time)}, {str(q)})')
                            except ValueError:
                                pass
                            except sqlite3.OperationalError:
                                pass
                            else:
                                arc_file.file_pos = output_file.tell()
                        arc_line = output_file.readline()
                    connection.commit()
            except FileNotFoundError:
                pass  # File might not exist yet
