"""DisBuilder class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from typing import Any, Sequence

# 2. Third party modules

# 3. Aquaveo modules
from xms.constraint import Grid, GridType
from xms.constraint.modflow import DisuCalculator, get_stacked_grid_2d_topology
from xms.grid.ugrid import UGrid

# 4. Local modules
from xms.mf6.data.array import Array
from xms.mf6.data.data_type_aliases import DisX
from xms.mf6.data.disu_data import DisuData
from xms.mf6.data.disv_data import DisvData
from xms.mf6.file_io.disv_reader import DisReaderBase
from xms.mf6.misc import log_util

# Type aliases
Locs = list[list[float]] | dict[tuple[float, float]:int]  # DISV uses the 1st, DISU uses the 2nd
IntLists = Sequence[Sequence[int]]  # Like ((3, 2, 1, 0), (5, 4, 2)) or [[3, 2, 1, 0], [5, 4, 2]]
XyzTuples = list[tuple[float, float, float]]  # A list of tuples of 2 floats
VerticesList = list[list[Any]]  # List for VERTICES block
Cell2dList = list[list[Any]]  # List for use with CELL2D block


def build_dis_package(cogrid: Grid, ugrid: UGrid, package, units):
    """Creates a DIS, DISV or DISU package file from a UGrid.

    Args:
        cogrid (CoGrid): The cogrid.
        ugrid: The ugrid of the cogrid.
        package (DisData): The dis package
        units (str): The length units.

    Returns:
        (DisData): the new dis package
    """
    builder = DisBuilder(cogrid, ugrid, package, units)
    return builder.build_dis_package()


def get_grid_dis_type(cogrid: Grid) -> str:
    """Returns either 'DIS6', 'DISV6', or 'DISU6' based on the type of grid.

    Args:
        cogrid: The constrained grid.

    Returns:
        See description.
    """
    if cogrid:
        stacked = cogrid.check_is_stacked_grid()  # Returns tuple with [0]=True/False, [1]=layer count
        if cogrid.grid_type == GridType.rectilinear_3d:  # CO_RECTILINEAR_3D
            return 'DIS6'
        if stacked is None:
            return 'DISU6'
        return 'DISV6'
    return 'DIS6'


class DisBuilder:
    """Creates a DIS, DISV or DISU package file from a UGrid."""
    def __init__(self, cogrid: Grid, ugrid: UGrid, package: DisX, units: str):
        """Initializes the class.

        Args:
            cogrid: The constrained grid.
            ugrid: The ugrid of the cogrid.
            package: The dis package
            units: The length units.
        """
        self._cogrid = cogrid
        self._ugrid = ugrid  # Use the 'ugrid' @property/@setter. We use this to only call cogrid.ugrid 1x.
        self._package = package
        self._units = units
        self._log = log_util.get_logger()

    @property
    def ugrid(self):
        """Returns the ugrid."""
        # We do it this way so cogrid.ugrid is only called 1x.
        if self._ugrid is None:
            self._ugrid = self._cogrid.ugrid
        return self._ugrid

    @ugrid.setter
    def ugrid(self, ugrid: UGrid) -> None:
        """Sets the ugrid.

        Args:
            ugrid: The ugrid
        """
        self._ugrid = ugrid

    def build_dis_package(self):
        """Creates a DIS, DISV or DISU package file from a UGrid.

        Returns:
            (DisData): the new dis package
        """
        dis_type = self._package.ftype
        if dis_type == 'DIS6':
            self._build_dis()
        elif dis_type == 'DISV6':
            self._build_disv()
        else:  # 'DISU6'
            self._build_disu()
        return self._package

    def _set_array(self, array: Array, values, shape, use_constant=False):
        """Set the array values.

        Args:
            array: The array data.
            values: The values.
            shape: The shape.
            use_constant (bool): If true and all values are equal, it will use the CONSTANT option.
        """
        array.layer(0).layered = False
        array_layer = array.layer(0)
        array_layer.set_values(values, shape, use_constant)

    def _build_dis(self):
        """Builds the DIS package.
        """
        self._log.info('Building DIS package')
        # Get stuff from the Grid
        xloc = self._cogrid.locations_x
        yloc = self._cogrid.locations_y
        zloc = self._cogrid.locations_z  # For number of layers
        origin = self._cogrid.origin
        angle = self._cogrid.angle
        tops = self._cogrid.get_cell_tops()
        bottoms = self._cogrid.get_cell_bottoms()
        idomain = self._cogrid.model_on_off_cells

        # Compute some stuff
        delr = [xloc[i] - xloc[i - 1] for i in range(1, len(xloc))]  # <delr(ncol)> column spacing in the row direction
        delc = [yloc[i] - yloc[i - 1] for i in range(1, len(yloc))]  # <delc(nrow)> row spacing in the column direction
        delc.reverse()
        layer_count = len(zloc) - 1

        # Start changing data to match the Grid
        self._set_data_units()
        self._package.options_block.set('XORIGIN', on=True, value=str(origin[0]))
        self._package.options_block.set('YORIGIN', on=True, value=str(origin[1]))
        self._package.options_block.set('ANGROT', on=True, value=str(angle))
        self._package.grid_info().nrow = len(delc)  # <delc(nrow)> row spacing in the column direction
        self._package.grid_info().ncol = len(delr)  # <delr(ncol)> column spacing in the row direction
        self._package.grid_info().nlay = layer_count

        griddata = self._package.block('GRIDDATA')
        griddata.array('DELR').layer(0).set_values(delr, (len(delr), 1))
        griddata.array('DELC').layer(0).set_values(delc, (len(delc), 1))

        shape = (self._package.grid_info().nrow, self._package.grid_info().ncol)
        self._set_data_elevations(tops, bottoms)
        self._set_data_idomain(self._package, idomain, shape)

    def _build_disv(self):
        """Build DISV package data for a grid.
        """
        self._log.info('Building DISV package')

        locs: Locs = []
        cells2d: IntLists = []
        locs, cells2d = get_stacked_grid_2d_topology(self.ugrid)  # result[0] locs, result[1] cell2d
        # locs and cells2d is just for one layer.
        vertices = vertices_list_from_locations(locs, one_based=True)
        save_vertices_data(self._package, vertices, ftype='DISV6')
        cell_centers = get_cell_centers2d(self._cogrid, self.ugrid)
        max_ncvert = _compute_max_ncvert(cells2d)
        cell2d = cell2d_list_from_grid_data(cells2d, cell_centers, one_based=True)
        save_cell2d_data(cell2d, max_ncvert, self._package)

        # Set grid info
        cells_per_layer = len(cells2d)
        layer_count = self.ugrid.cell_count // cells_per_layer  # '//' rounds down to nearest whole number
        self._package.grid_info().nlay = layer_count
        self._package.grid_info().ncpl = cells_per_layer
        self._package.grid_info().nvert = len(locs)

        # Set data to match the Grid
        self._set_data_units()

        shape = (cells_per_layer, 1)
        tops = self._cogrid.get_cell_tops()
        bottoms = self._cogrid.get_cell_bottoms()
        self._set_data_elevations(tops, bottoms)

        idomain = self._cogrid.model_on_off_cells
        self._set_data_idomain(self._package, idomain, shape)

    def _build_disu(self):
        """Build DISU package data for a grid.
        """
        self._log.info('Building DISU package')

        locs, cells2d = get_disu_topology(self.ugrid)
        vertices = vertices_list_from_locations(locs, one_based=True)
        save_vertices_data(self._package, vertices, ftype='DISU6')
        cell_centers = get_cell_centers2d(self._cogrid, self.ugrid)
        max_ncvert = _compute_max_ncvert(cells2d)
        cell2d = cell2d_list_from_grid_data(cells2d, cell_centers, one_based=True)
        save_cell2d_data(cell2d, max_ncvert, self._package)

        # Calculate DISU data
        tops = self._cogrid.get_cell_tops()
        bottoms = self._cogrid.get_cell_bottoms()
        calculator = DisuCalculator(self.ugrid)
        calculator.calculate(tops, bottoms, cell_centers, 0.0)

        # These are size num cells
        iac = calculator.get_adjacent_counts()  # Adjacent counts
        plan_view_area = calculator.get_plan_view_area()  # AREA

        # These are all size NJA
        ja = calculator.get_adjacent_cells()  # Adjacent cell list
        ihc = calculator.get_connection_direction()  # Connection direction
        cl12 = calculator.get_interface_length()  # Interface length
        hwva = calculator.get_interface_area()  # Interface area
        angldegx = calculator.get_interface_angle()  # Interface angle

        # JA adjacent cell lists must be in ascending order
        ja, ihc, cl12, hwva, angldegx = _sort_ja(iac, (ja, ihc, cl12, hwva, angldegx))

        # "Entries in the HWVA array must be symmetric". Sometimes it's not due to round off error.
        hwva = make_symmetric(iac, ja, hwva)

        # Compute and set grid size
        cell_count = self.ugrid.cell_count
        nja = len(ja)
        self._package.grid_info().nodes = cell_count
        self._package.grid_info().nja = nja

        # Set JA length values
        shape = (nja, 1)
        connectiondata = self._package.block('CONNECTIONDATA')
        self._set_array(connectiondata.array('JA'), ja, shape, use_constant=False)
        self._set_array(connectiondata.array('IHC'), ihc, shape, use_constant=False)
        self._set_array(connectiondata.array('CL12'), cl12, shape, use_constant=False)
        self._set_array(connectiondata.array('HWVA'), hwva, shape, use_constant=False)
        self._set_array(connectiondata.array('ANGLDEGX'), angldegx, shape, use_constant=False)

        # DISU has no layers so treat as one layer
        cells_per_layer = cell_count
        shape = (cells_per_layer, 1)

        # Set NODES length values
        griddata = self._package.block('GRIDDATA')
        self._set_array(griddata.array('AREA'), plan_view_area, shape, use_constant=True)
        self._set_array(connectiondata.array('IAC'), iac, shape, use_constant=False)

        # Set data elevations
        self._set_data_units()
        self._set_data_elevations(tops, bottoms)

        # Set IDOMAIN
        idomain = self._cogrid.model_on_off_cells
        self._set_data_idomain(self._package, idomain, shape)

    def _set_data_units(self):
        """Sets the units on the package data object.
        """
        self._package.options_block.set('LENGTH_UNITS', on=True, value='UNKNOWN')
        if 'FEET' in self._units.upper():
            self._package.options_block.set('LENGTH_UNITS', on=True, value='FEET')
        elif 'METERS' == self._units.upper():
            self._package.options_block.set('LENGTH_UNITS', on=True, value='METERS')
        elif 'CENTIMETERS' == self._units.upper():
            self._package.options_block.set('LENGTH_UNITS', on=True, value='CENTIMETERS')

    def _set_data_elevations(self, tops, bottoms):
        """Set elevations on DIS package.

        Args:
            tops (iter): Array of tops for each grid cell.
            bottoms (iter): Array of bottoms for each grid cell.
        """
        self._package.set_tops(tops)
        self._package.set_bottoms(bottoms)

    def _set_data_idomain(self, package, idomain, shape):
        """Set IDOMAIN on grid package.

        Args:
            package (GriddataBase): Package data.
            idomain: Array of IDOMAIN values.
            shape (tuple): Shape of layered data.
        """
        block = package.block('GRIDDATA')
        if not idomain:
            if block.has('IDOMAIN'):
                block.delete_array('IDOMAIN')
        else:
            if block.array('IDOMAIN') is None:
                layered = package.ftype != 'DISU6'
                block.add_array(package.new_array('IDOMAIN', layered=layered))
            array = block.array('IDOMAIN')
            array.set_values(idomain, shape, False)


def get_disu_topology(ugrid: UGrid):
    """Get topology of cells for DISU package.

    Args:
        ugrid: The ugrid from the cogrid.

    Returns:
        A tuple containing the locations and plan-view cell point indices.
    """
    cell_count = ugrid.cell_count
    locations = {}
    cells_2d = []
    for cell_index in range(cell_count):
        _, polygon = ugrid.get_cell_plan_view_polygon(cell_index)
        icvert = []
        for location in polygon:
            loc = (location[0], location[1])
            point_index = locations.get(loc)
            if point_index is None:
                point_index = len(locations)
                locations[loc] = point_index
            icvert.append(point_index)
        icvert.reverse()
        cells_2d.append(icvert)
    return locations, cells_2d


def get_cell_centers2d(cogrid: Grid, ugrid: UGrid) -> XyzTuples:
    """Get the cell centroids for the grid.

    Args:
        cogrid: The constrained grid.
        ugrid: The ugrid from the cogrid.

    Returns:
        (list): list of cell centroids (x, y, z tuples).
    """
    custom_cell_centers = cogrid.cell_centers
    if custom_cell_centers is not None:
        return custom_cell_centers
    if not ugrid:
        ugrid = cogrid.ugrid
    return [ugrid.get_cell_centroid(idx)[1] for idx in range(ugrid.cell_count)]


def get_cell_centers3d(cogrid: Grid):
    """Have to calculate the z value of the cell centroids ourselves.

    If the grid has tops and bottoms, we use those to get the cell center Z. If not, we just average the Zs of the
    cell points.

    Args:
        cogrid: The constrained grid.

    Returns:
        List of cell centroids.
    """
    cell_centers: list[list[float]] = []

    # If the grid has layer tops and bottoms defined, use them
    cell_tops = None
    cell_bottoms = None
    if cogrid.has_cell_tops_and_bottoms():
        cell_tops = cogrid.get_cell_tops()
        cell_bottoms = cogrid.get_cell_bottoms()

    ugrid = cogrid.ugrid
    point_locations = ugrid.locations
    for cell_idx in range(ugrid.cell_count):
        centroid = list(ugrid.get_cell_centroid(cell_idx)[1])
        if cell_tops:
            centroid[2] = (cell_tops[cell_idx] + cell_bottoms[cell_idx]) / 2
        else:
            centroid[2] = _average_cell_points_zs(point_locations, ugrid, cell_idx)
        cell_centers.append(centroid)
    return cell_centers


def _average_cell_points_zs(point_locations, ugrid: UGrid, cell_idx: int) -> float:
    """Returns the average z value of all the points of the cell.

    Args:
        point_locations: xyz locations of the ugrid points.
        ugrid: The ugrid.
        cell_idx (int): The cell index.
    """
    cell_points = ugrid.get_cell_points(cell_idx)
    z_sum = sum(point_locations[point][2] for point in cell_points)
    return z_sum / len(cell_points)


def vertices_list_from_locations(locs: Locs, one_based: bool) -> VerticesList:
    """Returns the VERTICES list in the form it would appear in the file.

    Args:
        locs: List of points.
        one_based: True if indexes should start at 1, False if they should start at 0.

    Returns:
        list of [<iv> <xv> <yv>], e.g. [[0, 0., 0.], [1, 0., 100.],[2, 100., 100.]]
    """
    if one_based:
        return [[point_index + 1, loc[0], loc[1]] for point_index, loc in enumerate(locs)]
    else:
        return [[point_index, loc[0], loc[1]] for point_index, loc in enumerate(locs)]


def save_vertices_data(data, vertices: VerticesList, ftype):
    """Save vertices data to a temp file.

    Args:
        data (DisData): Package data.
        vertices: list for VERTICES block.
        ftype (str): The file type used in the GWF name file (e.g. 'WEL6')
    """
    list_lines = [' '.join(map(str, vertex)) for vertex in vertices]
    reader = DisReaderBase(ftype=ftype)
    external_filename = reader.list_to_external_file(list_lines=list_lines, column_count=3)
    data.list_blocks['VERTICES'] = external_filename
    if len(vertices) > 0:
        data.grid_info().nvert = len(vertices)


def cell2d_list_from_grid_data(cells2d: IntLists, cell_centers: XyzTuples, one_based: bool) -> list[list[float]]:
    """Returns the CELL2D list in the form it would appear in the file.

    Args:
        cells2d: Integer point IDs for each cell. MUST BE 0-BASED.
        cell_centers: The calculated or custom grid cell centers.
        one_based: True if indexes should start at 1, False if they should start at 0.
    """
    cell2d = []
    if one_based:
        for cell_index, cell in enumerate(cells2d):
            center = cell_centers[cell_index]
            points = [point_index + 1 for point_index in cell]
            cell2d.append([cell_index + 1, center[0], center[1], len(points), *points])
    else:
        for cell_index, cell in enumerate(cells2d):
            center = cell_centers[cell_index]
            cell2d.append([cell_index, center[0], center[1], len(cell), *cell])
    return cell2d


def save_cell2d_data(cell2d: Cell2dList, max_ncvert: int, data: DisvData | DisuData) -> None:
    """Save CELL2D data to a temp file.

    Args:
        cell2d: CELL2D.
        max_ncvert: Maximum number of vertices used to define a cell.
        data: DIS* package.
    """
    list_lines = [' '.join(map(str, cell)) for cell in cell2d]
    reader = DisReaderBase(ftype=data.ftype)
    external_filename = reader.list_to_external_file(list_lines, max_ncvert + 4)
    data.list_blocks['CELL2D'] = external_filename


def _compute_max_ncvert(cells2d: IntLists) -> int:
    """Computes and returns the maximum number of vertices used to define a cell.

    Args:
        cells2d: Integer point IDs for each cell. MUST BE 0-BASED.
    """
    max_ncvert = 0
    for cell2d in cells2d:
        max_ncvert = max(max_ncvert, len(cell2d))
    return max_ncvert


def ia_from_iac(iac) -> list[int]:
    """Returns an IA array given the IAC array.

    The IA array is in the binary grid (.grb) file and is the length of the IAC + 1 (length of IAC is the number of
    cells in the grid). IA starts at 1 and accumulates the number of adjacent cells. IAC is from the DISU package or
    the DISU calculator and is the number of adjacent cells to each cell, plus 1.

    Args:
        iac: Adjacent counts (IAC).
    """
    ia = [1]
    for iac_value in iac:
        ia.append(ia[-1] + iac_value)
    return ia


def iac_from_ia(ia) -> list[int]:
    """Returns an IAC array given the IA array.

    The IA array is in the binary grid (.grb) file and is the length of the IAC + 1 (length of IAC is the number of
    cells in the grid). IA starts at 1 and accumlates the number of adjacent cells. IAC is from the DISU package or
    the DISU calculator and is the number of adjacent cells to each cell, plus 1.

    Args:
        ia: IA array from binary grid file (.grb) containing adjacency counts but accumlated and size cell count + 1.
    """
    iac = []
    for i in range(len(ia) - 1):
        iac.append(ia[i + 1] - ia[i])
    return iac


def _make_ja_dict(iac, ja) -> dict[int, dict]:
    """Returns a dict that gives you the index into the ja array given two cell IDs.

    This is used to make the HWVA array symmetric. Might be useful for other things too?

    Example:
        ja array showing index numbers:
        [0] -1
        [1] 2
        [2] 3
        [3] -2
        [4] 1
        [5] 3
        [6] -3
        [7] 1
        [8] 2

        ja_dict: {1: {2: 1, 3: 2}, 2: {1: 4, 3: 5}, 3: {1: 7, 2: 8}}
        ja_dict[1][2] = 1
        ja_dict[1][3] = 2
        ja_dict[2][1] = 4
        ja_dict[2][3] = 5
        ...

    Args:
        iac: Adjacent counts (IAC).
        ja: The adjacent cells (JA).
    """
    ja_dict = {}
    start = 1
    for cell_idx, count in enumerate(iac):
        end = start + count - 1
        ja_dict[cell_idx + 1] = {}
        for i in range(start, end):
            ja_dict[cell_idx + 1][ja[i]] = i
        start += count
    return ja_dict


def make_symmetric(iac, ja, nja_array) -> list:
    """Returns a version of nja_array that has been made symmetric.

    Args:
        iac: Adjacent counts (IAC).
        ja: The adjacent cells (JA).
        nja_array: Any array the same size as JA (NJA), e.g. IAC, HWVA, ANGLDEGX.

    Returns:
        The new symmetric array.
    """
    nja_array = list(nja_array)
    ja_dict = _make_ja_dict(iac, ja)
    for cell_id, adjacent_cell_indices in ja_dict.items():
        for adj_cell_id, idx in adjacent_cell_indices.items():
            if cell_id in ja_dict[adj_cell_id]:
                symmetric_idx = ja_dict[adj_cell_id][cell_id]
                nja_array[symmetric_idx] = nja_array[idx]
    return nja_array


def _get_sorting_array(iac: list[int], ja: list[int]) -> list[int] | None:
    """If ja is out of order, returns an array of how it should be reordered, or else None.

    Uses "Decorate-Sort-Undecorate" https://docs.python.org/3/howto/sorting.html#decorate-sort-undecorate

    Args:
        iac: Adjacent counts (IAC).
        ja: The adjacent cells (JA).

    Returns:
        The old-to-new array (array containing old indices) or None if it's already sorted.
    """
    decorated = [(i, cell) for i, cell in enumerate(ja)]
    start = 1
    for count in iac:
        end = start + count - 1
        decorated[start:end] = sorted(decorated[start:end], key=lambda pair: pair[1])
        start += count

    is_sorted = all(decorated[i][0] <= decorated[i + 1][0] for i in range(len(decorated) - 1))
    return None if is_sorted else [i for i, _cell in decorated]


def _reorder(old_to_new, values) -> list:
    """Returns a new array that has been reordered.

    Args:
        old_to_new: The old-to-new array (array containing old indices).
        values: A list or tuple of values.
    """
    assert len(old_to_new) == len(values)
    new_array = list(values).copy()  # Cast to list in case it's a tuple
    for i in range(len(old_to_new)):
        new_array[i] = values[old_to_new[i]]
    return new_array


def _sort_ja(iac, nja_arrays) -> list:
    """Sorts the JA array and all other arrays that are dependent on it.

    For JA array, adjacent cell lists must be in ascending order. Resort all NJA arrays if needed.

    Args:
        iac: Adjacent counts (IAC).
        nja_arrays: Tuple of nja arrays, starting with JA.
    """
    old_to_new = _get_sorting_array(iac, nja_arrays[0])
    if old_to_new:
        new_arrays = []
        for array in nja_arrays:
            new_arrays.append(_reorder(old_to_new, array))
        return new_arrays
    else:
        return nja_arrays


def _dis_elevs_match_grid(dis_values, grid_values) -> list[int]:
    """Compares DIS* package elevations to the grid and returns list of cell IDs of mismatches.

    A simple tolerance of 1e-6 is used for the comparison.

    Args:
        dis: DIS* package.
        cogrid: The constrained grid.

    Returns:
        See description.
    """
    tol = 1e-6
    mismatches = []
    for i, (dis_value, grid_value) in enumerate(zip(dis_values, grid_values)):
        if abs(dis_value - grid_value) > tol:
            mismatches.append(i + 1)
    return mismatches
