"""Module for the WeirStructureWriter class."""

__copyright__ = "(C) Copyright Aquaveo 2024"
__license__ = "All rights reserved"
__all__ = ['RowOrColumnStructureWriter']

# 1. Standard Python modules
from abc import ABC, abstractmethod
from functools import cached_property
import itertools
import logging
from typing import cast, TextIO

# 2. Third party modules

# 3. Aquaveo modules
from xms.constraint import GridType, QuadtreeGrid2d, RectilinearGrid2d, UGrid2d
from xms.data_objects.parameters import Arc, Coverage
from xms.gmi.data_bases.coverage_base_data import CoverageBaseData
from xms.grid.ugrid import UGrid
from xms.guipy.data.target_type import TargetType
from xms.snap import SnapInteriorArc

# 4. Local modules
from xms.cmsflow.data.model import get_model
from xms.cmsflow.file_io.card_writer import CardWriter
from xms.cmsflow.file_io.writer_utils import cells_are_contiguous, neighbors_of


class RowOrColumnStructureWriter(ABC):
    """Class for writing all the data needed by a weir or tide gate in a structure coverage."""
    def __init__(
        self, coverage: Coverage, data: CoverageBaseData, ugrid: UGrid2d, logger: logging.Logger, wrote_header: bool,
        cards: TextIO
    ):
        """
        Initialize the writer.

        Args:
            coverage: Coverage containing geometry to write.
            data: Data manager for the coverage. Should have its component_id_map initialized.
            ugrid: The QuadTree or CGrid to snap the coverage to.
            logger: Where to log any warnings or errors.
            cards: Where to write cards to. Typically obtained by calling `open(...)` on the *.cmcards file.
            wrote_header: Whether the structures header has already been written.
        """
        self._structure_display_name: str = ''
        self._structure_group_name: str = ''

        self._coverage = coverage
        self._data = data

        self._cogrid = ugrid
        if ugrid.grid_type == GridType.rectilinear_2d:
            self._cgrid = cast(RectilinearGrid2d, ugrid)
            self._quadtree = None
        elif ugrid.grid_type == GridType.quadtree_2d:
            self._cgrid = None
            self._quadtree = cast(QuadtreeGrid2d, ugrid)
        else:
            raise AssertionError('Unknown grid type')  # pragma: nocover  Only happens due to programmer mistake.

        self._ugrid = ugrid.ugrid
        self._activity = ugrid.model_on_off_cells  # May be empty if UGrid has no activity (all cells are active)
        self._cards = CardWriter(cards)
        self._logger = logger
        self._section = get_model().arc_parameters
        self._wrote_header = wrote_header

    def write(self) -> bool:
        """
        Write all the weir or tide gate data needed for the coverage.

        Returns:
            Whether the `!Structures` header was written (either before or as a result of calling this function).
        """
        for arc in self._coverage.arcs:
            self._write_structure(arc)

        return self._wrote_header

    def _ensure_header_written(self):
        """
        Ensure the header for this structure section is written.

        Does nothing after the first time it was called.
        """
        if self._wrote_header:
            return

        self._logger.info(f'Writing {self._structure_display_name} structures')
        self._wrote_header = True
        self._cards.write('!Structures', indent=0)

    def _write_structure(self, arc: Arc):
        """
        Write a structure if necessary.

        The arc will not be written if it is invalid or the wrong type. If it is invalid, a warning will be logged.

        Args:
            arc: The arc to write.
        """
        values = self._data.feature_values(TargetType.arc, feature_id=arc.id)
        self._section.restore_values(values)
        structure_group = self._section.group(self._structure_group_name)
        if not structure_group.is_active:
            return

        cells = self._snapper.get_snapped_points(arc)['id']
        if not self._check_cells(cells, arc.id):
            return

        self._ensure_header_written()
        self._write_specific_structure(cells)

    @abstractmethod
    def _write_specific_structure(self, cells: list[int]):
        """Write structure-specific information."""

    @cached_property
    def _snapper(self) -> SnapInteriorArc:
        """Set up the arc snapper."""
        snapper = SnapInteriorArc()
        ugrid = cast(UGrid, self._cogrid)  # Inspections can't see UGrid2d is a UGrid because the inheritance is in C++.
        snapper.set_grid(ugrid, target_cells=True)
        return snapper

    def _check_cells(self, cells: list[int], feature_id: int) -> bool:
        """
        Check if a list of cells is a valid structure.

        Logs an error if they are not a valid structure.

        Args:
            cells: Cells to check.
            feature_id: Feature ID of the arc the cells represent. Used for error reporting.

        Returns:
            Whether the cells comprise a valid structure.
        """
        if not cells:
            self._skip(feature_id, 'arc does not intersect the grid')
            return False

        if len(cells) < 2:
            self._skip(feature_id, 'arc is degenerate')
            return False

        if not cells_are_contiguous(self._ugrid, cells):
            self._skip(feature_id, 'arc cells are not contiguous')
            return False

        # Intersecting cells with differing resolution can make the diagonal check fail, so it needs to come first.
        if not self._cells_same_resolution(set(cells)):
            self._skip(feature_id, 'arc intersects cells with different refinement')
            return False

        # Assumes the cells are all the same resolution, so it needs to come after the same resolution check.
        if self._cells_are_diagonal(cells):
            self._skip(feature_id, 'arc is not in single row or column of grid')
            return False

        neighbors = self._get_perpendicular_neighbors(cells)
        # Make sure we compare to a structure cell too, or we'll miss the case where a line of same-refinement cells is
        # between two lines of cells that are all the same refinement as each other, but different from the structure.
        neighbors.add(cells[0])
        if not self._cells_same_resolution(neighbors):
            self._skip(feature_id, 'arc adjacent to cells with different refinement')
            return False
        neighbors.discard(cells[0])  # Don't confuse later checks with an extra neighbor.

        # At this point, we know the structure and its adjacent cells are all the same refinement. This means that
        # either there are exactly two perpendicular neighbors per cell, or the structure is adjacent to the boundary.
        if len(neighbors) != 2 * len(cells):
            self._skip(feature_id, 'arc is adjacent to grid boundary')
            return False

        # Empty indicates no activity, so only do the activity check if there's activity.
        if self._activity and not all(self._activity[cell] for cell in cells):
            self._skip(feature_id, 'arc intersects inactive cells')
            return False

        # Same as above: No activity implies all active.
        if self._activity and not all(self._activity[cell] for cell in neighbors):
            self._skip(feature_id, 'arc is adjacent to inactive cells')
            return False

        return True

    def _get_perpendicular_neighbors(self, cells: list[int]) -> set[int]:
        """
        Get the neighbors that are perpendicular to the arc.

        This might not behave well for diagonal or stair-step arcs.

        Some cells may be missing if either end of the structure has perpendicular neighbors of different refinement to
        the endcap itself. This can be ruled out by checking if everything has the same refinement.

        Args:
            cells: The cells that make up the arc.

        Returns:
            The cells that are neighbors of the input cells, minus the ones that are off the ends of the structure.
        """
        cells_to_check = self._get_endcap_neighbors(cells[0], cells[1])
        cells_to_check.update(self._get_endcap_neighbors(cells[-1], cells[-2]))

        for cell in cells[1:-1]:
            cells_to_check.update(neighbors_of(self._ugrid, cell))

        cells_to_check.difference_update(cells)
        return cells_to_check

    def _get_endcap_neighbors(self, endcap: int, structure_cell: int) -> set[int]:
        """
        Get the perpendicular neighbors of an endcap cell.

        Endcap cells are those at the start and end of a structure. They have exactly one neighbor which is a structure
        cell, in contrast to interior ones, which have exactly two neighbors that are structure cells (in typical cases;
        the counts may be off if the structure is U or spiral shaped, but those can be rejected by a diagonal test).

        Surprising (but normal) behavior might be observed if the endcap has neighbors of different refinement. If one
        of its neighbors is less refined, then it's possible that one or more of the returned neighbors will also be
        neighbors of structure_cell. If one of its neighbors is less refined, then only one of the refined neighbors per
        side will be returned, which means a neighbor might be missed. Both of these cases can be ruled out by checking
        if the returned cells are the same refinement as the endcap cell.

        Args:
            endcap: The endcap cell to find the neighbors of.
            structure_cell: The other structure cell that is attached to the endcap. This may be the other endcap if the
                structure is only two cells long; this method will work correctly whether structure_cell is another
                endcap or not.

        Returns:
            The cells which are neighbors of the endcap, minus the structure_cell cell and the neighbor(s) on the
            opposite side of the endcap cell.
        """
        # Given a grid like this, where E denotes the endcap and W denotes the structure cell:
        #
        # +---+---+---+
        # | 0 | P | N |
        # +---+---+---+
        # | 0 | E | S |
        # +---+---+---+
        # | 0 | P | N |
        # +---+---+---+
        #
        # We want to find the two P cells, while ignoring everything else. This algorithm is based on the observation
        # that the P cells are both neighbors of the N cells, which are in turn neighbors of the S cell, but none of the
        # 0 cells have this property. So we find the neighbors of E which are also second-order neighbors of S, and
        # declare them to be the endcap neighbors.
        #
        # Things get more complicated with QuadTrees:
        #
        # +---+---------------+
        # | 0 |       L       |
        # +---+-------+-------+
        # | 0 |   E   |   S   |
        # +---+---+---+-------+
        # | 0 | M | H | N | N |
        # +---+---+---+---+---+
        #
        # In this case, L is a neighbor we should check, but it doesn't have our property. If the structure is at least
        # three cells long, then all of S's non-structure neighbors will be checked later, so missing L this time around
        # isn't a problem since it will be picked up later as a neighbor of S. But if the weir is only two cells long,
        # then we'll just end up here again with E and S switched, which will miss L again. Since both cells would be
        # short a neighbor, later code could not distinguish this case from the one where the structure is adjacent to
        # the grid boundary. To avoid that, we also include shared neighbors in the result, which will give later stages
        # a chance to recognize the mismatched resolution. Another check enforces that E and S are adjacent to each
        # other, so the only remaining way for them to share a neighbor is if the neighbor is above or below them, which
        # means shared neighbors are always in the direction we want to check.
        #
        # Another shortcoming is that we'll also fail to report the M cell as a neighbor of E, so later code won't ever
        # check M, even though it's a neighbor of E and in the right direction. But all that code would have done with M
        # is report that it's a different resolution from E and skip the arc, which is exactly what will happen when it
        # sees H, which we will pick up, so the absence of M won't matter as long as we report H.

        # These might both include parallel neighbors.
        endcap_neighbors = set(neighbors_of(self._ugrid, endcap)) - {structure_cell}
        structure_neighbors = set(neighbors_of(self._ugrid, structure_cell)) - {endcap}
        # These are the L cells from the above block comment.
        shared_neighbors = endcap_neighbors & structure_neighbors

        second_order_neighbor_generator = (neighbors_of(self._ugrid, neighbor) for neighbor in structure_neighbors)
        second_order_neighbors = set(itertools.chain.from_iterable(second_order_neighbor_generator))
        second_order_neighbors -= {endcap, structure_cell}

        # All of the endcap's neighbors that are also structure neighbors, plus the shared neighbors.
        perpendicular_neighbors = (endcap_neighbors & second_order_neighbors) | shared_neighbors
        return perpendicular_neighbors

    def _cells_are_diagonal(self, cells: list[int]) -> bool:
        """
        Check whether a list of cells cross multiple rows and columns in the grid.

        Args:
            cells: Cells to check.

        Returns:
            Whether the cells form a diagonal line in the grid (i.e. are not all in the same row and are not in the same
            column).
        """
        if self._cgrid:
            # Rectilinear 2D grids don't have a .quad_grid_cells attribute since they aren't QuadTrees, but their
            # .get_cell_ijk_from_index() method yields sensible results, so we'll use it.
            ijks = [self._cgrid.get_cell_ijk_from_index(cell) for cell in cells]
        elif self._quadtree:
            # QuadTrees have a .get_cell_ijk_from_index() method too, but the results are garbage when the grid has
            # refinement, which is kind of the point of having a QuadTree, so we use .quad_grid_cells instead.
            qc = self._quadtree.quad_grid_cells
            ijks = [qc.get_cell_ijk(cell) for cell in cells]
        else:
            raise AssertionError('Unknown grid type')  # pragma: nocover  Only happens due to programmer mistake.

        reference_i = ijks[0][0]
        reference_j = ijks[0][1]

        same_i = all(ijk[0] == reference_i for ijk in ijks)
        same_j = all(ijk[1] == reference_j for ijk in ijks)

        return not (same_i or same_j)

    def _cells_same_resolution(self, cells: set[int]) -> bool:
        """
        Check whether a list of cells are in the same row or column within the grid.

        Args:
            cells: Cells to check.

        Returns:
            Whether the arc was good.
        """
        if self._cgrid:
            return True  # Every cell in a rectilinear grid is the same resolution.

        qc = self._quadtree.quad_grid_cells
        reference_refinement = qc.get_cell_refinement(next(iter(cells)))
        all_same = all(qc.get_cell_refinement(cell) == reference_refinement for cell in cells)
        return all_same

    def _skip(self, feature_id: int, reason: str):
        """
        Create a message for an arc that is being skipped.

        Args:
            feature_id: The ID of the arc the message is for.
            reason: A short reason message for why the arc is bad.
        """
        self._logger.warning(f'Skipping {self._structure_display_name} arc {feature_id}, {reason}.')
