"""MawPackageBuilder class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules
import pandas as pd
import shapefile  # From pyshp

# 3. Aquaveo modules

# 4. Local modules
from xms.mf6.file_io import io_util
from xms.mf6.gui import gui_util
from xms.mf6.mapping.cell_adder_wel import CellAdderWel
from xms.mf6.mapping.package_builder import PackageBuilder
from xms.mf6.mapping.package_builder_base import (
    add_list_package_period_data, att_type_from_ftype, cell_active_and_no_chd, PackageBuilderInputs, should_skip_record
)


class MawPackageBuilder(PackageBuilder):
    """Builds a package during map from coverage."""
    def __init__(self, inputs: PackageBuilderInputs):
        """Initializes the class.

        Args:
            inputs: Everything needed to build the package.
        """
        super().__init__(inputs)
        self._cxn_df = None  # CONNECTIONDATA table
        self._pkg_df = None  # PACKAGEDATA table
        self._reader = None  # Shapefile reader
        self._records = None  # Shapefile records
        self._period_rows = {}  # dict of period -> mawsetting rows
        self._feature_type = 'points'
        self._att_type = att_type_from_ftype('MAW6')
        self._tops = []  # cell top elevations
        self._bottoms = []  # cell bottom elevations
        self._well_records: dict[int, int] = {}  # wellno -> record index (shape index)

    def _build_package(self, feature_type: str) -> None:  # overrides
        """Builds the package.

        Args:
            feature_type (str): 'points', 'arcs', or 'polygons'.
        """
        if feature_type == 'points' and self._ix_info.points:
            shapefile_name, ix_recs = self._more_set_up()
            with shapefile.Reader(shapefile_name) as self._reader:
                self._records = self._reader.records()
                self._build_pkg_and_cxn_data(ix_recs)
                self._build_periods()

    def _more_set_up(self):
        """Do more stuff to get set up."""
        feature_type = 'points'
        shapefile_name, att_type, ix_recs, map_info = self._get_initial_stuff(feature_type)
        self.trans_data = self._read_transient_data(list(map_info.keys()), shapefile_name)
        self._tops = self.cogrid.get_cell_tops()
        self._bottoms = self.cogrid.get_cell_bottoms()
        return shapefile_name, ix_recs

    def _build_pkg_and_cxn_data(self, ix_recs) -> None:
        """Builds the PACKAGEDATA table and the CONNECTIONDATA table.

        Args:
            ix_recs: Intersection info.
        """
        cxn_rows = []
        pkg_rows = []
        cell_adder = CellAdderWel(self, self._package, self._tops, self._bottoms)
        wellno = self._starting_wellno()
        for shape_idx, record in enumerate(self._records):
            if should_skip_record(ix_recs[shape_idx], record, self._att_type):
                continue  # Shape intersected no cells or the feature att type doesn't match. Skip it.

            if not isinstance(record, dict):
                record = record.as_dict()
            self._shape = self._reader.shape(shape_idx)  # self._shape is used by self.group() in base class

            # Iterate through intersected cells
            icon = 0  # GWF connection number, 1 to ngwfnodes (becomes same as ngwfnodes after adding all cxns)
            cell_idxs = ix_recs[shape_idx].cellids.tolist()
            for cell_idx in cell_idxs:
                # Get screened intervals for cells intersected by well screen or layer range
                intervals = cell_adder.compute_cell_intervals(cell_idx, record)
                top_cellidx, bottom_cellidx = -1, -1  # indexes of top and bottom cells for this well
                for cellidx in intervals.keys():
                    if cell_active_and_no_chd(self._idomain, self.period, cellidx, self._chd_cells):
                        icon += 1
                        cxn_rows.append(self._make_cxn_data_row(wellno, icon, cellidx, record))
                        # Keep track of top and bottom cells for this well
                        top_cellidx = cellidx if top_cellidx == -1 else top_cellidx
                        bottom_cellidx = cellidx
            if icon > 0:
                pkg_rows.append(self._make_pkg_data_row(wellno, icon, top_cellidx, bottom_cellidx, record))
                self._well_records[wellno] = shape_idx
                wellno += 1

        # Create dataframes
        self._pkg_df = self._create_pkg_data_dataframe(pkg_rows)
        self._cxn_df = self._create_cxn_data_dataframe(cxn_rows)

    def _starting_wellno(self) -> int:
        """Return the wellno that should be used to start out with."""
        if self._replace or not self._package.list_blocks.get('PACKAGEDATA'):
            return 1
        else:
            return io_util.count_max_line([self._package.list_blocks['PACKAGEDATA']]) + 1

    def _build_periods(self):
        """Builds the PERIODS blocks.

        <wellno> <mawsetting>
        """
        period_list = [1] if not self.trans_data else list(self.trans_data.keys())
        for period in period_list:
            rows = []

            for wellno, shape_idx in self._well_records.items():
                record = self._records[shape_idx]
                q = _get_q(self.trans_data, period, record)
                value1 = q if record['MAWSETTING'] == 'RATE' else record['VALUE1']
                # <wellno> <mawsetting>
                rows.append([wellno, record['MAWSETTING'], value1, record['VALUE2'], record['VALUE3']])

            if rows:
                add_list_package_period_data(self._package, period, rows)

    def _create_pkg_data_dataframe(self, rows: list[list]) -> pd.DataFrame:
        """Create and return the dataframe for the PACKAGEDATA block."""
        names, types, _defaults = self._package.get_column_info('PACKAGEDATA')
        return pd.DataFrame(rows, columns=names)

    def _create_cxn_data_dataframe(self, rows: list[list]) -> pd.DataFrame:
        """Create and return the dataframe for the CONNECTIONDATA block."""
        names, types, _defaults = self._package.get_column_info('CONNECTIONDATA')
        return pd.DataFrame(rows, columns=names)

    def _make_cxn_data_row(self, wellno: int, icon: int, cellidx: int, record) -> list:
        """Adds a row to the CONNECTIONDATA table.

        <wellno> <icon> <cellid(ncelldim)> <scrn_top> <scrn_bot> <hk_skin> <radius_skin>

        Args:
            wellno: The well number <wellno>.
            icon: GWF connection number <icon>.
            cellidx: Cell index, which is converted to a cell ID for the table.
            record: A record from the shapefile.
        """
        cellid = self.grid_info.modflow_cellid_from_cell_index(cellidx)
        id_list = [ijk for ijk in cellid] if not isinstance(cellid, int) else [cellid]
        condeqn = record['CONDEQN']
        # "If CONDEQN is SPECIFIED, THIEM, SKIN, or COMPOSITE, SCRN_TOP can be any value and is set to the top of the
        # cell. If CONDEQN is MEAN, SCRN_TOP is set to the multi-aquifer well connection cell top if the specified
        # value is greater than the cell top."
        cell_top, cell_bottom = self._tops[cellidx], self._bottoms[cellidx]
        if condeqn == 'MEAN' and 'USE_SCRN' in record and record['USE_SCRN']:
            scrn_top, scrn_bot = record['TOP_SCRN'], record['BOT_SCRN']
            scrn_top, scrn_bot = min(cell_top, scrn_top), max(cell_bottom, scrn_bot)
        else:
            scrn_top, scrn_bot = cell_top, cell_bottom
        return [wellno, icon, *id_list, scrn_top, scrn_bot, record['HK_SKIN'], record['RADIUS_SKN']]

    def _make_pkg_data_row(self, wellno: int, ngwfnodes: int, top_cellidx: int, bottom_cellidx: int, record) -> list:
        """Adds a row to the PACKAGEDATA table.

        <wellno> <radius> <bottom> <strt> <condeqn> <ngwfnodes> [<aux(naux)>] [<boundname>]

        Args:
            wellno: The well number <wellno>.
            ngwfnodes: Number of GWF nodes connected to this well.
            top_cellidx: Index of the top cell of this well.
            bottom_cellidx: Index of the bottom cell of this well.
            record: A record from the shapefile.
        """
        condeqn = record['CONDEQN']
        # "If CONDEQN is SPECIFIED, THIEM, SKIN, or COMPOSITE, BOTTOM is set to the cell bottom in the lowermost GWF
        # cell connection in cases where the specified well bottom is above the bottom of this GWF cell. If CONDEQN is
        # MEAN, BOTTOM is set to the lowermost GWF cell connection screen bottom in cases where the specified well
        # bottom is above this value."
        bottom = self._bottoms[bottom_cellidx] if condeqn != 'MEAN' else record['BOT_SCRN']
        strt = self._tops[top_cellidx]  # Set strt to the top elevation of the top cell for the well
        group = self.group(record['ID'])
        # <wellno> <radius> <bottom> <strt> <condeqn> <ngwfnodes> [<aux(naux)>] [<boundname>]
        return [wellno, record['RADIUS'], bottom, strt, condeqn, ngwfnodes, group, record['Name']]

    def _write_package(self) -> None:  # overrides
        """Writes the package to disk."""
        self._write_temporarily()
        super()._write_package()

    def _write_temporarily(self):
        """Writes the data to temporary files."""
        self._write_block('PACKAGEDATA', self._pkg_df)
        self._write_block('CONNECTIONDATA', self._cxn_df)

    def _write_block(self, block: str, df: pd.DataFrame) -> None:
        """Write the dataframe to disk, possibly appending to existing data.

        Args:
            block: The block.
            df: The dataframe.
        """
        if self._replace or not self._package.list_blocks.get(block):
            temp_filename = self._package.dataframe_to_temp_file(block, df)
            self._package.list_blocks[block] = temp_filename
        else:
            filename = self._package.list_blocks[block]
            gui_util.dataframe_to_csv(df, filename, io_util.mfsep, append=True)


def _get_q(trans_data, period: int, record) -> float:
    """Return the Q for the period.

    Args:
        trans_data: Transient data dict.
        period: The period.
        record: The record.

    Returns:
        See description.
    """
    # See if it's in the transient data
    point_id = record['ID']
    q = trans_data.get(period, {}).get(point_id, {}).get('Q')
    if q is None:
        q = record['Q']  # Get it from the shapefile record
    return q
