"""Compares DIS* packages with the linked UGrids."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from dataclasses import dataclass, field
from typing import Any

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.tree import tree_util
from xms.constraint import Grid
from xms.grid.ugrid import UGrid

# 4. Local modules
from xms.mf6.components import dis_builder
from xms.mf6.data.mfsim_data import MfsimData
from xms.mf6.data.model_data_base import ModelDataBase
from xms.mf6.geom import ugrid_builder
from xms.mf6.geom.ugrid_builder import DisGeom, DisvDisuGeom

# File global
_MAX_DIFFS = 5  # Maximum number of differences we're going to find before stopping
_TAB = ' ' * 8  # Indent n spaces


def get_sim_mismatches(mfsim: MfsimData) -> list[str]:
    """Compares the DIS* packages the currently linked grids in XMS and returns any mismatch errors.

    Args:
        mfsim: The simulation.

    Returns:
        List of error strings.
    """
    comparer = DisGridComparer()
    return comparer.get_sim_mismatches(mfsim)


def get_model_mismatches(model: ModelDataBase, cogrid: Grid) -> list[str]:
    """Compares the model DIS* package to the currently linked grid in XMS and returns any mismatch errors.

    Args:
        model: The model.
        cogrid: The constrained grid.

    Returns:
        List of error strings.
    """
    comparer = DisGridComparer()
    return comparer.get_model_mismatches(model, cogrid)


@dataclass
class Mismatch:
    """A DIS*/UGrid mismatch.

    If the size of a list of things in the DIS doesn't match the size of the corresponding list in UGrid (e.g., DELR),
    we leave "indexes" empty and store the size of the DIS list as the first item in "dis_vals", and the size of the
    UGrid list in the first item of "grid_vals".
    """
    dis_property: str = ''
    indexes: list[int] = field(default_factory=list)  # Indices where there is a mismatch
    dis_vals: list[Any] = field(default_factory=list)  # DIS* package value
    grid_vals: list[Any] = field(default_factory=list)  # UGrid value


class DisGridComparer:
    """Compares the DIS* package against the currently linked grid in XMS."""
    def __init__(self):
        """Initializer."""
        # The following are so we don't have to pass so many args around
        self._model = None
        self._dis = None
        self._tol = 1.0e-6  # Floating point comparison tolerance
        self._ugrid: UGrid | None = None  # So we only call cogrid.ugrid 1x

    def get_sim_mismatches(self, mfsim: MfsimData | None) -> list[str]:
        """Checks that the linked UGrids match the DIS packages in each model, and returns errors if not.

        Returns:
            List of error strings.
        """
        errors: list[str] = []
        for model in mfsim.models:
            # Get current linked grid from XMS
            cogrid = model.get_cogrid()
            errors.extend(self.get_model_mismatches(model, cogrid))
        return errors

    def get_model_mismatches(self, model: ModelDataBase, cogrid: Grid) -> list[str]:
        """Returns mismatch errors for this model/cogrid combo.

        Args:
            model: The model.
            cogrid: The constrained grid.

        Returns:
            See description.
        """
        errors = []
        self._model = model
        self._dis = model.get_dis()
        if not self._dis:
            return []

        dis_ftype = self._dis.ftype
        grid_ftype = dis_builder.get_grid_dis_type(cogrid)
        if dis_ftype != grid_ftype:
            errors.append(self._error_first_line(cogrid))
            errors.append(f'{_TAB}DIS* package is {dis_ftype} but linked UGrid requires {grid_ftype}.')
        else:
            # Check all the things, adding to list of mismatches if they don't match
            mismatches: list[Mismatch] = []
            if dis_ftype == 'DIS6':
                dis_geom = ugrid_builder.geom_from_dis(self._dis)
                self._check('TOP', dis_geom, cogrid, mismatches)
                self._check('BOTM', dis_geom, cogrid, mismatches)
                self._check('DELR', dis_geom, cogrid, mismatches)
                self._check('DELC', dis_geom, cogrid, mismatches)
                self._check('IDOMAIN', dis_geom, cogrid, mismatches)
                self._check('NLAY', dis_geom, cogrid, mismatches)
            elif dis_ftype == 'DISV6':
                dis_geom = ugrid_builder.geom_from_disv_disu(self._dis, transform=True)
                self._check('TOP', dis_geom, cogrid, mismatches)
                self._check('BOTM' if dis_ftype != 'DISU6' else 'BOT', dis_geom, cogrid, mismatches)
                # Can't compare VERTICES. UGrid vertices won't necessarily match due to warping and ordering
                # self._check('VERTICES', dis_geom, cogrid, mismatches)
                self._check('CELL2D', dis_geom, cogrid, mismatches)
            # We don't currently support DISU6

            if mismatches:
                errors.append(self._error_first_line(cogrid))
            # Convert mismatches to errors
            for mismatch in mismatches:
                errors.append(_error_from_mismatch(mismatch))
        return errors

    def _check(
        self, dis_property: str, dis_geom: DisGeom | DisvDisuGeom, cogrid: Grid, mismatches: list[Mismatch]
    ) -> None:
        """Checks the dis_property of the DIS* package and the UGrid.

        Args:
            dis_property: The DIS* package property we are considering.
            dis_geom: The geometry part of the DIS package.
            cogrid: The constrained grid.
            mismatches: List of mismatches to append to.
        """
        mismatch = None
        match dis_property:
            case 'TOP':
                mismatch = self._compare(dis_property, dis_geom.tops, cogrid.get_cell_tops())
            case 'BOT' | 'BOTM':
                mismatch = self._compare(dis_property, dis_geom.bots, cogrid.get_cell_bottoms())
            case 'IDOMAIN':
                mismatch = self._compare(dis_property, dis_geom.model_on_off_cells, cogrid.model_on_off_cells)
            case 'DELR':
                mismatch = self._compare(dis_property, dis_geom.delr_list, cogrid.locations_x)
            case 'DELC':
                mismatch = self._compare(dis_property, dis_geom.delc_list, cogrid.locations_y)
            case 'NLAY':
                if len(cogrid.locations_z) != self._dis.grid_info().nlay + 1:
                    mismatch = Mismatch(dis_property, [self._dis.grid_info().nlay + 1], [len(cogrid.locations_z)])
            case 'CELL2D':
                dis_vals = dis_geom.centers2d
                if not self._ugrid:
                    self._ugrid = cogrid.ugrid
                cell_centers2d = dis_builder.get_cell_centers2d(cogrid, self._ugrid)
                mismatch = self._compare(dis_property, dis_vals, cell_centers2d[:len(dis_vals)])
        if mismatch:
            mismatches.append(mismatch)

    def _compare(self, dis_property: str, dis_vals: list[Any], grid_vals: list[Any]) -> Mismatch | None:
        """Compare two lists of things, returning [] if the same, otherwise list of indices where different.

        Args:
            dis_property: The DIS* package property we are considering.
            dis_vals: DIS* package values.
            grid_vals: Grid values.
        """
        # If list is empty, return None
        if len(dis_vals) == 0:
            return None

        # See if sizes match
        if len(dis_vals) != len(grid_vals):
            return Mismatch(dis_property=dis_property, dis_vals=[len(dis_vals)], grid_vals=[len(grid_vals)])

        # Compare the two lists of values
        count = 0
        mismatch = None
        for i, (dis_val, grid_val) in enumerate(zip(dis_vals, grid_vals)):
            if isinstance(dis_vals[0], float) or isinstance(dis_vals[0], int):
                equal = abs(dis_val - grid_val) <= self._tol
            else:
                equal = abs(dis_val[0] - grid_val[0]) <= self._tol and abs(dis_val[1] - grid_val[1]) <= self._tol

            if not equal:
                if not mismatch:
                    mismatch = Mismatch(dis_property)
                mismatch.indexes.append(i)
                mismatch.dis_vals.append(dis_val)
                mismatch.grid_vals.append(grid_val)
                count += 1
                if count == _MAX_DIFFS:
                    break
        return mismatch

    def _error_first_line(self, cogrid: Grid) -> str:
        """Returns the first line of the error message.

        Returns:
            See description.
        """
        dis_uuid = self._dis.tree_node.uuid
        dis_path = tree_util.build_tree_path(self._model.tree_node, dis_uuid)
        if cogrid:
            ugrid_ptr_node = tree_util.descendants_of_type(
                self._model.tree_node, xms_types=['TI_UGRID_PTR'], allow_pointers=True, only_first=True
            )
            ugrid_path = tree_util.build_tree_path(self._model.tree_node, ugrid_ptr_node.uuid)
            error = f'DIS* package "{dis_path}" does not match linked UGrid "{ugrid_path}"'
        else:
            error = f'DIS* package "{dis_path}" does not match UGrid.'
        return error


def _error_from_mismatch(mismatch: Mismatch) -> str:
    """Returns the error string, given a mismatch.

    Args:
        mismatch: The mismatch.

    Returns:
        See description.
    """
    # Get diffs line first
    if mismatch.indexes:
        indexes = ", ".join(map(str, mismatch.indexes[:_MAX_DIFFS]))
        dis_vals = ", ".join(map(str, mismatch.dis_vals[:_MAX_DIFFS]))
        grid_vals = ", ".join(map(str, mismatch.grid_vals[:_MAX_DIFFS]))
        diffs = f'At indexes: {indexes}; DIS* value(s): {dis_vals}; UGrid value(s): {grid_vals}'
        if len(mismatch.indexes) == _MAX_DIFFS:
            diffs += f' (only first {_MAX_DIFFS} shown)'
    else:
        diffs = f'DIS* size: {mismatch.dis_vals[0]}, UGrid size: {mismatch.grid_vals[0]}'

    return f'{_TAB}{mismatch.dis_property}: {diffs}'
