"""EfdcGridWriter class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from collections import deque
from io import StringIO
import math
import os
import shutil

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
try:
    from xms.api.dmi import XmsEnvironment as XmEnv
    xm_env_supported = True
except ImportError:  # pragma no cover - optional import
    xm_env_supported = False
from xms.constraint.ugrid_boundaries import UGridBoundaries
from xms.grid.geometry import geometry as geom

# 4. Local modules


class EfdcGridWriter:
    """Writer for EFDC curvilinear grid files."""
    ACTIVE_CELL = 5
    BORDER_CELL = 9
    MAX_LINE_LEN = 640  # Maximum number of i-columns per j-row in the cell.inp file.

    def __init__(self, dir, dxdy_filename, lxly_filename, cellinp_filename, ugrid, ij_df, cell_elevations, logger):
        """Initializes the class.

        Args:
            dir (str): Path to folder where the files will be saved
            dxdy_filename (str): name for the dxdy.inp file (to be added to dir)
            lxly_filename (str): name for the lxly.inp file (to be added to dir)
            cellinp_filename (str): name for the cell.inp file (to be added to dir)
            ugrid (UGrid): The input ugrid
            ij_df (pd.DataFrame): The cell i-j coordinate dataset
            cell_elevations (list): The cell elevations
            logger (logging.Logger): The tool logger
        """
        self._dxdy_filename = os.path.join(dir, dxdy_filename)
        self._lxly_filename = os.path.join(dir, lxly_filename)
        self._cellinp_filename = os.path.join(dir, cellinp_filename)
        self._ugrid = ugrid
        self._ij_df = ij_df
        self._cell_elevations = cell_elevations
        self._logger = logger
        self._ss = StringIO()
        num_cells = self._ugrid.cell_count
        self._dx = np.zeros(num_cells)
        self._dy = np.zeros(num_cells)
        self._cue = np.zeros(num_cells)
        self._cun = np.zeros(num_cells)
        self._cve = np.zeros(num_cells)
        self._cvn = np.zeros(num_cells)
        self._center = np.zeros((num_cells, 3))
        # Need buffer of 2 cells on all sides. Just for fun, j rows and i cols for this array.
        self._cell_map = np.zeros((self._ij_df.index.levshape[1] + 4, self._ij_df.index.levshape[0] + 4), dtype=int)
        self._major_ticks = ''
        self._minor_ticks = ''

    def _compute_variables(self):
        """Compute Dx, Dy, CCUE, CCVE, CCUN, and CCVN from cell centroids.

        Notes:
            Cell points should be counterclockwise starting with the bottom left corner.
        """
        self._logger.info('Computing Dx, Dy, CCUE, CCUN, CCVE and CCVN from cell centroids...')
        for row in self._ij_df.itertuples():
            cell_idx = row.cell_idx
            cell_orientation = row.cell_orientation
            points = self._compute_centroid(cell_idx)
            # Rotate the points based on the orientation
            num_pts_to_rotate = int(cell_orientation + 2) % 4
            lower_left_points = deque(points)
            if num_pts_to_rotate != 0:
                lower_left_points.rotate(num_pts_to_rotate)

            self._dx[cell_idx], self._dy[cell_idx] = self._compute_dx_dy(cell_idx, lower_left_points)
            self._cue[cell_idx], self._cun[cell_idx], self._cve[cell_idx], self._cvn[cell_idx] = \
                self._compute_rotation_vectors(lower_left_points)
            self._cell_map[row.Index[1] - 1][row.Index[0] - 1] = self.ACTIVE_CELL  # Mark cell as wet/active

    def _compute_centroid(self, cell_idx):
        """Compute the cell centroid and return its point locations.

        Args:
            cell_idx (int): The cellstream index of the cell

        Returns:
            np.ndarray: The points of the cell as x,y,z coordinate tuples
        """
        _, centroid = self._ugrid.get_cell_centroid(cell_idx)
        self._center[cell_idx] = centroid
        points = self._ugrid.get_cell_locations(cell_idx)
        if self._cell_elevations is not None:  # If we have cell elevations, use that
            self._center[cell_idx][-1] = self._cell_elevations[cell_idx]
        else:  # Otherwise, use average point elevation
            self._center[cell_idx][-1] = np.mean(points[:, -1])
        return points

    def _compute_dx_dy(self, cell_idx, points):
        """Compute Dx (di) and Dy (dj) for a cell.

        Args:
            cell_idx (int): The cellstream index of the cell
            points (x,y,z points): The cell points starting from curvilinear lower left

        Returns:
            tuple(float, float) Dx, Dy
        """
        # Compute Dx
        left_mid = (points[0] + points[3]) / 2.0
        right_mid = (points[1] + points[2]) / 2.0
        dx = geom.distance_2d(left_mid, right_mid)
        # Compute Dy
        top_mid = (points[0] + points[1]) / 2.0
        bottom_mid = (points[2] + points[3]) / 2.0
        dy = geom.distance_2d(top_mid, bottom_mid)
        return dx, dy

    def _compute_rotation_vectors(self, points):
        """Compute the rotation vectors for a cell.

        Args:
            points (np.ndarray): The points of the cell as x,y,z coordinate tuples in counterclockwise order starting
                with the bottom left corner.

        Returns:
            tuple(float, float, float, float): CUE, CUN, CVE, CVN
        """
        # Adapted from the old SMS 11.2 C++ code. Go ask AKZ.
        bl = points[0]
        br = points[1]
        tr = points[2]
        tl = points[3]
        dxdi = 0.5 * (tr[0] - tl[0] + br[0] - bl[0])
        dydi = 0.5 * (tr[1] - tl[1] + br[1] - bl[1])
        dxdj = 0.5 * (tr[0] - br[0] + tl[0] - bl[0])
        dydj = 0.5 * (tr[1] - br[1] + tl[1] - bl[1])
        gii = dxdi * dxdi + dydi * dydi
        gjj = dxdj * dxdj + dydj * dydj
        hii = math.sqrt(gii)
        hjj = math.sqrt(gjj)
        cue = dxdi / hii
        cun = dydi / hii
        cve = dxdj / hjj
        cvn = dydj / hjj
        return cue, cun, cve, cvn

    def _find_boundaries(self):
        """Find all the boundary cells and mark their neighbors as border cells.

        Notes:
            Neighbors might actually be active but will get flagged as such later. There needs to be a row/column of
            inactive cells on each side of the grid. All inactive cells that are neighbors to cells on the UGrid
            boundary need to be flagged as border cells.
        """
        self._logger.info('Finding boundary cells...')
        boundary = UGridBoundaries(self._ugrid, target_cells=True)
        loops = boundary.get_loops()
        for loop in loops.values():
            for cell_idx in loop['id']:
                df = self._ij_df.loc[self._ij_df.cell_idx == cell_idx]
                ii = df.index[0][0] - 1  # ij coordinates stored as 1-base in DataFrame Index
                jj = df.index[0][1] - 1
                # Assign all neighboring cells as borders. Buffer rows/columns should prevent out-of-bounds here.
                self._cell_map[jj - 1][ii] = self.BORDER_CELL
                self._cell_map[jj - 1][ii - 1] = self.BORDER_CELL
                self._cell_map[jj - 1][ii + 1] = self.BORDER_CELL
                self._cell_map[jj + 1][ii] = self.BORDER_CELL
                self._cell_map[jj + 1][ii - 1] = self.BORDER_CELL
                self._cell_map[jj + 1][ii + 1] = self.BORDER_CELL
                self._cell_map[jj][ii - 1] = self.BORDER_CELL
                self._cell_map[jj][ii + 1] = self.BORDER_CELL

    def _write_dxdy_header(self):
        """Write the header lines for the dxdy.inp file."""
        self._logger.info('Writing dxdy.inp header...')
        self._ss.write(
            f'C DXDY.INP - GRID CELL SIZE AND IC DATA, {self._ugrid.cell_count} ACTIVE CELLS - '
            f'written by: {self._app_name()} {self._app_version()} (64-bit)\n'
            f'C Project: {os.path.basename(self._project_path())}\n'
            'C                                                   WATER         BOTTOM                           VEG\n'
            'C    I     J             DX             DY          DEPTH           ELEV         ZROUGH           TYPE\n'
        )

    def _write_dxdy_cells(self):
        """Write the cell lines in the dxdy.inp file."""
        self._logger.info('Writing dxdy.inp cells...')
        for row in self._ij_df.itertuples():
            cell_idx = row.cell_idx
            # Don't think these are really fixed-format, just matched my examples.
            self._ss.write(
                f'{row.Index[0]:6d}'                   # I
                f'{row.Index[1]:6d}'                   # J
                f'{self._dx[cell_idx]:15.3f}'          # DX
                f'{self._dy[cell_idx]:15.3f}'          # DY
                f'{row.depth:15.3f}'                   # DEPTH (optional)
                f'{self._center[cell_idx][-1]:15.3f}'  # BOTTOM ELEV
                f'{row.zrough:15.3f}'                  # ZROUGH (optional)
                f'{int(row.veg_type):15d}\n'           # VEG TYPE (optional)
            )

    def _write_lxly_header(self):
        """Write the header lines for the lxly.inp file."""
        self._logger.info('Writing lxly.inp header...')
        self._ss.write(
            f'C LXLY.INP - GRID CELL CENTROID AND ROTATION, {self._ugrid.cell_count} ACTIVE CELLS - '
            f'written by: {self._app_name()} {self._app_version()} (64-bit)\n'
            f'C Project: {os.path.basename(self._project_path())}\n'
            'C                                                                                       WIND\n'
            'C    I     J              X              Y       CUE       CVE       CUN      CVN    SHELTER\n'
        )

    def _write_lxly_cells(self):
        """Write the cell lines in the lxly.inp file."""
        self._logger.info('Writing lxly.inp cells...')
        for row in self._ij_df.itertuples():
            cell_idx = row.cell_idx
            # Don't think these are really fixed-format, just matched my examples.
            self._ss.write(
                f'{row.Index[0]:6d}'                   # I
                f'{row.Index[1]:6d}'                   # J
                f'{self._center[cell_idx][0]:15.3f}'   # X (centroid)
                f'{self._center[cell_idx][1]:15.3f}'   # Y (centroid)
                f'{self._cue[cell_idx]:10.5f}'         # CUE
                f'{self._cve[cell_idx]:10.5f}'         # CVE
                f'{self._cun[cell_idx]:10.5f}'         # CUN
                f'{self._cvn[cell_idx]:10.5f}'         # CVN
                f'{row.wind_shelter:10.2f}\n'          # WIND SHELTER (optional)
            )

    def _set_cellinp_ticks(self, num_i_cols):
        """Set the tick labels used in the annotations.

        Args:
            num_i_cols (int): The number of i levels in the grid (usually i means rows but they are columns in cell.inp)
        """
        ii = 1
        num_i_cols = min(num_i_cols, self.MAX_LINE_LEN)
        while len(self._major_ticks) < num_i_cols:
            self._major_ticks += '0        1' if ii == 1 else f'{ii:10d}'
            self._minor_ticks += '1234567890'
            ii += 1

    def _write_cellinp_header(self):
        """Write the header lines for the cell.inp file."""
        self._logger.info('Writing cell.inp header...')
        num_j_rows = self._cell_map.shape[0]
        num_i_cols = self._cell_map.shape[1]
        self._set_cellinp_ticks(num_i_cols)
        self._ss.write(
            f'C LXLY.INP - GRID CELL CENTROID AND ROTATION, {num_i_cols} I COLUMNS AND {num_j_rows} J ROWS - '
            f'written by: {self._app_name()} {self._app_version()} (64-bit)\n'
            f'C Project: {os.path.basename(self._project_path())}\n'
            f'C    {self._major_ticks}\n'
            f'C    {self._minor_ticks}\n'
        )

    def _write_cellinp_footer(self):
        """I don't know, SMS 11.2 thought this was useful."""
        self._logger.info('Writing cell.inp footer...')
        self._ss.write(
            'C\n'
            f'C    {self._minor_ticks}\n'
            f'C    {self._major_ticks}\n'
            'C cell codes\n'
            'C 0 - dry land cell not bordering a water cell on a side or corner\n'
            'C 1 - triangular water cell with land to the northeast\n'
            'C 2 - triangular water cell with land to the southeast\n'
            'C 3 - triangular water cell with land to the southwest\n'
            'C 4 - triangular water cell with land to the northwest\n'
            'C 5 - quadrilateral water cell\n'
            'C 9 - dry land cell bordering a water cell on a side or corner or\n'
            'C       a fictitious dry land cell bordering an open boundary cell\n'
            'C       on a side or a corner\n'
        )

    def _write_cellinp_cells(self):
        """Write the cell lines in the lxly.inp file."""
        self._logger.info('Writing cell.inp cells...')
        # EFDC can only read 640 i-column integers per line.
        chunk_indices = [i * self.MAX_LINE_LEN for i in range(1, int(self._cell_map.shape[1] / self.MAX_LINE_LEN) + 1)]
        chunks = np.split(self._cell_map, chunk_indices, axis=1)
        for chunk in chunks:
            if chunk.size > 0:
                j_rows = [''.join(row) for row in chunk.astype(str)]
                for j in range(self._cell_map.shape[0] - 1, -1, -1):
                    self._ss.write(f'{j + 1:4d} {j_rows[j]}\n')

    def _write_dxdy_inp(self):
        """Write the dxdy.inp file."""
        self._logger.info('Writing dxdy.inp file...')
        self._write_dxdy_header()
        self._write_dxdy_cells()
        self._flush(self._dxdy_filename)

    def _write_lxly_inp(self):
        """Write the lxly.inp file."""
        self._logger.info('Writing lxly.inp file...')
        self._write_lxly_header()
        self._write_lxly_cells()
        self._flush(self._lxly_filename)

    def _write_cell_inp(self):
        """Write the cell.inp file."""
        self._logger.info('Writing cell.inp file...')
        self._write_cellinp_header()
        self._write_cellinp_cells()
        self._write_cellinp_footer()
        self._flush(self._cellinp_filename)

    def _flush(self, filename):
        """Flush string stream to disk."""
        self._logger.info('Flushing to disk.')
        with open(filename, 'w') as f:
            self._ss.seek(0)
            shutil.copyfileobj(self._ss, f, 100000)
        self._ss = StringIO()  # Clear stream for the next file

    def _app_name(self):
        """Get the XMS application name."""
        if xm_env_supported:
            return XmEnv.xms_environ_app_name()
        return ''

    def _app_version(self):
        """Get the XMS application version."""
        if xm_env_supported:
            return XmEnv.xms_environ_app_version()
        return ''

    def _project_path(self):
        """Get the XMS project path."""
        if xm_env_supported:
            return XmEnv.xms_environ_project_path()
        return ''

    def write(self):
        """Write the grid.inp file."""
        self._find_boundaries()
        self._compute_variables()
        self._write_dxdy_inp()
        self._write_lxly_inp()
        self._write_cell_inp()
