"""Module for managing the plot data of SRH PEST model run progress scripts."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import json
import logging
import os
import sqlite3

# 2. Third party modules

# 3. Aquaveo modules

# 4. Local modules

# Monitor point database table indices
PT_GROUND_TABLE = 0
PT_WSE_TABLE = 1

LINES_BEFORE_INITIAL_PHI = 1
LINES_BEFORE_PHI = 2
LINES_BEFORE_INITIAL_PARAMETERS = 4
LINES_BEFORE_PARAMETERS = 8
CHARACTERS_BEFORE_PHI = 37


class PestPlotManager:
    """Class to manage plots for an SRH-2D run."""
    def __init__(self, progress_looper):
        """Constructor.

        Args:
            progress_looper (:obj:`ProgressLoop`): The api progress loop project.
        """
        self.looper = progress_looper
        self._last_iteration = -1
        self._plotted_initial = False
        with open('srh_pest.json', 'r') as f:  # Current working directory should be the PEST model run location.
            json_dict = json.loads(f.read())
        self._parameter_names = json_dict['PARAMETER_NAMES']  # In order we wrote for PEST
        self._rec_file = f'{json_dict["PROJECT_NAME"]}.rec'
        self.logger = logging.getLogger('pest_plots')
        self._plot_data = {'pest_error_plot': []}
        for p in self._parameter_names:
            self._plot_data[p] = []
        log_file = 'srh_pest_plots.log'
        if os.path.isfile(log_file):
            os.remove(log_file)
        handler = logging.FileHandler(log_file)
        self.logger.addHandler(handler)
        self.logger.setLevel(logging.INFO)
        self.logger.info('Begin logging for plots')
        self.logger.info(f'REC file: {self._rec_file}')

    def update_plots(self):
        """Update the plots."""
        if not os.path.isfile(self._rec_file):
            return

        old_plot_data = self._plot_data.copy()
        self._read_rec_file()
        self.logger.info(f'old_plot_data: {old_plot_data}')
        self.logger.info(f'plot_data: {self._plot_data}')

        connection = sqlite3.connect(self.looper.plot_db_file)
        cursor = connection.cursor()
        error_table = self.looper.get_plot_data_db_table('pest_error_plot', 'Error')[0]
        for i, item in enumerate(self._plot_data['pest_error_plot']):
            if i >= len(old_plot_data['pest_error_plot']):
                stmt = f'INSERT INTO {error_table} VALUES ({item[0]}, {item[1]})'
                self.logger.info(f'sql: {stmt}')
                cursor.execute(stmt)

        for pname in self._parameter_names:
            for i, item in enumerate(self._plot_data[pname]):
                if i >= len(old_plot_data[pname]):
                    table_name = self.looper.get_plot_data_db_table('pest_parameters', pname)[0]
                    stmt = f'INSERT INTO {table_name} VALUES ({item[0]}, {item[1]})'
                    self.logger.info(f'sql: {stmt}')
                    cursor.execute(stmt)

        connection.commit()

    def _read_rec_file(self):
        """Reads the data from the PEST rec file."""
        for key in self._plot_data.keys():
            self._plot_data[key] = []
        with open(self._rec_file, 'r') as f:
            lines = f.readlines()
            self._read_initial_values(lines)
            self._read_iteration_values(lines)

    def _read_initial_values(self, lines):
        """Read the initial values from the rec file.

        Args:
            lines (:obj:`list[str]`): lines from the file
        """
        i = 0
        done = False
        while i < len(lines) and not done:
            line = lines[i]
            i = i + 1
            if line.startswith('INITIAL CONDITIONS:'):
                done = True
        done = False
        while i < len(lines) and not done:
            line = lines[i]
            if 'Sum of squared weighted residuals (ie phi) =' in line:
                self._plot_data['pest_error_plot'].append((0, float(line.split()[-1])))
            elif 'Current parameter values' in line:
                i = i + 1
                # check if all parameter values are in the file
                if i + len(self._parameter_names) + 1 < len(lines):
                    for pname in self._parameter_names:
                        line = lines[i]
                        self._plot_data[pname].append((0, float(line.split()[-1])))
                        i = i + 1
                    done = True
            i = i + 1

    def _read_iteration_values(self, lines):
        """Read the iteration values from the rec file.

        Args:
            lines (:obj:`list[str]`): lines from the file
        """
        i = 0
        iter = -1
        while i < len(lines):
            line = lines[i]
            i = i + 1
            if line.startswith('OPTIMISATION ITERATION NO.'):
                iter = int(line.split()[-1])
            elif iter != -1:
                if 'Starting phi for this iteration:' in line:
                    self._plot_data['pest_error_plot'].append((iter, float(line.split()[-1])))
                elif 'Current parameter values' in line:
                    # check if all parameter values are in the file
                    if i + len(self._parameter_names) + 1 < len(lines):
                        for pname in self._parameter_names:
                            line = lines[i].strip()
                            self.logger.info(f'line: {line}')
                            self._plot_data[pname].append((iter, float(line.split()[1])))
                            i = i + 1
                    iter = -1
