"""GridWriter class."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules
import logging
from pathlib import Path
from typing import TextIO

# 2. Third party modules

# 3. Aquaveo modules
from xms.grid.geometry import geometry
from xms.grid.ugrid import UGrid

# 4. Local modules


def write(co_grid, mesh_filepath: Path):
    """Writes the grid for HGS."""
    writer = GridWriter(co_grid, mesh_filepath)
    writer.write()


class GridWriter:
    """Writes the grid for HGS."""
    def __init__(self, co_grid, mesh_filepath: Path) -> None:
        """Initializes the class.

        Args:
            co_grid: The grid.
            mesh_filepath (Path): Filepath of .3dm file to be written.
        """
        self._co_grid = co_grid
        self._ugrid = None
        self._mesh_filepath: Path = mesh_filepath
        self._log = logging.getLogger('xms.hgs')
        self._width: int = 8  # Fixed width column size. 8 allows for 100 million cells and nodes

    def write(self) -> None:
        """Writes the grid as a .3dm file and assume it is already numbered the way HGS requires."""
        if not self._co_grid:  # grid can be None if testing
            return

        self._ugrid = self._co_grid.ugrid  # Get ugrid once as it's costly
        self._check_grid_for_hgs_compatibility()
        with self._mesh_filepath.open('w') as file:
            file.write('MESH3D\n')
            self._write_grid_cells(file)
            self._write_grid_points(file)

    def _check_grid_for_hgs_compatibility(self) -> bool:
        """Checks that the co_grid is stacked, vertically prismatic, and all cells are either hexes or prisms."""
        errors = False
        self._ugrid.calculate_cell_ordering()
        if self._ugrid.get_cell_ordering() != UGrid.cell_ordering_enum.CELL_ORDER_INCREASING_UP:
            errors = True
            self._log.error('Grid cell/point numbers must increase from bottom to top.')

        if not self._co_grid.check_all_cells_vertically_prismatic():
            errors = True
            self._log.error('Grid cells are not all vertically prismatic.')

        rv = self._co_grid.check_is_stacked_grid()
        stacked = rv is not None and rv[0] != 0
        if not stacked:
            errors = True
            self._log.error(
                'Grid is not a "stacked" grid (no vertical sub-discretization of layers and the'
                ' horizontal discretization of all layers is the same).'
            )

        all_hex = self._co_grid.check_all_cells_are_of_type(UGrid.cell_type_enum.HEXAHEDRON)
        all_prism = self._co_grid.check_all_cells_are_of_type(UGrid.cell_type_enum.WEDGE)
        if not all_hex and not all_prism:
            errors = True
            self._log.error('Grid cells are either not all hexahedron or not all triangular prisms.')

        return not errors

    def _write_grid_cells(self, file: TextIO) -> None:
        """Writes the grid cells to the .3dm file."""
        cellstream = self._ugrid.cellstream
        locations = self._ugrid.locations
        cell_id = 1
        i = 0  # Index into the cellstream
        while i < len(cellstream):
            cell_type = cellstream[i]
            if cell_type not in {UGrid.cell_type_enum.HEXAHEDRON, UGrid.cell_type_enum.WEDGE}:
                self._log.error(f'Unsupported cell type for cell {cell_id}: "{str(cell_type)}".')
                return

            i += 1  # Skip the cell type number

            # Write the cell type string and id number
            cell_points_count = cellstream[i]
            if cell_points_count == 8:
                file.write(f'E8H {cell_id:>{self._width}}')
            else:  # cell_points_count == 6:
                file.write(f'E6W {cell_id:>{self._width}}')
            i += 1

            # Make wedge points ordered how HGS expects - CCW looking down
            pts = list(cellstream[i:i + cell_points_count])  # Make copy to modify it, and convert tuple to list.
            if cell_points_count == 6:
                area = geometry.polygon_area_2d([locations[pts[0]], locations[pts[1]], locations[pts[2]]])
                if area < 0.0:
                    pts[1], pts[2] = pts[2], pts[1]
                area = geometry.polygon_area_2d([locations[pts[3]], locations[pts[4]], locations[pts[5]]])
                if area < 0.0:
                    pts[4], pts[5] = pts[5], pts[4]

            # Write the cell points
            for j in range(cell_points_count):
                file.write(f' {pts[j] + 1:>{self._width}}')  # + 1 to go from 0-based to 1-based
                i += 1

            file.write(f'{1:>{self._width}}\n')  # Write the material number as always 1
            cell_id += 1

    def _write_grid_points(self, file: TextIO) -> None:
        """Writes the grid points to the .3dm file."""
        locations = self._ugrid.locations
        for i, location in enumerate(locations):
            file.write(f'ND {i + 1:>{self._width}} {location[0]} {location[1]} {location[2]}\n')
