"""GridInfo class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from enum import Enum
from pathlib import Path

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules

# 4. Local modules
from xms.mf6.data.mfcellid import MfCellId
from xms.mf6.file_io.grb_reader import GrbReader
from xms.mf6.misc import util


class DisEnum(Enum):
    """Enumeration of DIS packages."""
    DIS = 1
    DISV = 2
    DISU = 3
    END = 4


class GridInfo:
    """Event filter object. Used to prevent user from leaving number fields blank."""
    def __init__(self, nrow=-1, ncol=-1, nlay=-1, ncpl=-1, nvert=-1, nodes=-1, nja=-1, dis: DisEnum | None = None):
        """Initializes the class."""
        self._dis_enum: DisEnum | None = dis  # Which dis package
        self._nrow = nrow  # Number of rows (DIS)
        self._ncol = ncol  # Number of columns (DIS)
        self._nlay = nlay  # Number of layers (DIS/DISV)
        self._ncpl = ncpl  # Number of cells per layer (DISV)
        self._nvert = nvert  # Total number of (x, y) vertex pairs (DISV/DISU)
        self._nodes = nodes  # Total number of cells (DISU)
        self._nja = nja  # Total number of connections (DISU)

        # If dis_enum wasn't set, set it based on what info was given
        if self._dis_enum is None:
            if ncpl != -1:
                self._dis_enum = DisEnum.DISV
            elif nodes != -1 or nja != -1:
                self._dis_enum = DisEnum.DISU
            else:
                self._dis_enum = DisEnum.DIS

    @property
    def dis_enum(self) -> DisEnum | None:
        """Returns the type of DIS package."""
        return self._dis_enum

    @dis_enum.setter
    def dis_enum(self, value: DisEnum | None) -> None:
        """Sets the type of DIS package.

        Args:
            value: The type of dis package.
        """
        self._dis_enum = value

    @property
    def nrow(self) -> int:
        """Returns nrow."""
        return self._nrow

    @nrow.setter
    def nrow(self, value: int) -> None:
        """Sets nrow.

        Args:
            value: The value.
        """
        self._nrow = value

    @property
    def ncol(self) -> int:
        """Returns ncol."""
        return self._ncol

    @ncol.setter
    def ncol(self, value: int) -> None:
        """Sets ncol.

        Args:
            value: The value.
        """
        self._ncol = value

    @property
    def nlay(self) -> int:
        """Returns nlay."""
        return self._nlay

    @nlay.setter
    def nlay(self, value: int) -> None:
        """Sets nlay.

        Args:
            value: The value.
        """
        self._nlay = value

    @property
    def ncpl(self) -> int:
        """Returns ncpl."""
        return self._ncpl

    @ncpl.setter
    def ncpl(self, value: int) -> None:
        """Sets ncpl.

        Args:
            value: The value.
        """
        self._ncpl = value

    @property
    def nvert(self) -> int:
        """Returns nvert."""
        return self._nvert

    @nvert.setter
    def nvert(self, value: int) -> None:
        """Sets nvert.

        Args:
            value: The value.
        """
        self._nvert = value

    @property
    def nodes(self) -> int:
        """Returns nodes."""
        return self._nodes

    @nodes.setter
    def nodes(self, value: int) -> None:
        """Sets nodes.

        Args:
            value: The value.
        """
        self._nodes = value

    @property
    def nja(self) -> int:
        """Returns nja."""
        return self._nja

    @nja.setter
    def nja(self, value: int) -> None:
        """Sets nja.

        Args:
            value: The value.
        """
        self._nja = value

    def modflow_cellid_from_cell_index(self, cell_index):
        """For DisEnum.DIS returns the layer, row, and column (1-based) given a cell index.

        For DisEnum.DISV returns layer and cell2d.
        For DisEnum.DISU returns the cell_index incremented by one.

        Args:
            cell_index (int): The 0-based cell index.

        Returns:
            (tuple): See description.
        """
        if self._dis_enum == DisEnum.DIS:
            row = (cell_index // self._ncol) % self._nrow + 1
            col = cell_index % self._ncol + 1
            lay = cell_index // (self._nrow * self._ncol) + 1
            return lay, row, col
        elif self._dis_enum == DisEnum.DISV:
            lay = cell_index // self._ncpl + 1
            cell2d = cell_index - ((lay - 1) * self._ncpl) + 1
            return lay, cell2d
        else:
            return cell_index + 1

    def cell_index_from_modflow_cellid(self, cellid, one_based=True):
        """Returns a 0-based cell index given the 0-based MODFLOW cellid which may be an int or a tuple.

        Args:
            cellid: Can be an int (DISU), a 2-tuple (DISV), or a 3-tuple(DIS).
            one_based (bool): True if lay, row, col are 1-based. False if they are 0-based.

        Returns:
            (int): See description.
        """
        # DISU
        if isinstance(cellid, list) and len(cellid) == 1:
            cellid = cellid[0]

        if isinstance(cellid, int) or isinstance(cellid, np.int64):
            if one_based:
                return cellid - 1
            else:
                return cellid
        elif len(cellid) == 3:  # DIS
            return self.cell_index_from_lay_row_col(*cellid, one_based)
        elif len(cellid) == 2:  # DISV
            return self.cell_index_from_lay_cell2d(*cellid, one_based)
        else:  # Error
            raise ValueError('Bad cellid in GridInfo.cell_index_from_modflow_cellid')

    def check_mfcellid(self, mfcellid: MfCellId) -> tuple[bool, list[bool] | None]:
        """Return True if mfcellid is valid, else False and a list containing False where the data is bad.

        Args:
            mfcellid: The MfCellId to check.

        Returns:
            See description.
        """
        if mfcellid is None:
            return False, None
        elif self._dis_enum == DisEnum.DIS:
            return _in_ranges(mfcellid, [self._nlay, self._nrow, self.ncol])
        elif self._dis_enum == DisEnum.DISV:
            return _in_ranges(mfcellid, [self._nlay, self.cells_per_layer()])
        elif self._dis_enum == DisEnum.DISU:
            return _in_ranges(mfcellid, [self._nodes])

    def cell_index_from_lay_row_col(self, lay, row, col, one_based=True):
        """Returns a 0-based cell index given the 1-based layer, row and column.

        Args:
            lay (int): Layer.
            row (int): Row.
            col (int): Column
            one_based (bool): True if lay, row, col are 1-based. False if they are 0-based.

        Returns:
            See description.
        """
        if one_based:
            if not 1 <= lay <= self._nlay or not 1 <= row <= self._nrow or not 1 <= col <= self._ncol:
                raise ValueError('Layer, row, or column is outside grid range.')
            return ((lay - 1) * self._nrow * self._ncol) + ((row - 1) * self._ncol) + (col - 1)
        else:
            if not 0 <= lay <= self._nlay - 1 or not 0 <= row <= self._nrow - 1 or not 0 <= col <= self._ncol - 1:
                raise ValueError('Layer, row, or column is outside grid range.')
            return (lay * self._nrow * self._ncol) + (row * self._ncol) + col

    def cell_index_from_lay_cell2d(self, lay, cell2d, one_based=True):
        """Returns a 0-based cell index given the 1-based layer, and cell2d.

        This can be used with DIS as well as DISV.

        Args:
            lay (int): Layer.
            cell2d (int): ID of cell in the layer.
            one_based (bool): True if lay, cell2d are 1-based. False if they are 0-based.

        Returns:
            See description.
        """
        ncpl = self.cells_per_layer()
        if one_based:
            if not 1 <= lay <= self._nlay or not 1 <= cell2d <= ncpl:
                raise ValueError('Layer or Cell2D is outside grid range.')
        else:
            if not 0 <= lay <= self._nlay - 1 or not 0 <= cell2d <= ncpl - 1:
                raise ValueError('Layer or Cell2D is outside grid range.')
        return cell_index_from_lay_cell2d(lay, cell2d, ncpl, one_based)

    def cell_count(self):
        """Returns the total number of cells in the grid.

        Returns:
            See description.
        """
        if self._dis_enum == DisEnum.DIS:
            return self._nlay * self._nrow * self._ncol
        elif self._dis_enum == DisEnum.DISV:
            return self._nlay * self._ncpl
        else:  # self._dis_enum == DisEnum.DISU
            return self._nodes

    def cell_id_column_count(self):
        """Returns the number of columns needed for the cellid.

        Returns:
            (int) See description.
        """
        if self._dis_enum == DisEnum.DIS:
            cellid_column_count = 3
        elif self._dis_enum == DisEnum.DISV:
            cellid_column_count = 2
        else:
            cellid_column_count = 1

        return cellid_column_count

    def cells_per_layer(self) -> int:
        """Returns the number of cells per layer.

        For DISU, returns total number of cells (self._nodes).

        Returns:
            (int): See description
        """
        if self._dis_enum == DisEnum.DIS:
            return self._nrow * self._ncol
        elif self._dis_enum == DisEnum.DISV:
            return self._ncpl
        else:
            return self._nodes

    def fix_cellid(self, cellid, layer=0):
        """Makes the cellid complete by adding missing info, and also makes it 1-based for writing to the file.

        flopy.GridIntersect returns 0-based (row, col) for DIS, and single 0-based int for both DISV and DISU.

        Args:
            cellid (int or tuple): The cellid from flopy.GridIntersect
            layer (int): the layer

        Returns:
            (int or tuple): The completed 1-based MODFLOW cellid.
        """
        if self._dis_enum == DisEnum.DIS:
            if util.is_number(cellid):  # (We make these cell indexes ourselves with DIS and polygons)
                mfcellid = self.modflow_cellid_from_cell_index(cellid)
                lay = 1 if layer == 0 else layer
                return lay, mfcellid[1], mfcellid[2]
            elif len(cellid) == 2:
                lay = 1 if layer == 0 else layer
                return lay, cellid[0] + 1, cellid[1] + 1  # Add layer and make 1-based
            else:
                return cellid[0] + 1, cellid[1] + 1, cellid[2] + 1  # Make 1-based
        elif self._dis_enum == DisEnum.DISV:
            # With DISV the intersection info is the same as with DISU - a 0-based single integer.
            # Turn it into a 1-based MODFLOW cellid (lay, cell2d) suitable for DISV.
            lay, node = self.modflow_cellid_from_cell_index(cellid)
            if layer != 0:
                lay = layer
            return lay, node
        else:
            return cellid + 1

    def cell_indexes_from_dataframe(self, df):
        """Get a set of cell indexes from a dataframe.

        Args:
            df (Dataframe): pandas dataframe

        Returns:
            (set(int)): The cell indexes in the dataframe
        """
        if self._dis_enum not in [DisEnum.DIS, DisEnum.DISV, DisEnum.DISU]:
            raise ValueError('DIS package type not specified.')

        # "Iterating through pandas objects is generally slow. In many cases,
        # iterating manually over the rows is not needed [...]."
        # ... BUT ...
        # "Premature optimization is the root of all evil", so here we go.
        cell_idxs = set()
        for row in df.itertuples(index=False):
            if self._dis_enum == DisEnum.DIS:
                if 'CELLID' in df:
                    if row.CELLID.upper() != 'NONE':
                        cellid = row.CELLID.split()
                        index = self.cell_index_from_lay_row_col(int(cellid[0]), int(cellid[1]), int(cellid[2]))
                    else:
                        index = None
                else:
                    index = self.cell_index_from_lay_row_col(row.LAY, row.ROW, row.COL)
            elif self._dis_enum == DisEnum.DISV:
                if 'CELLID' in df:
                    cellid = row.CELLID.split()
                    index = self.cell_index_from_lay_cell2d(int(cellid[0]), int(cellid[1]))
                else:
                    index = self.cell_index_from_lay_cell2d(row.LAY, row.CELL2D)
            else:
                index = int(row.CELLID) - 1
            if index is not None:
                cell_idxs.add(index)
        return cell_idxs

    @staticmethod
    def from_grb_file(grb_filepath: Path) -> 'GridInfo':
        """Returns a GridInfo object by reading the info from the binary grid (.grb) file.

        Args:
            grb_filepath (Path): File path to .grb file

        Returns:
            (GridInfo): Number of rows, cols etc.
        """
        if not grb_filepath.is_file():
            return None

        reader = GrbReader(grb_filepath, {'NCELLS', 'NODES', 'NLAY', 'NROW', 'NCOL', '`NCPL', 'NJA', 'NVERT'})
        grb_data = reader.read()
        return GridInfo(
            nrow=grb_data.get('NROW', -1),
            ncol=grb_data.get('NCOL', -1),
            nlay=grb_data.get('NLAY', -1),
            ncpl=grb_data.get('NCPL', -1),
            nvert=grb_data.get('NVERT', -1),
            nodes=grb_data.get('NODES', -1),
            nja=grb_data.get('NJA', -1)
        )


def cell_index_from_lay_cell2d(layer: int, cell2d: int, cells_per_layer: int, one_based: bool = True):
    """Returns a 0-based cell index given the 1-based layer, and cell2d.

    This can be used with DIS as well as DISV.

    Args:
        layer: Layer (could be 0 or 1-based depending on one_based arg).
        cell2d: ID of cell in the layer.
        cells_per_layer: Number of cells per grid layer.
        one_based: True if layer and cell2d are 1-based. False if they are 0-based.

    Returns:
        See description.
    """
    if one_based:
        return ((layer - 1) * cells_per_layer) + cell2d - 1
    else:
        return (layer * cells_per_layer) + cell2d


def _in_ranges(mfcellid: MfCellId, maxes: list) -> tuple[bool, list[bool]]:
    """Return True if all mfcellid parts are between 1 and maxes, else False and a list showing where the data is bad.

    Args:
        mfcellid: The MfCellId to check.
        maxes: Max values.

    Returns:
        See description.
    """
    ok = True
    good = [True] * len(maxes)
    for i, mx in enumerate(maxes):
        if mfcellid[i] < 1 or mfcellid[i] > mx:
            ok = False
            good[i] = False
    return ok, good
