from abc import ABC
from enum import IntEnum

from .refine_level import RefineLevel
from .. import _xmsconstraint
from ..rectilinear_geometry import Numbering


class NeighborDirection(IntEnum):
    r"""

    ::

        Neighbors of cell assuming smoothed quadtree grid with I, J grid origin
        at top left (I rows, J cols).
        -------------------------
        |  0  | 11  | 10  |  9  |
        -------------------------
        |  1  |           |  8  |
        -------   cell    -------
        |  2  |           |  7  |
        -------------------------
        |  3  |  4  |  5  |  6  |
        -------------------------

    """
    i_minus_j_minus_corner = 0
    j_minus_i_minus = 1
    j_minus_i_plus = 2
    i_plus_j_minus_corner = 3
    i_plus_j_minus = 4
    i_plus_j_plus = 5
    j_plus_i_plus_corner = 6
    j_plus_i_plus = 7
    j_plus_i_minus = 8
    i_minus_j_plus_corner = 9
    i_minus_j_plus = 10
    i_minus_j_minus = 11
    num_neighbor_directions = 12


class QuadGridCells(ABC):
    """Tracks refinement and cell indices of a quad tree constrained UGrid."""

    EXCLUDED = 2 ** 31 - 1  # maximum 32-bit signed integer
    """Cell index value for cells that are excluded."""

    INCLUDED = EXCLUDED - 1
    """Cell index value for included cells with no assigned cell index."""

    def __init__(self, num_i=None, num_j=None, num_k=None,
                 numbering=Numbering.kij, is_3d=None, number_of_levels=None,
                 instance=None):
        """Initializer that can optionally setup the grid dimensions in one
        call.

        Args:
            num_i: The number of cells in I direction.
            num_j: The number of cells in J direction.
            num_k: The number of cells in K direction (or None for 2D).
            numbering: The cell numbering.
            is_3d: Is the grid 2D or 3D?
            number_of_levels: The maximum number of refinement levels.
            instance: The C++ wrapped instance.
        """
        if instance is None:
            instance = _xmsconstraint.quadtree.QuadGridCells()
        self._instance = instance

        if num_i is None and num_j is None and num_k is None:
            return
        elif num_i is None or num_j is None:
            raise ValueError(
                'When specifying num_k, must specify num_i and num_j.')

        if num_k is None:
            if is_3d:
                raise ValueError('Must specify num_k for a 3D grid.')
            is_3d = False
            num_k = 1
        elif is_3d is None:
            is_3d = True

        self.set_base_grid(num_i, num_j, num_k, numbering, is_3d)

        if number_of_levels is not None:
            self.number_of_levels = number_of_levels

    def __str__(self):
        """Get string description."""
        num_i = self.number_of_base_rows
        num_j = self.number_of_base_columns
        num_k = self.number_of_base_layers
        return f'<QuadGridCells: num_i={num_i}, num_j={num_j}, num_k={num_k}>'

    def __repr__(self):
        """Get string representation."""
        num_i = self.number_of_base_rows
        num_j = self.number_of_base_columns
        num_k = self.number_of_base_layers
        return f'<QuadGridCells: num_i={num_i}, num_j={num_j}, num_k={num_k}>'

    @property
    def number_of_levels(self):
        """The maximum number of refinement levels."""
        return self._instance.GetNumLevels()

    @number_of_levels.setter
    def number_of_levels(self, value):
        """The maximum number of refinement levels."""
        self._instance.SetNumLevels(value)

    def get_level(self, i, j, k):
        """Get the refinement level of a cell in refinement grid coordinates.
        Refinement level coordinates is the number of base rows or columns
        times two to maximum refinement levels power.

        Args:
            i: The row of the cell in refinement grid coordinates.
            j: The column of the cell in refinement grid coordinates.
            k: The layer of the cell.

        Returns:
            The refinement level of the cell.
        """
        self._require_valid_refine_coordinates(i, j, k)
        return RefineLevel(instance=self._instance.GetLevel(i, j, k))

    def set_level(self, i, j, k, refine_level):
        """Set the refinement level of a cell in refinement grid coordinates.

        Args:
            i: The row of the cell in refinement grid coordinates.
            j: The column of the cell in refinement grid coordinates.
            k: The layer of the cell.
            refine_level: The amount of cell refinement.
        """
        self._require_valid_refine_coordinates(i, j, k)
        self._instance.SetLevel(i, j, k, refine_level._instance)

    def get_included(self, i, j, k):
        """Determine if a given cell is included in the grid. Uses refinement
        grid coordinates. When included there is an existing cell in the UGrid.
        Otherwise, there is no cell and the index numbering skips that
        location.

        Args:
            i: The row of the cell in refinement grid coordinates.
            j: The column of the cell in refinement grid coordinates.
            k: The layer of the cell.
        Returns:
            True if the cell is included.
        """
        self._require_valid_refine_coordinates(i, j, k)
        return self._instance.GetIncluded(i, j, k)

    def set_included(self, i, j, k):
        """Set a given cell as included in the grid. Uses refinement grid
        coordinates. When included there is an existing cell in the UGrid.
        Otherwise, there is no cell and the index numbering skips that
        location.

        Args:
            i: The row of the cell in refinement grid coordinates.
            j: The column of the cell in refinement grid coordinates.
            k: The layer of the cell.
        """
        self._require_valid_refine_coordinates(i, j, k)
        self._instance.SetIncluded(i, j, k)

    def set_excluded(self, i, j, k):
        """Set a given cell as excluded (not included) in the grid. Uses
        refinement grid coordinates. When excluded there is not an existing
        cell in the UGrid.

        Args:
            i: The row of the cell in refinement grid coordinates.
            j: The column of the cell in refinement grid coordinates.
            k: The layer of the cell.
        """
        self._require_valid_refine_coordinates(i, j, k)
        self._instance.SetExcluded(i, j, k)

    def set_all_included(self):
        """Set all cells included in the generated UGrid."""
        self._instance.SetAllIncluded()

    def get_cell_index(self, i, j, k):
        """Get the index of a cell in the UGrid. Uses refinement grid
        coordinates.

        Args:
            i: The row of the cell in refinement grid coordinates.
            j: The column of the cell in refinement grid coordinates.
            k: The layer of the cell.

        Returns:
            The index of the cell in the UGrid.
        """
        self._require_valid_refine_coordinates(i, j, k)
        return self._instance.GetCellIdx(i, j, k)

    def set_cell_index(self, i, j, k, index):
        """Set the index of a cell in the UGrid. Uses refinement grid
        coordinates.

        Args:
            i: The row of the cell in refinement grid coordinates.
            j: The column of the cell in refinement grid coordinates.
            k: The layer of the cell.
            index: The index of the cell.
        """
        self._require_valid_refine_coordinates(i, j, k)
        self._instance.SetCellIdx(i, j, k, index)

    def has_cell_index(self, index):
        """Determine if a cell index exists in the UGrid.

        Args:
            index: The index of the cell.

        Returns:
            If the given cell index is in the UGrid.
        """
        return self._instance.HasCellIdx(index)

    def get_cell_ijk(self, index):
        """Get the I, J, and K value of a given cell index in the UGrid.

        Args:
            index: The index of the UGrid cell.

        Returns:
            (tuple): A tuple of the I, J, and K index of the cell in refinement coordinates.
        """
        ijk = self._instance.GetCellIjk(index)
        if ijk[0] < 0 or ijk[1] < 0 or ijk[2] < 0:
            raise ValueError('Index is out of range.')
        return ijk

    def get_cell_refinement(self, index):
        """Get the cell refinement level for the given UGrid cell.

        Args:
            index: The index of the UGrid cell.

        Returns:
            The number of refinement levels at the given index.
        """
        level = self._instance.GetCellRefinement(index)
        if level is None:
            raise IndexError('Index is out of range.')
        level = level[0]
        return RefineLevel(instance=level)

    @property
    def most_refined(self):
        """The maximum refinement level of any cell in the UGrid."""
        return RefineLevel(instance=self._instance.MostRefined())

    @property
    def number_of_refined_rows(self):
        """The number of rows in refinement coordinates."""
        return self._instance.GetNumRefinedRows()

    @property
    def number_of_refined_columns(self):
        """The number of columns in refinement coordinates."""
        return self._instance.GetNumRefinedCols()

    @property
    def number_of_base_rows(self):
        """The number of rows in the base grid or unrefined grid."""
        return self._instance.GetNumBaseRows()

    @property
    def number_of_base_columns(self):
        """The number of columns in the base grid or unrefined grid."""
        return self._instance.GetNumBaseCols()

    @property
    def number_of_base_layers(self):
        """The number of layers in the base grid."""
        return self._instance.GetNumBaseLays()

    def set_base_grid(self, num_i, num_j, num_k=None, numbering=Numbering.kij,
                      is_3d=False):
        """Set the dimensions of the base grid.

        Args:
            num_i: The number of cells in I direction.
            num_j: The number of cells in J direction.
            num_k: The number of cells in K direction (or None for 2D).
            numbering: The cell numbering.
            is_3d: Is the grid 2D or 3D?
        """
        if num_k is None:
            if is_3d:
                raise ValueError('Must specify num_k for a 3D grid.')
            is_3d = False
            num_k = 1
        self._instance.SetBaseGrid(num_i, num_j, num_k, int(numbering), is_3d)

    def get_base_ijk(self, index):
        """Get the I, J, K location of the given UGrid cell in the base grid.
        Args:
            index

        Returns:
            tuple: The I, J, K location of the UGrid cell within the base grid.
        """
        return self._instance.GetBaseIjk(index)

    @property
    def numbering(self):
        """The cell numbering (KIJ or KJI). The order the cells are numbered
        in."""
        return Numbering(self._instance.GetNumbering())

    def get_cell_quad_idx(self, i, j, k):
        """Get the cell quadrant index

        Args:
            i: The row of the cell.
            j: The column of the cell.
            k: The layer of the cell.

        Returns:
            The cell quadrant index.
        """
        return self._instance.GetCellQuadIdx(i, j, k)

    def get_smoothed_neighbor(self, index, direction):
        """
        Get index of neighbor in given direction, or -1 if cell doesn't exist.

        Args:
            index: The index of the UGrid cell.
            direction: The direction to the neighbor.

        Returns:
            The cell index of the neighbor or -1 if there is no neighbor.
        """
        return self._instance.GetSmoothedNeighbor(index, direction)

    def get_smoothed_neighbors(self, index):
        """
        Get indices of all neighbors of cell, or -1 for cells that don't exist.

        Args:
            index: The index of the UGrid cell.

        Returns:
            An iterable of cell indices of the 12 potential neighbors in
            NeighborDirection order. Returns -1 for neighbors that don't
            exist.
        """
        return self._instance.GetSmoothedNeighbors(index)

    @property
    def get_quad_grid_array(self):
        """A 1-dimensional array that represents cell indices and cell
        refinement. The first row by column by layer elements contain UGrid
        indices of the base grid if unrefined. If the cell is refined then
        the element has a negative value that gives the index of refined
        quad indices toward the end of the array. A refined quad is 5
        elements long. The first element is a negative value giving the
        index of the parent base cell or refined quad. For excluded cells
        the element value is QuadGridCells.EXCLUDED.

        Returns:
            A 1-dimensional array describing the cell indices and refinement.
        """
        return self._instance.GetQuadGridArray()

    def set_quad_grid_array(self, quad_array):
        """Set cell refinement using 1-dimensional array. See
        get_quad_grid_array for a description.

        Args:
            quad_array: A 1-dimensional array that describes the cell refinement. See get_quad_grid_array.
        """
        self._instance.SetQuadGridArray(quad_array)

    def _require_valid_refine_coordinates(self, i, j, k):
        invalid = i < 0 or i >= self.number_of_refined_rows
        invalid = invalid or j < 0 or j >= self.number_of_refined_columns
        invalid = invalid or k < 0 or k >= self.number_of_base_layers
        if invalid:
            raise ValueError('Invalid refined grid coordinates.')
