"""HfbData class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os
import sqlite3

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.components.display import display_options_io
from xms.constraint import Grid

# 4. Local modules
from xms.mf6.data import data_util
from xms.mf6.data.list_package_data import ListPackageData
from xms.mf6.data.options_block import OptionsBlock
from xms.mf6.file_io import database
from xms.mf6.gui import gui_util, units_util
from xms.mf6.gui.options_defs import Checkbox


class HfbData(ListPackageData):
    """Data class to hold the info from a hfb package file."""
    def __init__(self, **kwargs):
        """Initializes the class.

        Args:
            **kwargs: Arbitrary keyword arguments.

        Keyword Args:
            ftype (str): The file type used in the GWF name file (e.g. 'WEL6')
            mfsim (MfsimData): The simulation.
            model (GwfData or GwtData): The GWF/GWT model. Will be None for TDIS, IMS, Exchanges (things below mfsim)
            grid_info (GridInfo): Information about the grid. Only used when testing individual packages. Otherwise,
             it comes from model and dis
        """
        super().__init__(**kwargs)
        self.ftype = 'HFB6'

    def get_column_tool_tips(self, block: str) -> dict[int, str]:
        """Returns a dict with column index and tool tip.

        Args:
            block (str): Name of the block.
        """
        names, _types, _defaults = self.get_column_info('')
        units_str = units_util.string_from_units(self, '[1/T or (L/T)/L]')
        return {
            names.index('HYDCHR'):
                f'Hydraulic characteristic of the horizontal-flow barrier. The hydraulic'
                f' characteristic is the barrier hydraulic conductivity divided by the width'
                f' of the horizontal-flow barrier {units_str}.',
        }

    def package_column_info(self, block=''):
        """Returns the column info just for the columns unique to this package.

        You should override this method.

        Returns:
            (tuple): tuple containing:
                - column_names (list): Column names.
                - types (dict of str -> type): Column names -> column types.
                - default (dict of str -> value): Column names -> default values.
        """
        return [], {}, {}

    def package_column_dict(self):
        """Returns the column info just for the columns unique to this package.

        Returns:
            (tuple): tuple containing:
                - column_names (list): Column names.
                - types (dict of str -> type): Column names -> column types.
                - default (dict of str -> value): Column names -> default values.
        """
        columns = {
            'HYDCHR': (np.float64, 0.0),
        }

        return columns

    # @overrides
    def get_column_info(self, block, use_aux=True):
        """Returns column names, types, and defaults.

        The columns depend on the DIS package in use and the AUX variables.
        The package specific and AUX columns are type object because they
        might contain time series strings.

        Args:
            block (str): Name of the list block.
            use_aux (bool): True to include AUXILIARY variables.

        Returns:
            (tuple): tuple containing:
                - column_names (list): Column names.
                - types (dict of str -> type): Column names -> column types.
                - default (dict of str -> value): Column names -> default values.
        """
        # Get two sets of cellId columns
        id_columns = data_util.get_id_column_dict(self.grid_info())
        new_id_columns = {}
        for key, value in id_columns.items():
            new_id_columns[f'{key}1'] = value
        for key, value in id_columns.items():
            new_id_columns[f'{key}2'] = value

        # Add the package specific columns
        package_columns = self.package_column_dict()
        columns = {**new_id_columns, **package_columns}

        # self.add_aux_columns(names, types, defaults)
        # self.add_boundname_columns(names, types, defaults)

        return gui_util.column_info_tuple_from_dict(columns)

    def dialog_title(self):
        """Returns the title to show in the dialog.

        Returns:
            (str): The dialog title.
        """
        return 'Horizontal Flow Barrier (HFB) Package'

    def _update_displayed_faces_cell_indices(self):
        """Updates the cell indices file used to display symbols."""
        cell_idx_tuples = set()

        db_filename = database.database_filepath(self.filename)
        with sqlite3.connect(db_filename) as conn:
            cur = conn.cursor()
            stmt = 'SELECT DISTINCT CELLIDX1, CELLIDX2 FROM data'
            cur.execute(stmt)
            rows = cur.fetchall()
            for row in rows:
                cell_idx_tuples.add((row[0], row[1]))

        # Convert set of tuples to list
        cell_idxs = []
        for item in cell_idx_tuples:
            cell_idxs.append(item[0])
            cell_idxs.append(item[1])

        self._write_displayed_cell_indices(cell_idxs, 'hfb6_faces.display_indices')

    def _update_displayed_lines_cell_indices(self, cogrid: Grid):
        """Updates the cell indices file used to display symbols."""
        ugrid = cogrid.ugrid
        lines = set()
        db_filename = database.database_filepath(self.filename)
        with sqlite3.connect(db_filename) as conn:
            cur = conn.cursor()
            stmt = 'SELECT DISTINCT CELLIDX1, CELLIDX2 FROM data'
            cur.execute(stmt)
            rows = cur.fetchall()
            for row in rows:
                index1 = row[0]
                index2 = row[1]
                face_count = ugrid.get_cell_3d_face_count(index1)
                for face in range(face_count):
                    adj_cell = ugrid.get_cell_3d_face_adjacent_cell(index1, face)
                    if adj_cell == index2:
                        points = ugrid.get_cell_3d_face_points(index1, face)
                        lines.add(points)

        filename = os.path.join(os.path.dirname(self.filename), 'hfb6_lines.display_indices')
        display_options_io.write_display_option_line_ids(filename, list(lines))

    # @overrides
    def update_displayed_cell_indices(self) -> None:
        """Updates the cell indices file used to display symbols."""
        self._update_displayed_faces_cell_indices()
        cogrid = self.model.get_cogrid()
        if cogrid is not None:
            self._update_displayed_lines_cell_indices(cogrid)

    def stress_id_columns(self):
        """Returns the column name where the id exists that can be used to help identify this stress across periods.

        Typically is 'CELLIDX' which is added by GMS but is 'RNO' for SFR.

        Returns:
            See description.
        """
        return ['CELLIDX1', 'CELLIDX2']

    # @overrides
    def _setup_options(self):
        """Returns the definition of all the available options.

        Returns:
            (OptionsBlock): See description.
        """
        return OptionsBlock([
            Checkbox('PRINT_INPUT', brief='Print input to listing file'),
        ])

    def map_info(self, feature_type):
        """Returns info needed for Map from Coverage.

        Args:
            feature_type (str): 'points', 'arcs', or 'polygons'

        Returns:
            (dict): Dict describing how to get the MODFLOW variable from the shapefile or att table fields.
        """
        if feature_type == 'arcs':
            return {'Name': None, 'HYDCHR': None}
        return {}

    def map_import_info(self, feature_type):
        """Returns info needed for mapping shapefile or transient data file to package data.

        Args:
            feature_type (str): 'points', 'arcs', or 'polygons'

        Returns:
            (dict): See description
        """
        info = {}
        if feature_type == 'arcs':
            info = {'HYDCHR': None}
        self._add_aux_and_boundname_info(info)
        return info
