"""ListPackageWriter class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import csv
import os
import sqlite3

# 2. Third party modules

# 3. Aquaveo modules
from xms.core.filesystem import filesystem as fs

# 4. Local modules
from xms.mf6.file_io import database, io_util
from xms.mf6.file_io.package_writer_base import PackageWriterBase


class ListPackageWriter(PackageWriterBase):
    """Writes a list package file."""

    COLUMN_0 = 0  # Avoids magic numbers

    def __init__(self):
        """Initializes the class."""
        super().__init__()
        self._modified_periods = []  # Stress periods that have been modified
        self._modified_period_files = []  # Files that have been modified (temp files)

    def _write_dimensions(self, fp):
        """Writes the dimensions block.

        Args:
            fp (_io.TextIOWrapper): The file.
        """
        self._update_maxbound()
        fp.write('\n')
        fp.write('BEGIN DIMENSIONS\n')
        # MODFLOW complains if MAXBOUND is less than 1, even if there are no boundary conditions
        maxbound = max(1, self._data.maxbound)
        fp.write(f'{io_util.mftab}MAXBOUND {maxbound}\n')
        if self._data.ftype == 'EVT6':
            fp.write(f'{io_util.mftab}NSEG {self._data.nseg}\n')
        fp.write('END DIMENSIONS\n')

    def _copy_temp_file(self, stress_period):
        """Copies temp file of stress period to be next to package file.

        Does nothing if a temp file does not exist.

        Args:
            stress_period (int): Stress period number (1-based).
        """
        if not self._data.period_files[stress_period]:
            return

        cur_period_file = self._data.period_files[stress_period]

        # Get the new directory
        new_period_file_dir = os.path.dirname(self._data.filename)
        if self._writer_options.dmi_sim_dir:  # Store external files next to package with DMI
            new_period_file_dir = os.path.dirname(self._data.filename)
        elif self._writer_options.open_close_dir:
            new_period_file_dir = self._writer_options.open_close_dir

        # Get the new basename
        remove_temp_file = False
        new_name = os.path.basename(cur_period_file)
        if io_util.is_temporary_file(cur_period_file):
            new_name = os.path.basename(self._compute_external_file_name(stress_period))
            remove_temp_file = True

        # Store the new path
        new_period_file = os.path.join(new_period_file_dir, new_name)
        self._data.period_files[stress_period] = new_period_file

        # Copy the file if the new and current paths are different
        fs.copyfile(cur_period_file, new_period_file)
        if remove_temp_file:
            fs.removefile(cur_period_file)

    def _get_modified_files(self):
        """Get the list of modified files.

        Returns:
            (tuple): tuple containing:
                - modified_stress_periods (list of int): Stress periods that have been modified.
                - modified_period_files (list of str): Names of files that have been modified.
        """
        # TODO: Keep track of which periods have been modified in the database so we can do less work
        using_db = self._data.periods_db and os.path.isfile(self._data.periods_db)

        modified_stress_periods = []
        modified_period_files = []
        for sp, file in self._data.period_files.items():
            if using_db or (file and io_util.is_temporary_file(file)):
                modified_period_files.append(file)
                modified_stress_periods.append(sp)
        return modified_stress_periods, modified_period_files,

    def max_db_period_rows(self, modified_periods):
        """Returns the maximum period rows."""
        stmt = 'SELECT COUNT(id) FROM data GROUP BY PERIOD'
        max_period_rows = 0
        with sqlite3.connect(self._data.periods_db) as conn:
            cur = conn.cursor()
            cur.execute(stmt)
            results = cur.fetchall()
            # fetchall returns a list of tuples, but we don't want tuples
            column0 = [r[self.COLUMN_0] for r in results]
            max_period_rows = max(column0) if column0 else 0

        return max_period_rows

    def _update_maxbound(self):
        """Updates self._data.maxbound if necessary.

        Returns:
            (bool): True if self._data.maxbound was updated, otherwise FALSE
        """
        # We only need to update MAXBOUND if there are modified stress periods.

        if self._modified_period_files:
            if self._data.periods_db and os.path.isfile(self._data.periods_db):
                self._data.maxbound = self.max_db_period_rows(self._modified_periods)
                return True
        return False

    def _write_period_files_from_db(self):
        """Writes the external period files using the periods.db."""
        _, column_names, orig_column_count = database.get_create_strings(self._data)
        column_names = column_names[:orig_column_count]
        column_str = ', '.join(column_names)
        if os.path.exists(self._data.periods_db):
            with sqlite3.connect(self._data.periods_db) as conn:
                cur = conn.cursor()
                for period in self._data.period_files.keys():
                    temp_filename = io_util.get_temp_filename(suffix='.mf6_tmp')
                    self._data.period_files[period] = temp_filename
                    with open(temp_filename, 'w', newline='') as file:
                        csv_writer = csv.writer(file, delimiter=' ')
                        stmt = f'SELECT {column_str} FROM data WHERE PERIOD = {period} ORDER BY PERIOD_ROW'
                        cur.execute(stmt)
                        rows = cur.fetchall()
                        csv_writer.writerows(rows)

    def _write_stress_periods(self, fp):
        """Writes all the stress period blocks.

        Args:
            fp (_io.TextIOWrapper): The file.
        """
        if self._writer_options.use_periods_db:
            self._write_period_files_from_db()

        for sp in sorted(self._data.period_files.keys()):
            if fp:
                fp.write('\n')
                fp.write('BEGIN PERIOD {}\n'.format(sp))
            if self._writer_options.use_open_close:
                self._copy_temp_file(sp)
                external_file_name = self._data.period_files[sp]
                if external_file_name:
                    relative_filename = fs.compute_relative_path(self._writer_options.mfsim_dir, external_file_name)
                    if fp:
                        fp.write(f'{io_util.mftab}OPEN/CLOSE {io_util.quote(relative_filename)}\n')
            else:
                if self._data.period_files[sp] and fp:
                    io_util.write_file_internal(fp, self._data.period_files[sp])
            if fp:
                fp.write('END PERIOD\n')

        # Build the periods.db database
        if self._writer_options.dmi_sim_dir:
            database.build(self._data)

    def _write_package(self, data):
        """Writes the package file.

        Saves stress period data in external files using the OPEN/CLOSE option.
        """
        self._data = data
        self._modified_periods, self._modified_period_files = self._get_modified_files()
        with open(self._data.filename, 'w') as fp:
            self._write_comments(fp)
            self._write_options(fp)
            self._write_dimensions(fp)
            self._write_stress_periods(fp)
