"""ListPackageData class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os
from pathlib import Path
import sqlite3
from typing import Iterable

# 2. Third party modules
import pandas as pd

# 3. Aquaveo modules
from xms.components.display import display_options_io

# 4. Local modules
from xms.mf6.data import data_util
from xms.mf6.data.base_file_data import BaseFileData
from xms.mf6.file_io import database
from xms.mf6.gui import gui_util


class ListPackageData(BaseFileData):
    """Data class to hold the info from a 'list' package (CHD, DRN, EVT, GHB, HFB, RCH, RIV, WEL) file."""
    def __init__(self, **kwargs):
        """Initializes the class.

        Args:
            **kwargs: Arbitrary keyword arguments.

        Keyword Args:
            ftype (str): The file type used in the GWF name file (e.g. 'WEL6')
            mfsim (MfsimData): The simulation.
            model (GwfData or GwtData): The GWF/GWT model. Will be None for TDIS, IMS, Exchanges (things below mfsim)
            grid_info (GridInfo): Information about the grid. Only used when testing individual packages. Otherwise,
             it comes from model and dis
        """
        super().__init__(**kwargs)

        # Misc
        self.maxbound = 0  # max number of items in stress period
        self.period_files = {}  # Dict of 1-based stress periods and the paths to their external files
        self.list_blocks = {}  # Dict of list blocks -> filename in the package file
        self.block_with_cellids = 'PERIODS'

    # @overrides
    def get_column_delegate_info(self, block):
        """Returns a list of tuples of [0] column index and [1] list of strings."""
        return None

    def get_id_column_info(self):
        """Returns cell id column info (lay, row, col etc) based on dis package.

        Returns:
            (tuple): tuple containing:
                - column_names (list): Column names.
                - types (dict of str -> type): Column names -> column types.
                - default (dict of str -> value): Column names -> default values.
        """
        id_columns = data_util.get_id_column_dict(self.grid_info())
        return gui_util.column_info_tuple_from_dict(id_columns)

    def add_aux_columns_to_dict(self, columns, use_aux=True):
        """If the AUXILIARY option is on, adds the auxiliary column to the dict.

        Args:
            columns (dict of str -> type): Column names -> column types.
            use_aux (bool): True to include aux.
        """
        if use_aux:
            auxiliary_list = self.options_block.get('AUXILIARY', [])
            if auxiliary_list:
                for aux in auxiliary_list:
                    columns[aux] = (object, 0.0)

    def add_aux_columns(self, names, types, defaults):
        """If the AUXILIARY option is on, adds the auxiliary columns.

        Args:
            names (list): Column names.
            types (dict of str -> type): Column names -> column types.
            defaults (dict of str -> value): Column names -> default values.
        """
        auxiliary_list = self.options_block.get('AUXILIARY', [])
        if auxiliary_list:
            for aux in auxiliary_list:
                names.append(aux)
                types[aux] = object
                defaults[aux] = 0.0

    def add_boundname_columns(self, names, types, defaults):
        """If the BOUNDNAMES option is on, adds the BOUNDNAME column.

        Args:
            names (list): Column names.
            types (dict of str -> type): Column names -> column types.
            defaults (dict of str -> value): Column names -> default values.
        """
        # Add boundname
        if self.options_block.has('BOUNDNAMES'):
            names.append('BOUNDNAME')
            types['BOUNDNAME'] = object
            defaults['BOUNDNAME'] = ''

    def _add_aux_and_boundname_info(self, info):
        # Add aux and BOUNDNAME columns
        auxiliary_list = self.options_block.get('AUXILIARY', [])
        info.update({aux: None for aux in auxiliary_list})
        if self.options_block.has('BOUNDNAMES'):
            info['Name'] = None

    # @overrides
    def get_column_info(self, block, use_aux=True):
        """Returns column names, types, and defaults.

        The columns depend on the DIS package in use and the AUX variables.
        The package specific and AUX columns are type object because they
        might contain time series strings.

        Args:
            block (str): Name of the list block.
            use_aux (bool): True to include AUXILIARY variables.

        Returns:
            (tuple): tuple containing:
                - column_names (list): Column names.
                - types (dict of str -> type): Column names -> column types.
                - default (dict of str -> value): Column names -> default values.
        """
        names, types, default = self.get_id_column_info()

        # Add the package specific columns
        package_names, package_types, package_default = self.package_column_info()
        names = names + package_names
        types = {**types, **package_types}
        defaults = {**default, **package_default}

        if use_aux:
            self.add_aux_columns(names, types, defaults)
        self.add_boundname_columns(names, types, defaults)

        return names, types, defaults

    def get_column_tool_tips(self, block: str) -> dict[int, str]:
        """Returns a dict with column index and tool tip.

        Args:
            block (str): Name of the block.
        """
        return {}

    def package_column_info(self, block=''):
        """Returns the column info just for the columns unique to this package.

        You should override this method.

        Args:
            block (str): Name of the block.

        Returns:
            (tuple): tuple containing:
                - column_names (list): Column names.
                - types (dict of str -> type): Column names -> column types.
                - default (dict of str -> value): Column names -> default values.
        """
        raise NotImplementedError()
        # return [], {}, {}

    def dialog_title(self):
        """Returns the title to show in the dialog.

        You should override this method.

        Returns:
            (str): The dialog title.
        """
        raise NotImplementedError()
        # return ''

    def get_time_series_columns(self) -> list[int]:
        """Returns a list of the column indices that can contain time series.

        Returns:
            List of indices of columns that can contain time series.
        """
        id_columns, _, _ = self.get_id_column_info()
        package_columns, _, _ = self.package_column_info()
        start = len(id_columns)
        auxiliary_list = self.options_block.get('AUXILIARY', [])
        naux = len(auxiliary_list) if auxiliary_list else 0
        end = len(id_columns) + len(package_columns) + naux
        time_series_columns = list(range(start, end))
        return time_series_columns

    def read_file_into_dataframe(self, filename, block):
        """Reads the file into a dataframe and returns the dataframe.

        Returns:
            (DataFrame): The dataframe.
        """
        column_names, column_types, _ = self.get_column_info(block)
        return self.read_csv_file_into_dataframe(block, filename, column_names, column_types)

    def block_with_aux(self):
        """Returns the name of the block that can have aux variables.

        Returns:
            (str): The name of the block that can have aux variables.
        """
        return ''

    def block_with_boundnames(self):
        """Returns the name of the block that can have boundnames.

        Returns:
            (str): The name of the block that can have boundnames.
        """
        return ''

    def _get_displayed_cell_indices_for_block(self, filename, block):
        """Returns the cell indices found in the filename / block.

        Args:
            filename (str): Filename of a csv file.
            block (str): Name of block.
        """
        cell_idxs = set()
        grid_info = self.grid_info()
        if grid_info:
            df = self.read_file_into_dataframe(filename, block)
            cell_idxs = grid_info.cell_indexes_from_dataframe(df)
        return cell_idxs

    def _write_displayed_cell_indices(self, cell_idxs: Iterable[int], filename: str) -> None:
        """Writes the cell indices to a file.

        Args:
            cell_idxs: List or set of cell indices (0-based) where symbols should be displayed.
            filename: The filepath.
        """
        filename = os.path.join(os.path.dirname(self.filename), filename)
        # Remove negative cell indices, which may exist if user specified a lay, row, col outside grid bounds
        cell_idxs_non_negative = [cell_idx for cell_idx in cell_idxs if cell_idx >= 0]
        display_options_io.write_display_option_ids(filename, cell_idxs_non_negative)

    def _update_displayed_cell_indices_in_block(self, block):
        """Updates the cell indices file used to display symbols.

        Args:
            block (str): Name of block containing cell ids.
        """
        if block in self.list_blocks:
            cell_idxs = self._get_displayed_cell_indices_for_block(self.list_blocks[block], block)
            self._write_displayed_cell_indices(cell_idxs, 'cell.display_indices')

    # @overrides
    def update_displayed_cell_indices(self):
        """Updates the cell indices file used to display symbols."""
        if self.ftype not in BaseFileData.displayable_ftypes():
            return

        cell_idxs = set()
        db_filename = database.database_filepath(self.filename)
        with sqlite3.connect(db_filename) as conn:
            cur = conn.cursor()
            stmt = 'SELECT DISTINCT CELLIDX FROM data'
            cur.execute(stmt)
            rows = cur.fetchall()
            for row in rows:
                cell_idxs.add(row[0])

        self._write_displayed_cell_indices(cell_idxs, 'cell.display_indices')

    def get_period_df(self, sp: int, periods_db: str | Path = None) -> pd.DataFrame:
        """Returns the period data at the specified stress period as a dataframe.

        Args:
            sp: The 1-based stress period. If -1, all stress periods are returned.
            periods_db: File path of periods database file. If None, self.periods_db is used.

        Returns:
            See description.
        """
        df = None
        db_filename = periods_db if periods_db is not None else self.periods_db
        if db_filename and Path(db_filename).exists() and Path(db_filename).stat().st_size != 0:
            with sqlite3.connect(db_filename) as conn:
                if sp == -1:
                    df = pd.read_sql_query('SELECT * FROM data', conn)
                else:
                    df = pd.read_sql_query('SELECT * FROM data WHERE PERIOD = ?', conn, params=[sp])
        elif sp in self.period_files:
            df = self.read_file_into_dataframe(filename=self.period_files[sp], block='')
        return df

    def copy_period(self, from_sp: int, to_sp: int, cxn: sqlite3.Connection | None) -> None:
        """Copies a stress period from from_sp to to_sp (1-based).

        Args:
            from_sp: The source stress period (1-based).
            to_sp: The destination period (1-based).
            cxn: If provided, an existing connection. Otherwise, a local one is used.
        """
        if cxn:
            _copy_period_sql(from_sp, to_sp, cxn)
        else:
            if not (db_filename := database.filepath_from_package(self)):
                return
            with sqlite3.connect(db_filename) as cxn:
                _copy_period_sql(from_sp, to_sp, cxn)
        self.period_files[to_sp] = ''

    def fill_missing_periods(self, sp_count: int, cxn: sqlite3.Connection | None) -> None:
        """Fills in missing stress periods (where an undefined period means use the previously defined one).

        Args:
            sp_count: Number of stress periods.
            cxn: If provided, an existing connection. Otherwise, a local one is used.
        """
        prev_period = None  # 1-based stress period number
        for sp_idx in range(sp_count):
            sp = sp_idx + 1
            if sp in self.period_files:
                if self.period_files[sp]:  # Stress period is defined and not empty
                    prev_period = sp
                else:  # Stress period is defined but empty. Reset prev_period
                    prev_period = None
            elif prev_period is not None:
                # Fill in undefined periods by copying from the last defined period
                for i in range(prev_period + 1, sp + 1):
                    self.copy_period(prev_period, i, cxn)
                prev_period = sp

    def stress_id_columns(self):
        """Returns the column name where the id exists that can be used to help identify this stress across periods.

        Typically is 'CELLIDX' which is added by GMS but is 'RNO' for SFR.

        Returns:
            See description.
        """
        return ['CELLIDX']

    def plottable_columns(self):
        """Returns a set of columns (0-based) that can be plotted with the XySeriesEditor.

        Returns:
            See description.
        """
        column_count = len(self.get_column_info('')[0])
        id_column_count = len(self.get_id_column_info()[0])
        return set(range(id_column_count, column_count))

    def map_info(self, feature_type):
        """Returns info needed for Map from Coverage.

        Args:
            feature_type (str): 'points', 'arcs', or 'polygons'

        Returns:
            (dict): Dict describing how to get the MODFLOW variable from the shapefile or att table fields.
        """
        return {}

    def append_period_data(self, period, rows):
        """Appends the data to the exiting period info.

        Args:
            period (int): Stress period.
            rows (list): List of lists consisting of rows of data.
        """
        database.append_period_data(self, period, rows)

    def bcs_with_cellgrp_exist(self, cellgrp):
        """Returns True if there are boundary conditions with the given cellgrp.

        Args:
            cellgrp (int): CELLGRP.

        Returns:
            (bool): See description.
        """
        aux = self.options_block.get('AUXILIARY', [])
        if not aux or 'CELLGRP' not in aux:
            return False

        db_filename = database.database_filepath(self.filename)
        with sqlite3.connect(db_filename) as conn:
            cur = conn.cursor()
            stmt = 'SELECT COUNT(*) FROM data WHERE CELLGRP = ?'
            cur.execute(stmt, (cellgrp, ))
            rv = cur.fetchone()
            return rv[0] > 0


def _db_column_names(cxn: sqlite3.Connection) -> list[str]:
    """Returns a list of column names for the 'data' table.

    Args:
        cxn: An sqlite3 connection.

    Returns:
        See description.
    """
    cursor = cxn.execute('SELECT * FROM data')
    column_names = [description[0] for description in cursor.description]
    return column_names


def _copy_period_sql(from_sp: int, to_sp: int, cxn: sqlite3.Connection):
    """SQL query to copy a stress period.

    Args:
        cxn: An existing connection.
        from_sp: The source stress period (1-based).
        to_sp: The destination period (1-based).
    """
    column_names = _db_column_names(cxn)
    column_names.pop()  # Remove the id column
    column_str = ', '.join(column_names)
    column_names[column_names.index('PERIOD')] = f'"{to_sp}"'
    new_column_str = ', '.join(column_names)
    stmt = (
        f'INSERT INTO data ({column_str}) SELECT {new_column_str} FROM data WHERE PERIOD = ?'
        f' ORDER BY PERIOD_ROW'
    )
    cxn.execute(stmt, (from_sp, ))
