"""Ch3dGridReader class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules
import numpy as np
import pandas as pd

# 3. Aquaveo modules
from xms.grid.ugrid import UGrid as XmUGrid

# 4. Local modules
from xms.tool.algorithms.ugrids import curvilinear_grid_ij as cgij
from xms.tool.file_io.curvilinear.curvilinear_reader_base import CurvilinearReaderBase as ReaderBase


class Ch3dGridReader(ReaderBase):
    """Reader for CH3D curvilinear grid files."""

    def __init__(self, filename, grid_name, logger):
        """Initializes the class.

        Args:
            filename (str): Path to the CH3D curvilinear grid file
            grid_name (str): Optional user input for output grid name. If not
                specified, will try to read from file.
            logger (logging.Logger): The tool logger
        """
        super().__init__(grid_name, logger)
        self._filename = filename
        self._numi = 0
        self._numj = 0
        self._lookup = {}

    def _read_locations(self):
        """Parse the grid corner nodes and their i-j coordinates from the file.

        Returns:
            pd.DataFrame: The grid point location DataFrame, with i-j coordinates as the MultiIndex
        """
        self.logger.info('Reading grid definition from file...')
        # read the number of rows and columns
        with open(self._filename) as f:
            # skip the first line
            line = f.readline()
            line = f.readline()
            numi, numj = (int(val) for val in line.split())

        # build the i and j indices for all the points in the file
        i_list = list(range(1, numi + 1)) * numj
        temp_list = list(range(1, numj + 1))
        j_list = list(np.repeat(temp_list, numi))
        # create the dataframe
        df = pd.read_csv(self._filename, sep='\\s+', skiprows=[0, 1], header=None, names=['x', 'y', 'z'],
                         na_values=cgij.NULL_COORD, usecols=[0, 1, 2], dtype={"x": float, "y": float, "z": float})
        df['z'] *= -1  # Invert Z depth to an elevation
        df['i'] = i_list
        df['j'] = j_list
        # add the MultiIndex
        df = df.set_index(['i', 'j'])
        return df

    def _find_grid_name(self):
        """Try to read the grid name from the file header if user didn't specify an output grid name."""
        if not self.grid_name:  # Output grid name is an optional user input
            try:
                self.logger.info('Reading grid name from file...')
                with open(self._filename, 'r') as f:
                    self.grid_name = f.readline().strip()
            except Exception as e:
                raise RuntimeError('Error reading CH3D grid name:\n') from e

    def _build_ch3d_grid(self, df):
        """Build the XmUGrid definition from imported CH3D file data.

        Args:
            df (DataFrame): The grid data read from the file.
        """
        self._numi = df.index.levshape[0]
        self._numj = df.index.levshape[1]
        self._extract_ch3d_locations(df)
        self._build_ch3d_quads()

    def _get_neighbors(self, i, j):
        """Get the neighbor points of a given point.

        Args:
            i (int): i-coordinate of the point to get neighbors of
            j (int): j-coordinate of the point to get neighbors of

        Returns:
            tuple(list, list): The point's neighbors [top, left, bottom, right], and a parallel array of the neighbors'
            i-j coordinates. Returns (None, None) if the cell cannot be created
        """
        neighbors = [
            self._lookup.get(i, {}).get(j + 1, cgij.NULL_POINT),  # Neighbor above (j-direction)
            self._lookup.get(i - 1, {}).get(j, cgij.NULL_POINT),  # Neighbor to the left (i-direction)
            self._lookup.get(i, {}).get(j - 1, cgij.NULL_POINT),  # Neighbor below (j-direction)
            self._lookup.get(i + 1, {}).get(j, cgij.NULL_POINT),  # Neighbor to the right (i-direction)
        ]
        ij_coords = [(i, j + 1), (i - 1, j), (i, j - 1), (i + 1, j)]
        return neighbors, ij_coords

    def _extract_ch3d_locations(self, df):
        """Gets the coordinates of all the active points in the grid and builds a mask of inactive i-j locations.

        Args:
            df (DataFrame): The grid data read from the file.
        """
        # First we need to add all the defined point locations. These are corner nodes in CH3D format. Build a mapping
        # of i-j coordinates of the active points to their 0-based index in the XmUGrid point list.
        self.logger.info('Creating active point locations...')
        for row in df.itertuples():  # itertuples() was much faster than iterrows()
            i = row.Index[0]
            j = row.Index[1]
            imap = self._lookup.setdefault(i, {})
            if np.isnan(row.x):
                imap[j] = cgij.NULL_POINT  # This point will not be included in the grid
            else:
                imap[j] = len(self._points)
                self._points.append((row.x, row.y, row.z))

    def _build_ch3d_quads(self):
        """Create quad cells from the corner node coordinates and i-j order specified in the CH3D file."""
        # Loop through all the points, creating any possible quads from its neighbors. March right in the i-direction,
        # and up in the j-direction.
        self.logger.info('Generating quad cells from point neighbors...')
        for i in range(1, self._numi + 1):
            imap = self._lookup.get(i, {})
            for j in range(1, self._numj + 1):
                pt_index = imap.get(j, cgij.NULL_POINT)
                if pt_index == cgij.NULL_POINT:
                    continue  # This location is inactive or has already been processed.
                if self._add_ch3d_quad_for_point(pt_index, i, j):
                    # If we added a quad, add its i-j coordinates to the output datasets. Since we are marching in
                    # ascending i-j order and building the quads from the bottom left corner, the point i-j should
                    # match the cell i-j for any point we process that ends up adding a quad to the cellstream.
                    self._i_values.append(i)
                    self._j_values.append(j)

    def _add_ch3d_quad_for_point(self, pt_index, i, j):
        """Create a quad cell with the specified point in the bottom left corner, if possible.

        Args:
            pt_index (int): 0-based point index in the XmUGrid point list
            i (int): i-coordinate of the point
            j (int): j-coordinate of the point

        Returns:
            bool: True if we added a quad cell, False if no valid quad was made
        """
        neighbor_indices, neighbor_ij = self._get_neighbors(i, j)
        neighbor_index = neighbor_indices[cgij.NEIGHBOR_LOC_RIGHT]
        added_cell = False
        if neighbor_index != cgij.NULL_POINT:
            ni, nj = neighbor_ij[cgij.NEIGHBOR_LOC_RIGHT]
            cell_pts = self._find_quad_cell_points(pt_index, neighbor_indices, ni, nj)
            # If we made a valid quad cell, add it to the cellstream.
            if all([cell_pt != cgij.NULL_POINT for cell_pt in cell_pts]):
                self._cellstream.extend([XmUGrid.cell_type_enum.QUAD, 4, *cell_pts])
                added_cell = True
        self._lookup[i][j] = cgij.NULL_POINT
        return added_cell

    def _find_quad_cell_points(self, origin_index, origin_neighbors, ni, nj):
        """Initialize the quad cell point array given an origin point and its neighbor to the right in the i-direction.

        Args:
            origin_index (int): The origin point's 0-based index from the XmUGrid point list
            origin_neighbors (list[int]): The origin point's neighbor point indices
            ni (int): i-coordinate of the neighbor point
            nj (int): j-coordinate of the neighbor point

        Returns:
            list[int]: The quad cell point array with the origin point and its neighbor assigned to the correct
            locations. Ordered [LOC_BOTTOM_LEFT, LOC_BOTTOM_RIGHT, LOC_TOP_RIGHT, LOC_TOP_LEFT] - CCW order
        """
        cell_pts = [cgij.NULL_POINT] * 4
        cell_pts[cgij.LOC_BOTTOM_LEFT] = origin_index  # Origin point is always the bottom left corner of the quad
        cell_pts[cgij.LOC_BOTTOM_RIGHT] = self._lookup[ni][nj]  # Always march to the right in the i-direction
        # Find the other two points of the quad (the points above the origin and its neighbor in the j-direction).
        neighbor_neighbors, _ = self._get_neighbors(ni, nj)
        cell_pts[cgij.LOC_TOP_LEFT] = origin_neighbors[cgij.NEIGHBOR_LOC_TOP]
        cell_pts[cgij.LOC_TOP_RIGHT] = neighbor_neighbors[cgij.NEIGHBOR_LOC_TOP]
        return cell_pts

    def read(self):
        """Import a curvilinear grid from a CH3D formatted file."""
        self._find_grid_name()
        df = self._read_locations()
        if df is not None:
            self._build_ch3d_grid(df)
