"""Ch3dGridWriter class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from collections import deque
from io import StringIO
import shutil

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules

# 4. Local modules
from xms.tool.algorithms.ugrids import curvilinear_grid_ij as cgij


class Ch3dGridWriter:
    """Writer for CH3D curvilinear grid files."""

    def __init__(self, filename, ugrid, ij_df, grid_name, logger):
        """Initializes the class.

        Args:
            filename (str): Path to the CH3D curvilinear grid file
            ugrid (UGrid): The input ugrid
            ij_df (pd.DataFrame): The cell i-j coordinate dataset
            grid_name (str): Optional user input for output grid name. If not specified, will try to read from file.
            logger (logging.Logger): The tool logger
        """
        self._filename = filename
        self._ugrid = ugrid
        self._ij_df = ij_df
        self._grid_name = grid_name
        self._logger = logger
        self._ss = StringIO()
        self._num_i = self._ij_df.index.levshape[0] + 1
        self._num_j = self._ij_df.index.levshape[1] + 1
        self._locations = np.full((self._num_i, self._num_j, 3), cgij.NULL_COORD)
        self._locations[:, :, 2] *= -1.0  # Z will get inverted when writing cell locations

    def _create_buffer(self):
        """Create imaginary cells to define the boundary points and to fill out a jagged grid.

        Notes:
            We define cells in the CH3D format by their bottom left point. This means that we need to create imaginary
            cells at the boundary such that each boundary point is the bottom left point of an imaginary cell. CH3D also
            requires the entire i-j grid, i.e. for each i row there needs to be j number of cells. Create inactive cells
            for all i-j coordinates that fall outside the bounds of our UGrid and the boundary buffer.
        """
        self._logger.info('Creating buffer cells at UGrid boundaries...')
        for i in range(self._num_i):
            cell_i = i + 1  # 1-based cell i-j coordinates (index in the DataFrame)
            for j in range(self._num_j):
                cell_j = j + 1
                index = (cell_i, cell_j)
                corner_loc = None
                if index in self._ij_df.index:  # We have a UGrid cell with these i-j coordinates assigned
                    # Create a boundary buffer cell if this cell is on our UGrid boundary.
                    row = self._ij_df.loc[index]
                    cell_locs = self._ugrid.get_cell_locations(row.cell_idx.astype(int)).tolist()
                    # Rotate the cell locations based on the orientation
                    cell_orientation = row.cell_orientation
                    num_pts_to_rotate = int(cell_orientation + 2) % 4
                    lower_left_cell_locs = deque(cell_locs)
                    if num_pts_to_rotate != 0:
                        lower_left_cell_locs.rotate(num_pts_to_rotate)
                    self._buffer_cell(cell_i, cell_j, lower_left_cell_locs)
                    corner_loc = lower_left_cell_locs[0]  # Only need to store the first one
                if corner_loc is not None:
                    self._locations[i][j] = corner_loc

    def _buffer_cell(self, cell_i, cell_j, cell_locs):
        """Create buffer cells to define the points of our boundary cells.

        Args:
            cell_i (int): i-coordinate of the cell (1-based)
            cell_j (int): j-coordinate of the cell (1-based)
            cell_locs (list): The points of the cell. Should be in CCW order, starting at the bottom left.
        """
        i = cell_i - 1  # 0-based arrays, 1-based i-j coordinates (index in the DataFrame)
        j = cell_j - 1
        # Check if there is a neighbor to our right
        if cell_i + 1 <= self._num_i and (cell_i + 1, cell_j) not in self._ij_df.index:
            i_row = self._locations[i + 1]
            i_row[j] = cell_locs[1]  # Make sure our bottom left point defines a buffer cell
            i_row[j + 1] = cell_locs[2]  # Make sure our bottom right point defines a buffer cell
        # Check if there is a neighbor above us
        if cell_j + 1 <= self._num_j and (cell_i, cell_j + 1) not in self._ij_df.index:
            j_column = self._locations[i]
            j_column[j + 1] = cell_locs[3]  # Make sure our top left point defines a buffer cell
            j_column = self._locations[i + 1]
            j_column[j + 1] = cell_locs[2]  # Make sure our top right point defines a buffer cell

    def _write_cell(self, i, j, cell_loc):
        """Write a single cell row to the file.

        Args:
            i (int): i-coordinate of the cell
            j (int): j-coordinate of the cell
            cell_loc (list): x,y,z coordinate tuple of the bottom left corner of the cell
        """
        self._ss.write(
            f'{cell_loc[0]:28.6f}'         # BL corner x
            f'{cell_loc[1]:28.6f}'         # BL corner y
            f'{cell_loc[2] * -1.0:28.6f}'  # BL corner z, invert elevation
            f'{i:6d}'                      # cell i
            f'{j:6d}\n'                    # cell j
        )  # There is an optional sixth column we don't currently support. Think it's a material type.

    def _write_locations(self):
        """Write the cell location lines."""
        self._logger.info('Exporting locations...')
        for j in range(self._num_j):
            for i in range(self._num_i):
                self._write_cell(i + 1, j + 1, self._locations[i][j])

    def _write_header(self):
        """Write the file header."""
        self._logger.info('Exporting header...')
        # Not sure how fixed format these files need to be, but I just matched my examples.
        self._ss.write(f'{self._grid_name}\n{self._num_i:6d}{self._num_j:6d}\n')

    def _flush(self):
        """Flush string stream to disk."""
        import os
        self._logger.info('Flushing to disk.')
        os.makedirs(os.path.dirname(self._filename), exist_ok=True)
        with open(self._filename, 'w') as f:
            self._ss.seek(0)
            shutil.copyfileobj(self._ss, f, 100000)

    def write(self):
        """Write the grid.inp file."""
        self._create_buffer()
        self._write_header()
        self._write_locations()
        self._flush()
