"""SfrData class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os
from pathlib import Path

# 2. Third party modules
import numpy as np
import pandas as pd
from typing_extensions import override

# 3. Aquaveo modules
from xms.core.filesystem import filesystem as fs

# 4. Local modules
from xms.mf6.data import data_util
from xms.mf6.data.list_package_data import ListPackageData
from xms.mf6.data.options_block import OptionsBlock
from xms.mf6.file_io import io_util
from xms.mf6.gui import gui_util, units_util
from xms.mf6.gui.options_defs import Checkbox, CheckboxButton, CheckboxField


def _file_has_a_comma(filepath: str | Path) -> bool:
    """Returns true if the file includes a ','.

    Args:
        filepath (str|Path): The filepath.

    Returns:
        (bool): See description.
    """
    if not filepath:
        return False
    path = Path(filepath)
    with path.open('r') as file:
        if ',' in file.read():
            return True
    return False


class SfrData(ListPackageData):
    """Data class to hold the info from an SFR6 package file."""
    def __init__(self, **kwargs):
        """Initializes the class.

        Args:
            **kwargs: Arbitrary keyword arguments.

        Keyword Args:
            ftype (str): The file type used in the GWF name file (e.g. 'WEL6')
            mfsim (MfsimData): The simulation.
            model (GwfData or GwtData): The GWF/GWT model. Will be None for TDIS, IMS, Exchanges (things below mfsim)
            grid_info (GridInfo): Information about the grid. Only used when testing individual packages. Otherwise,
             it comes from model and dis
        """
        super().__init__(**kwargs)
        self.ftype = 'SFR6'
        self.block_with_cellids = 'PACKAGEDATA'
        self.list_blocks = {'PACKAGEDATA': '', 'CONNECTIONDATA': '', 'DIVERSIONS': ''}  # (list blocks -> filename)

    # @overrides
    def get_column_delegate_info(self, block):
        """Returns a list of tuples of [0] column index and [1] list of strings.

        Returns:
            (tuple): tuple containing:
                - index (int): Column index.
                - strings (list of str): List of strings allowed in the column.
        """
        delegate_info = None
        names, _, _ = self.get_column_info(block)
        if block == 'DIVERSIONS':
            delegate_info = [(names.index('CPRIOR'), ['FRACTION', 'EXCESS', 'THRESHOLD', 'UPTO'])]
        elif not block or block == 'PERIODS':
            delegate_info = [
                (
                    names.index('SFRSETTING'), [
                        'STATUS', 'MANNING', 'STAGE', 'INFLOW', 'RAINFALL', 'EVAPORATION', 'RUNOFF', 'DIVERSION',
                        'UPSTREAM_FRACTION', 'AUXILIARY'
                    ]
                )
            ]

        return delegate_info

    def get_column_max(self, column_name, cellid_column_count):
        """Returns the maximum value in the ncon column of packagedata.

        Args:
            column_name (str): Name of the column to get the maximum value from e.g. 'NCON'
            cellid_column_count (int): Number of columns used for the cell_id

        Returns:
            (int): The maximum value in the ncon column of packagedata.
        """
        # df = self.read_file_into_dataframe(self.list_blocks['PACKAGEDATA'], 'PACKAGEDATA')
        # if df is not None:
        #     return df['NCON'].max()
        # return 1
        # Can't read into dataframe because there may be a 'NONE' in the cellid column.

        max_value = 0
        filename = self.list_blocks.get('PACKAGEDATA')
        if filename and os.path.isfile(filename):
            column_names, _, _ = self.get_column_info('PACKAGEDATA')
            column_index = column_names.index(column_name)
            with open(self.list_blocks['PACKAGEDATA'], 'r') as file:
                for line in file:
                    words = line.split(',')
                    # TODO not supported yet
                    # if 'NONE' in line.upper():
                    #     value = int(words[column_index - cellid_column_count + 1])
                    # else:
                    #     value = int(words[column_index])
                    value = int(words[column_index])
                    max_value = max(max_value, value)
        return max_value

    # @overrides
    def get_column_info(self, block, use_aux=True):
        """Returns column names, types, and defaults.

        The columns depend on the DIS package in use and the AUX variables.
        The package specific and AUX columns are type object because they
        might contain time series strings.

        Args:
            block (str): Name of the list block.
            use_aux (bool): True to include AUXILIARY variables.

        Returns:
            (tuple): tuple containing:
                - column_names (list): Column names.
                - types (dict of str -> type): Column names -> column types.
                - default (dict of str -> value): Column names -> default values.
        """
        if block.upper() == 'PACKAGEDATA':
            columns = {
                'RNO': (np.int32, 1),
                'CELLID': (object, ''),  # CELLID is 1 field here because it can be NONE
                'RLEN': (np.float64, 0.0),
                'RWID': (np.float64, 0.0),
                'RGRD': (np.float64, 0.0),
                'RTP': (np.float64, 0.0),
                'RBTH': (np.float64, 0.0),
                'RHK': (np.float64, 0.0),
                'MAN': (object, '.03'),
                'NCON': (np.int32, 1),
                'USTRF': (object, '0.0'),
                'NDV': (np.int32, 1),
            }
            self.add_aux_columns_to_dict(columns, use_aux=use_aux)
            data_util.add_boundname_columns_to_dict(self.options_block, columns)
        elif block.upper() == 'CONNECTIONDATA':
            columns = {
                'RNO': (np.int32, 1),
                'IC': (object, ''),
            }
        elif block.upper() == 'DIVERSIONS':
            columns = {
                'RNO': (np.int32, 1),
                'IDV': (np.int32, 1),
                'ICONR': (np.int32, 1),
                'CPRIOR': (object, ''),
            }

        else:  # This would be the stress periods
            return self.package_column_info()

        names, types, defaults = gui_util.column_info_tuple_from_dict(columns)
        return names, types, defaults

    def get_column_tool_tips(self, block: str) -> dict[int, str]:
        """Returns a dict with column index and tool tip.

        Args:
            block (str): Name of the block.
        """
        names, _types, _defaults = self.get_column_info(block)
        if block.upper() == 'PACKAGEDATA':
            length_units = units_util.string_from_units(self, units_util.UNITS_LENGTH)
            cond_units = units_util.string_from_units(self, units_util.UNITS_COND)
            cellid_tt = (
                'Cell identifier. Depends on the type of grid that is used for the simulation. For a structured grid'
                ' that uses the DIS input file, CELLID is the layer, row, and column. For a grid that uses the DISV'
                ' input file, CELLID is the layer and CELL2D number. If the model uses the unstructured discretization'
                ' (DISU) input file, CELLID is the node number for the cell.'
            )
            cellid_reach_tt = (
                'For reaches that are not connected to an underlying GWF cell, a zero should be specified for each grid'
                ' dimension. For example, for a DIS grid a CELLID of 0 0 0 should be specified.'
            )
            return {
                names.index('RNO'): 'Reach number',
                names.index('CELLID'): cellid_tt + ' ' + cellid_reach_tt,
                names.index('RLEN'): f'Reach length {length_units}',
                names.index('RWID'): f'Reach width {length_units}',
                names.index('RGRD'): 'Stream gradient (slope) across the reach',
                names.index('RTP'): f'Bottom elevation of the reach {length_units}',
                names.index('RBTH'): f'Thickness of the reach streambed {length_units}',
                names.index('RHK'): f'Hydraulic conductivity of the reach streambed {cond_units}',
                names.index('MAN'): 'Manning’s roughness coefficient for the reach',
                names.index('NCON'): 'Number of reaches connected to the reach',
                names.index('USTRF'):
                    'Fraction of upstream flow from each upstream reach that is applied as'
                    ' upstream inflow to the reach',
                names.index('NDV'): 'Number of downstream diversions for the reach',
            }
        elif block.upper() == 'CROSSSECTIONS':
            return {
                names.index('RNO'): 'Reach number',
            }
        elif block.upper() == 'CONNECTIONDATA':
            return {
                names.index('RNO'): 'Reach number',
                names.index('IC'):
                    'Reach number of the reach connected to the current reach and whether it is'
                    ' connected to the upstream or downstream end of the reach',
            }
        elif block.upper() == 'DIVERSIONS':
            return {
                names.index('RNO'): 'Reach number',
                names.index('IDV'): 'Downstream diversion number for the diversion for reach RNO',
                names.index('ICONR'): 'Downstream reach that will receive the diverted water',
                names.index('CPRIOR'): 'Prioritization system for the diversion',
            }
        else:  # stress periods
            return {
                names.index('RNO'): 'Reach number',
                names.index('SFRSETTING'): 'Keyword to start SFR setting line',
            }

    def package_column_info(self, block=''):
        """Returns the column info just for the columns unique to this package.

        You should override this method.

        Returns:
            (tuple): tuple containing:
                - column_names (list): Column names.
                - types (dict of str -> type): Column names -> column types.
                - default (dict of str -> value): Column names -> default values.
        """
        columns = {
            'RNO': (np.int32, 1),
            'SFRSETTING': (object, ''),
            'VALUE1': (object, ''),
            'VALUE2': (object, ''),
        }

        names, types, defaults = gui_util.column_info_tuple_from_dict(columns)
        return names, types, defaults

    def dialog_title(self):
        """Returns the title to show in the dialog.

        Returns:
            (str): The dialog title.
        """
        return 'Streamflow Routing (SFR) Package'

    @override
    def get_time_series_columns(self) -> list[int]:
        """Returns a list of the column indices that can contain time series.

        Returns:
            List of indices of columns that can contain time series.
        """
        names, _, _ = self.get_column_info('PERIODS')
        ts_columns = {'VALUE1', 'VALUE2'}
        return [i for i, name in enumerate(names) if name in ts_columns]

    def block_with_aux(self):
        """Returns the name of the block that can have aux variables.

        Returns:
            (str): The name of the block that can have aux variables.
        """
        return 'PACKAGEDATA'

    def block_with_boundnames(self):
        """Returns the name of the block that can have aux variables.

        Returns:
            (str): The name of the block that can have aux variables.
        """
        return 'PACKAGEDATA'

    # @overrides
    def update_displayed_cell_indices(self) -> None:
        """Updates the cell indices file used to display symbols."""
        self._update_displayed_cell_indices_in_block('PACKAGEDATA')

    def stress_id_columns(self):
        """Returns the column name where the id exists that can be used to help identify this stress across periods.

        Typically is 'CELLIDX' which is added by GMS but is 'RNO' for SFR.

        Returns:
            See description.
        """
        return ['RNO']

    def plottable_columns(self):
        """Returns a set of columns (0-based) that can be plotted with the XySeriesEditor.

        Returns:
            See description.
        """
        column_count = len(self.get_column_info('')[0])
        return set(range(2, column_count))  # Start after the SFRSETTING column

    def _rewrite_packagedata(self, filename):
        """Rewrite the packagedata file to use commas instead of spaces to handle NONE in CELLID."""
        from xms.mf6.file_io import package_reader_base  # Do this here or you get circular imports
        list_lines = package_reader_base.read_external_list(filename)
        from xms.mf6.file_io.gwf.sfr_reader import SfrReader
        reader = SfrReader()
        reader._data = self
        reader._curr_block_name = 'PACKAGEDATA'
        new_filename = reader._list_to_external_file(list_lines)
        fs.copyfile(new_filename, filename)

    # @overrides
    def read_csv_file_into_dataframe(self, block, filename, column_names, column_types):
        """Reads a csv file and returns a dataframe.

        Overridden to handle CONNECTIONDATA which is challenging. The IC numbers are turned into a space delimited list
        in one IC column.

        Args:
            block (str): Name of the block.
            filename (str): Filepath.
            column_names (list): Column names.
            types (dict of str -> type): Column names -> column types.

        Returns:
            (pandas.DataFrame)
        """
        if block.upper() == 'PACKAGEDATA':
            # Use ',' for separator because CELLID is one field because it can be NONE
            if filename and not _file_has_a_comma(filename):
                self._rewrite_packagedata(filename)
            usecols = [i for i in range(len(column_names))]
            return gui_util.read_csv_file_into_dataframe(filename, column_names, column_types, usecols, separator=',')
        elif block.upper() == 'DIVERSIONS':
            return gui_util.read_csv_file_into_dataframe(filename, column_names, column_types)
        elif block.upper() == 'CONNECTIONDATA':
            # Get max ncon from PACKAGEDATA so we know how many IC numbers are on a line.
            cell_id_column_count = self.grid_info().cell_id_column_count()
            max_ncon = self.get_column_max('NCON', cell_id_column_count)

            reaches = []
            ics = []
            if filename and os.path.isfile(filename):
                with open(filename, 'r') as file:
                    for line in file:
                        words = line.split()
                        if words and len(words) > 0:
                            reaches.append(words[0])
                            # Ignore any word after max_ncon. May be a comment or who knows, but isn't an IC number
                            last_ic = min(len(words), max_ncon + 1)
                            ics.append(' '.join(words[1:last_ic]))

                data = {
                    'RNO': reaches,
                    'IC': ics,
                }
                df = pd.DataFrame(data)
                df = df.astype(dtype={'RNO': np.int32, 'IC': object})
                df.index += 1
                return df
            else:
                return gui_util.empty_dataframe(column_names, list(column_types.values()), index=None)

    # @overrides
    def dataframe_to_temp_file(self, block, dataframe):
        """Writes the dataframe to a csv file.

        Overridden because CONNECTIONDATA is more complicated due to IC column being a space separated list.

        Args:
            block (str): Name of the block.
            dataframe (pandas.DataFrame): The dataframe

        Returns:
            (str): Filepath of file created.
        """
        if block.upper() == 'PACKAGEDATA':
            # Use ',' for separator because CELLID is one field because it can be NONE
            return gui_util.dataframe_to_temp_file(dataframe, separator=',')
        elif block.upper() == 'DIVERSIONS':
            return gui_util.dataframe_to_temp_file(dataframe)
        elif block.upper() == 'CONNECTIONDATA':
            # Write to a string so we can strip the quotes before dumping to file
            csv_string = gui_util.dataframe_to_csv(dataframe, '')
            csv_string = csv_string.replace("'", '')  # Remove quote chars

            # Can't just dump csv_string - we get too many newlines. Have to do all this.
            lines = csv_string.splitlines()
            return io_util.write_lines_to_temp_file(lines)

    def map_info(self, feature_type):
        """Returns info needed for Map from Coverage.

        Args:
            feature_type (str): 'points', 'arcs', or 'polygons'

        Returns:
            (dict): Dict describing how to get the MODFLOW variable from the shapefile or att table fields.
        """
        if feature_type == 'arcs':
            return {
                'Name': None,
                'FLOW': None,
                'RUNOFF': None,
                'ETSW': None,
                'PPTSW': None,
                'ROUGHCH': None,
                'ROUGHBK': None,
                'CDPTH': None,
                'FDPTH': None,
                'AWDTH': None,
                'BWDTH': None,
                'HCOND1': None,
                'THICKM1': None,
                'ELEVUP': None,
                'WIDTH1': None,
                'DEPTH1': None,
                'HCOND2': None,
                'THICKM2': None,
                'ELEVDN': None,
                'WIDTH2': None,
                'DEPTH2': None,
            }
        else:
            return {}

    def map_import_info(self, feature_type):
        """Returns info needed for mapping shapefile or transient data file to package data.

        Args:
            feature_type (str): 'points', 'arcs', or 'polygons'

        Returns:
            (dict): See description
        """
        return None

    @override
    def _setup_options(self) -> OptionsBlock:
        """Returns the definition of all the available options.

        Returns:
            See description.
        """
        return OptionsBlock(
            [
                Checkbox('STORAGE', brief='Activates storage contributions to the continuity equation'),
                CheckboxButton(
                    'AUXILIARY',
                    button_text='Auxiliary Variables...',
                    check_box_method='on_chk_auxiliary',
                    button_method='on_btn_auxiliary'
                ),
                Checkbox('BOUNDNAMES', brief='Allow boundary names', check_box_method='on_chk_boundnames'),
                Checkbox('PRINT_INPUT', brief='Print input to listing file'),
                Checkbox('PRINT_STAGE', brief='Print stage to listing file'),
                Checkbox('PRINT_FLOWS', brief='Print flows to listing file'),
                Checkbox('SAVE_FLOWS', brief='Save flows to budget file'),
                CheckboxField('STAGE FILEOUT', brief='Save stage to file', type_='str'),
                CheckboxField('BUDGET FILEOUT', brief='Save budget to file', type_='str'),
                CheckboxField('BUDGETCSV FILEOUT', brief='Save budget to CSV file', type_='str'),
                CheckboxField('PACKAGE_CONVERGENCE FILEOUT', brief='Save convergence info to file', type_='str'),
                CheckboxButton(
                    'TS6 FILEIN', brief='Time-series files', button_text='Files...', button_method='on_btn_ts6_filein'
                ),
                CheckboxButton(
                    'OBS6 FILEIN',
                    brief='Observation files',
                    button_text='Files...',
                    button_method='on_btn_obs6_filein'
                ),
                Checkbox('MOVER', brief='Can be used with the Water Mover (MVR) Package'),
                CheckboxField(
                    'MAXIMUM_PICARD_ITERATIONS',
                    brief='Max picard iterations allowed when solving for reach stages and flows',
                    type_='int',
                    value=100
                ),
                CheckboxField(
                    'MAXIMUM_ITERATIONS',
                    brief='Max Newton-Raphson iterations allowed for a reach',
                    type_='int',
                    value=100
                ),
                CheckboxField('MAXIMUM_DEPTH_CHANGE', brief='The depth closure tolerance', type_='float', value=1e-5),
                CheckboxField(
                    'LENGTH_CONVERSION',
                    brief='Value to convert user-specified Manning’s roughness'
                    ' coefficients from meters to model length units',
                    type_='float',
                    value=1.0
                ),
                CheckboxField(
                    'TIME_CONVERSION',
                    brief='Value to convert user-specified Manning’s roughness'
                    ' coefficients from seconds to model time units',
                    type_='float',
                    value=1.0
                ),
            ]
        )
