"""UgridRenumberer class."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules
from typing import Sequence

# 2. Third party modules

# 3. Aquaveo modules
from xms.constraint import Grid, UGrid3d
from xms.grid.ugrid import UGrid

# 4. Local modules


class UGridRenumberer:
    r"""Renumber a UGrid the way HydroGeoSphere requires.

    ::

           8----------7    4---------6
          /|         /|    |\       /|
         / |        / |    | \     / |
        5----------6  |    |  \   /  |
        |  4-------|--3    1---\ /---3
        | /        | /      \   5   /
        |/         |/        \  |  /
        1----------2          \ | /
                               \|/
                                2

    Cell nodes must be numbered as shown, bottom to top, counterclockwise locking down, and
    cells and nodes must be numbered in layers, bottom to top. See HydroGeoSphere reference manual.
    """
    def __init__(self, co_grid: Grid, ugrid: UGrid) -> None:
        """Initialize the class.

        Args:
            co_grid: The input co_grid.
            ugrid: The input ugrid (the ugrid of the co_grid).
        """
        self._co_grid = co_grid
        self._ugrid = ugrid  # Do 'self._co_grid.ugrid' only once as it's costly
        self._point_locations: list[tuple[float]] | None = None

    def build_cogrid(self) -> UGrid3d:
        """Builds and returns the new, renumbered, UGrid.

        Returns:
              (UGrid3d): The new UGrid.
        """
        self._point_locations = self._ugrid.locations
        cell_centers = self._get_cell_centers()

        # Points
        points, _sheet_count = _order_in_layers_bottom_to_top(self._point_locations)
        old_to_new_points = _make_old_to_new_list(points)
        xyzs = [self._point_locations[index] for index in points]

        # Cells
        cells, layer_count = _order_in_layers_bottom_to_top(cell_centers)
        old_to_new_cells = _make_old_to_new_list(cells)
        cell_stream = self._make_cell_stream(old_to_new_points, cells)

        # Create the grid
        new_ugrid = UGrid(xyzs, cell_stream)
        new_ugrid.set_cell_ordering(UGrid.cell_ordering_enum.CELL_ORDER_INCREASING_UP)
        new_co_grid = UGrid3d(new_ugrid)
        new_co_grid.cell_layers = _get_cell_layers(len(cells), layer_count)
        self._add_cell_tops_and_bottoms(old_to_new_cells, new_co_grid)
        _reset_point_locations(new_ugrid, xyzs)
        return new_co_grid

    def _get_top_and_bottom_faces(self, cell_idx: int) -> tuple[int, int]:
        """Returns the top and bottom face indexes of the cell.

        Args:
            cell_idx (int): The cell index.

        Returns:
            (tuple[int, int]): top and bottom face indexes.
        """
        face_count = self._ugrid.get_cell_3d_face_count(cell_idx)
        top_face = -1
        bottom_face = -1
        for face in range(face_count):
            if top_face > -1 and bottom_face > -1:
                break  # pragma no cover - tests never hit this line and that's fine
            orientation = self._ugrid.get_cell_3d_face_orientation(cell_idx, face)
            if orientation == UGrid.face_orientation_enum.ORIENTATION_TOP:
                top_face = face
            elif orientation == UGrid.face_orientation_enum.ORIENTATION_BOTTOM:
                bottom_face = face
        return top_face, bottom_face

    def _make_cell_stream(self, old_to_new_points: list[int], cells: list[int]) -> list[int]:
        """Returns the cell stream in the HGS order.

        Args:
            old_to_new_points: list of indices, where index is old index, and value is the new.
            cells: List of cell indexes in HGS order.

        Returns:
            See description.
        """
        # need correct cell ordering to get top and bottom face points
        self._ugrid.set_cell_ordering(UGrid.cell_ordering_enum.CELL_ORDER_UNKNOWN)
        cell_ordering = self._ugrid.calculate_cell_ordering()
        self._ugrid.set_cell_ordering(cell_ordering)

        # Get cell type by looking at first cell
        cell_points_count = len(self._ugrid.get_cell_points(0))
        cell_type = UGrid.cell_type_enum.HEXAHEDRON if cell_points_count == 8 else UGrid.cell_type_enum.WEDGE
        cell_stream = []
        for cell_idx in cells:
            top_face, bottom_face = self._get_top_and_bottom_faces(cell_idx)
            top_points = self._ugrid.get_cell_3d_face_points(cell_idx, top_face)
            bottom_points = self._ugrid.get_cell_3d_face_points(cell_idx, bottom_face)
            if cell_points_count == 8:
                # top points for hexahedron are counter clock-wise like the face
                new_top_points = [old_to_new_points[old_idx] for old_idx in top_points]
                # bottom points for hexahedron are clock-wise so reverse and rotate by 1
                temp_bottom_points = [old_to_new_points[old_idx] for old_idx in reversed(bottom_points)]
                new_bottom_points = temp_bottom_points[-1:] + temp_bottom_points[:-1]
            else:
                # We can't make wedges exactly like HGS wants them because UGrids follow VTK format. When we write
                # the HGS model native files, we adjust the wedge cell point order to make HGS happy.
                # top points for wedge are clock-wise so reverse and rotate by 1
                temp_top_points = [old_to_new_points[old_idx] for old_idx in reversed(top_points)]
                new_top_points = temp_top_points[-1:] + temp_top_points[:-1]
                # bottom points for wedge are counter clock-wise like the face
                new_bottom_points = [old_to_new_points[old_idx] for old_idx in bottom_points]
            cell_stream.extend([int(cell_type), cell_points_count, *new_bottom_points, *new_top_points])
        return cell_stream

    def _average_cell_points_zs(self, cell_idx: int) -> float:
        """Returns the average z value of all the points of the cell.

        Args:
            cell_idx (int): The cell index.
        """
        cell_points = self._ugrid.get_cell_points(cell_idx)
        z_sum = sum(self._point_locations[point][2] for point in cell_points)
        return z_sum / len(cell_points)

    def _get_cell_centers(self) -> list[list[float]]:
        """Have to calculate the z value of the cell centroids ourselves.

        If the grid has tops and bottoms, we use those to get the cell center Z. If not, we just average the Zs of the
        cell points.

        Returns:
            List of cell centroids.
        """
        cell_centers: list[list[float]] = []

        # If the grid has layer tops and bottoms defined, use them
        cell_tops = None
        cell_bottoms = None
        if self._co_grid.has_cell_tops_and_bottoms():
            cell_tops = self._co_grid.get_cell_tops()
            cell_bottoms = self._co_grid.get_cell_bottoms()

        for i in range(self._ugrid.cell_count):
            centroid = list(self._ugrid.get_cell_centroid(i)[1])
            if cell_tops:
                centroid[2] = (cell_tops[i] + cell_bottoms[i]) / 2
            else:
                centroid[2] = self._average_cell_points_zs(i)
            cell_centers.append(centroid)
        return cell_centers

    def _add_cell_tops_and_bottoms(self, old_to_new_cells: list[int], new_co_grid: UGrid3d) -> None:
        """Adds the cell tops and bottoms from the original grid to the new grid if original had them.

        Args:
            old_to_new_cells (list[int]): list of indices, where index is old index, and value is the new.
            new_co_grid (UGrid3d): The new co_grid.
        """
        if self._co_grid.has_cell_tops_and_bottoms():
            cell_tops = self._co_grid.get_cell_tops()
            cell_bottoms = self._co_grid.get_cell_bottoms()
            new_cell_tops = _apply_old_to_new_list(old_to_new_cells, cell_tops)
            new_cell_bottoms = _apply_old_to_new_list(old_to_new_cells, cell_bottoms)
            new_co_grid.set_cell_tops_and_bottoms(new_cell_tops, new_cell_bottoms)


def _order_in_layers_bottom_to_top(locations: Sequence[Sequence[float]]) -> tuple[list[int], int]:
    """Returns a list of the xyz locations in the HGS order.

    Args:
        locations (Sequence[Sequence[float]]): The xyz locations.

    Returns:
        (tuple[list[int], int]): See description.
    """
    columns = _make_vertical_columns(locations)

    # Check that all columns have the same number of sheets
    differs_list = _columns_with_different_lengths(columns)
    if differs_list:
        _raise_different_lengths_error(differs_list)

    _sort_columns_bottom_to_top(locations, columns)
    indexes, layer_count = _order_by_layers(columns)
    return indexes, layer_count


def _columns_with_different_lengths(columns: dict[str, list[int]]) -> list[tuple[str, int]]:
    """Return list of all columns where column length differs from first column, as tuple of (column key, length).

    If list is not empty, the first item will be the first column.

    Args:
        columns: The columns.

    Returns:
        See description.
    """
    if not columns:
        return []

    key1, len1 = '', None
    differs_list: list[tuple[str, int]] = []
    for key, column in columns.items():
        if len1 is None:
            key1 = key
            len1 = len(column)
        elif len(column) != len1:
            differs_list.append((key, len(column)))

    if differs_list:
        differs_list.insert(0, (key1, len1))  # Put the first column at the front of the list
    return differs_list


def _raise_different_lengths_error(differs_list: list[tuple[str, int]]) -> None:
    """Raise an error about columns having different lengths.

    Args:
        differs_list: Result of _columns_with_different_lengths().
    """
    dl = differs_list  # for short
    msg = f'Number of points at the following locations differs from the first location ({dl[0][0]}) {dl[0][1]}:\n'
    differs = ', '.join([f'({item[0]}) {item[1]}' for item in dl[1:]])
    msg += differs
    raise RuntimeError(msg)


def _order_by_layers(columns: dict[str, list[int]]) -> tuple[list[int], int]:
    """Returns a list of indices in layer order (for points, layers are "sheets"), bottom to top, and layer count.

    Args:
        columns (dict[str, list[int]]): The vertical point columns, now sorted.

    Returns:
        (tuple[list[int], int]): See description.
    """
    indexes: list[int] = []
    layer_count = _get_layer_count(columns)
    # Iterate through columns layer_count times adding column[layer_idx] to the list each time
    for layer_idx in range(layer_count):
        for column in columns.values():
            indexes.append(column[layer_idx])
    return indexes, layer_count


def _get_layer_count(columns: dict[str, list[int]]) -> int:
    """Returns the number of layers in the columns.

    Args:
        columns (dict[str, list[int]]): Vertical columns of xyz indices.

    Returns:
        (int): See description.
    """
    first_column = next(iter(columns.values()))
    return len(first_column)


def _sort_columns_bottom_to_top(locations: Sequence[Sequence[float]], columns: dict[str, list[int]]) -> None:
    """Sorts the xyz columns from bottom to top.

    The hard to read python below works like this: (https://stackoverflow.com/questions/6618515)
    1. zip the two lists.
    2. create a new, sorted list based on the zip using sorted().
    3. using a list comprehension extract the second elements of each pair from the sorted, zipped list.

    Args:
        locations (Sequence[Sequence[float]]): list like of xyz locations.
        columns (dict[str, list[int]]): Vertical columns of xyz indices.
    """
    for key, column in columns.items():
        zs = [locations[point_idx][2] for point_idx in column]
        columns[key] = [point_idx for _, point_idx in sorted(zip(zs, column), key=lambda pair: pair[0])]


def _make_vertical_columns(locations: Sequence[Sequence[float]]) -> dict[str, list[int]]:
    """Given list of xyz and assuming a stacked grid, returns xyzs sorted into unordered vertical columns.

    Vertical columns will not yet be sorted from bottom to top or top to bottom. Assumes a stacked grid and is
    currently implemented assuming the stacked xyz values are exactly the same (no tolerance used).

    Args:
        locations (Sequence[Sequence[float]]): The xyz points.

    Returns:
        See description.
    """
    columns: dict[str, list[int]] = {}
    for i, xyz in enumerate(locations):
        key = f'{xyz[0]} {xyz[1]}'  # Just use string of xy as hash key
        column = columns.get(key)
        if column is None:
            columns[key] = [i]
        else:
            column.append(i)
    return columns


def _make_old_to_new_list(indexes: list[int]) -> list[int]:
    """Returns a list of ints where list index is the old index, and value at index is the new index.

    Args:
        indexes (list[int]): Indexes, in sheet order, bottom to top.

    Returns:
        (list[int]): See description.
    """
    old_to_new = [-1] * len(indexes)
    for i in range(len(indexes)):
        old_to_new[indexes[i]] = i
    return old_to_new


def _apply_old_to_new_list(old_to_new_list, old_values):
    """Given the old_to_new_list and list of old values, returns a new list of values ordered by old_to_new_list.

    Args:
        old_to_new_list (list[int]): List of indices, where index is old index, and value is the new.
        old_values (list[int]): Old list of values that will be returned in the new order.
    """
    new_values = [-1.0] * len(old_values)
    for old_index, new_index in enumerate(old_to_new_list):
        new_values[new_index] = old_values[old_index]
    return new_values


def _get_cell_layers(cell_count, layer_count) -> list[int]:
    """Returns the list of cell layers.

    Args:
        cell_count (int): Number of cells.
        layer_count (int): Number of layers.

    Returns:
        (list[int]): Cell layers
    """
    cells_per_layer = cell_count // layer_count
    new_cell_layers = []
    for layer in range(layer_count):
        layer_numbers = [layer + 1] * cells_per_layer
        new_cell_layers.extend(layer_numbers)
    return new_cell_layers


def _reset_point_locations(ugrid: UGrid, xyzs) -> None:
    for i, xyz in enumerate(xyzs):
        ugrid.set_point_location(i, xyz)
