"""UGridBuilder class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import math
import uuid

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.tree import tree_util
from xms.constraint import Grid, Orientation, RectilinearGridBuilder, UGrid3d
from xms.constraint.modflow import build_stacked_grid_from_2d
from xms.data_objects.parameters import Projection, UGrid as DoGrid
from xms.grid.ugrid import UGrid

# 4. Local modules
from xms.mf6.data.array import Array
from xms.mf6.data.base_file_data import BaseFileData
from xms.mf6.data.mfsim_data import MfsimData
from xms.mf6.data.model_data_base import ModelDataBase
from xms.mf6.file_io import io_util
from xms.mf6.geom.disu_ugrid_stream_builder import DisuUgridStreamBuilder
from xms.mf6.misc import log_util


def _same_none(a, b) -> True:
    """Return True if both a and b are None, or both are not None.

    Args:
        a: An object.
        b: Another object.

    Returns:
        See description.
    """
    return (a is None and b is None) or (a is not None and b is not None)


class DisGeom:
    """Geom from the DIS package used to build the UGrid."""
    def __init__(
        self,
        xorigin: float,
        yorigin: float,
        angrot: float,
        delr_list: list[float],
        delc_list: list[float],
        tops: list[float],
        bots: list[float],
        model_on_off_cells: list[int] | None = None
    ):
        """Initializes the class.

        Args:
            xorigin: x origin.
            yorigin: y origin.
            angrot: angle of rotation.
            delr_list: column spacing in the row direction.
            delc_list: row spacing in the column direction.
            tops: Top elevations.
            bots: Bottom elevations.
            model_on_off_cells: IDOMAIN.
        """
        self.xorigin = xorigin
        self.yorigin = yorigin
        self.angrot = angrot
        self.delr_list = delr_list
        self.delc_list = delc_list
        self.tops = tops
        self.bots = bots
        self.model_on_off_cells = model_on_off_cells

    def __key(self) -> tuple:
        """Returns something that can be passed to hash().

        Returns:
            (tuple): See description.
        """
        model_on_off_cells = tuple(self.model_on_off_cells) if self.model_on_off_cells is not None else None
        return (
            self.xorigin, self.yorigin, self.angrot, tuple(self.delr_list), tuple(self.delc_list), tuple(self.tops),
            tuple(self.bots), model_on_off_cells
        )

    def __hash__(self) -> int:
        """Returns a hash key.

        Returns:
            (int): A hash key
        """
        hash_key = hash(self.__key())
        return hash_key

    def __eq__(self, other) -> bool:
        """Return True if the other is equal to self.

        Args:
            other: Some object.

        Returns:
            (bool): See description.
        """
        if self is other:
            return True

        if not isinstance(other, DisGeom):
            return NotImplemented

        if not _same_none(self.model_on_off_cells, other.model_on_off_cells):
            return False

        my_key, other_key = self.__key(), other.__key()
        return my_key == other_key


class DisvDisuGeom:
    """Geom from the DISV/DISU packages used to build the UGrid."""
    def __init__(
        self,
        xorigin: float,
        yorigin: float,
        angrot: float,
        tops: list[float],
        bots: list[float],
        centers2d: list[list[float]],
        icvert: list[list[int]],
        locations2d: list[list[float]],
        model_on_off_cells: list[int] | None = None
    ):
        """Initializes the class.

        Args:
            xorigin: x origin.
            yorigin: y origin.
            angrot: angle of rotation.
            tops: Top elevations.
            bots: Bottom elevations.
            centers2d: List of points.
            icvert: array of vertex numbers (in the VERTICES block) used to define the cell.
            locations2d: List of points.
            model_on_off_cells: IDOMAIN.
        """
        self.xorigin = xorigin
        self.yorigin = yorigin
        self.angrot = angrot
        self.tops = tops
        self.bots = bots
        self.centers2d = centers2d
        self.icvert = icvert
        self.locations2d = locations2d
        self.model_on_off_cells = model_on_off_cells

    def __key(self) -> tuple:
        """Returns something that can be passed to hash().

        Returns:
            (tuple): See description.
        """
        centers2d = tuple([tuple(center) for center in self.centers2d])
        icvert = tuple(tuple(icvert_list) for icvert_list in self.icvert)
        locations2d = tuple(tuple(location) for location in self.locations2d)
        model_on_off_cells = tuple(self.model_on_off_cells) if self.model_on_off_cells is not None else None
        return (
            self.xorigin, self.yorigin, self.angrot, tuple(self.tops), tuple(self.bots), centers2d, icvert, locations2d,
            model_on_off_cells
        )

    def __hash__(self) -> int:
        """Returns a hash key.

        Returns:
            (int): A hash key
        """
        hash_key = hash(self.__key())
        return hash_key

    def __eq__(self, other) -> bool:
        """Return True if the other is equal to self.

        Args:
            other: Some object.

        Returns:
            (bool): See description.
        """
        if self is other:
            return True

        if not isinstance(other, DisvDisuGeom):
            return NotImplemented

        if not _same_none(self.model_on_off_cells, other.model_on_off_cells):
            return False

        my_key, other_key = self.__key(), other.__key()
        return my_key == other_key


def geom_from_dis(dis_package) -> DisGeom | None:
    """Returns the geom needed to build the grid from the dis package.

    Args:
        dis_package (DisData): DIS package data

    Returns:
        (DisGeom): The geom.
    """
    try:
        xorigin, yorigin, angrot = _read_origin_and_angle(dis_package)
        # x, y locations for the grid
        grid_info = dis_package.grid_info()
        griddata = dis_package.block('GRIDDATA')
        delr_list = _dis_locations_from_offsets(griddata.array('DELR'), grid_info.ncol)
        delc_list = _dis_locations_from_offsets(griddata.array('DELC'), grid_info.nrow, reverse=True)
        tops, bots = _elevations_from_dis(dis_package)
        model_on_off_cells = _model_on_off_from_dis(dis_package)
        geom = DisGeom(xorigin, yorigin, angrot, delr_list, delc_list, tops, bots, model_on_off_cells)
    except Exception as ex:
        print(ex)
        return None
    return geom


def geom_from_disv_disu(dis_package, transform: bool) -> DisvDisuGeom | None:
    """Returns the geom needed to build the grid from the disv package.

    Args:
        dis_package (DisvData or DisuData): DISV or DISU package data
        transform: True to transform the points using the origin and rotation.

    Returns:
        (DisvDisuGeom): POD class with the geom.
    """
    if 'VERTICES' not in dis_package.list_blocks or 'CELL2D' not in dis_package.list_blocks:
        # "If NVERT is not specified or is specified as zero, then the VERTICES and CELL2D
        # blocks below are not read."
        log = log_util.get_logger()
        log.warning('VERTICES or CELL2D block missing from file. Could not create UGrid.')
        return None

    try:
        xorigin, yorigin, angrot = _read_origin_and_angle(dis_package)
        locations2d = _read_vertices_block(dis_package.list_blocks['VERTICES'])
        centers2d, icvert = _read_cell2d_block(dis_package.list_blocks['CELL2D'])
        tops, bots = _elevations_from_dis(dis_package)
        model_on_off_cells = _model_on_off_from_dis(dis_package)
        geom = DisvDisuGeom(xorigin, yorigin, angrot, tops, bots, centers2d, icvert, locations2d, model_on_off_cells)

        # Transform the points if needed
        if transform:
            _transform(geom)

    except Exception as ex:
        print(ex)
        return None
    return geom


class UGridBuilder:
    """Creates UGrids from DIS, DISV, or DISU package files."""
    def __init__(self, binary_arrays=True):
        """Initializes the class.

        Args:
            binary_arrays (bool): If True, UGrid files are written with binary array data.
        """
        self._binary_arrays = binary_arrays
        self._dogrids_dict = {}  # Hash table of data_object UGrids so models can share a UGrid.

    def build(self, mfsim: MfsimData, model: ModelDataBase | None = None) -> list[DoGrid]:
        """Build UGrids from DIS, DISV, or DISU, adds them to the model, and returns a list of grids created.

        Args:
            mfsim: modflow simulation
            model: The model

        Returns:
            List of data_object grids.
        """
        dogrids = []
        for curr_model in mfsim.models:
            if model and curr_model != model:
                continue

            if curr_model.get_dogrid() is None:
                dis = curr_model.get_dis()
                if not dis or dis.ftype not in ['DIS6', 'DISV6', 'DISU6']:
                    continue

                dogrid, cogrid = self.build_dis_grid(dis)
                if dogrid:
                    # Set the projection
                    dogrids.append(dogrid)
                    projection = _get_projection(model)
                    if projection:
                        dogrid.projection = projection

                    # Associate with curr_model
                    curr_model.set_dogrid(dogrid)
                    curr_model.set_cogrid(cogrid)
        return dogrids

    def build_dis_grid(self, dis: BaseFileData) -> tuple[DoGrid | None, Grid | None]:
        """Build a data_objects UGrid and a CoGrid from a DIS package.

        Args:
            dis (DisData): The DIS package data

        Returns:
            (tuple[DoGrid, CoGrid]): See description
        """
        dogrid = None
        cogrid = None
        if dis.ftype == 'DIS6':
            dogrid, cogrid = self._build_dis_ugrid(dis)
        elif dis.ftype == 'DISV6':
            dogrid, cogrid = self._build_disv_ugrid(dis)
        elif dis.ftype == 'DISU6':
            dogrid, cogrid = self._build_disu_ugrid(dis)
        return dogrid, cogrid

    def set_dis_grid_as_existing(self, tree_node, dis):
        """Add an existing DIS package's UGrid to the list of known UGrids.

        Args:
            tree_node (TreeNode): The tree node.
            dis (DisData): The DIS package data
        """
        if not tree_node:
            return  # Not going to be able to find a linked UGrid
        geom = geom_from_dis(dis) if dis.ftype == 'DIS6' else geom_from_disv_disu(dis, transform=False)
        dogrid = self._dogrids_dict.get(geom)
        if not dogrid:  # Add this UGrid if we haven't encountered its geometry before
            link_item = tree_util.descendants_of_type(  # Find the linked UGrid of the model (DIS item's parent)
                tree_node.parent, xms_types=['TI_UGRID_PTR'], allow_pointers=True, only_first=True
            )
            if link_item:  # Make a dummy data_object that just has the UUID of the existing UGrid so we can link it
                self._dogrids_dict[geom] = DoGrid('', uuid=link_item.uuid)

    def _build_dis_ugrid(self, dis_package) -> tuple[DoGrid | None, Grid | None]:
        """Create a structured UGrid from a DIS package.

        Args:
            dis_package (DisData): DIS package data

        Returns:
            The data_objects UGrid.
        """
        geom = geom_from_dis(dis_package)
        # If we already encountered an identical DIS package, use that UGrid.
        dogrid = self._dogrids_dict.get(geom)
        if dogrid:
            return dogrid, None

        # z locations, this will be overridden by the top, bottom elevations
        z_loc = [0.0]
        z_loc.extend([(i + 1) * 10 for i in range(dis_package.grid_info().nlay)])

        # build the UGrid
        builder = RectilinearGridBuilder()
        builder.is_3d_grid = True
        builder.origin = (geom.xorigin, geom.yorigin, 0.0)
        builder.orientation = (Orientation.y_decrease, Orientation.x_increase, Orientation.z_decrease)
        builder.angle = geom.angrot
        builder.locations_x = geom.delr_list
        builder.locations_y = geom.delc_list
        builder.locations_z = z_loc
        cogrid = builder.build_grid()
        cogrid.set_cell_tops_and_bottoms(geom.tops, geom.bots)
        if geom.model_on_off_cells:
            cogrid.model_on_off_cells = geom.model_on_off_cells
        temp_filename = io_util.get_temp_filename(folder=io_util.get_xms_temp_directory(), suffix='.xmc')
        cogrid.write_to_file(temp_filename, self._binary_arrays)
        dogrid = _dogrid_from_file(temp_filename)
        self._dogrids_dict[geom] = dogrid
        return dogrid, cogrid

    def _build_disv_ugrid(self, disv_package) -> tuple[DoGrid | None, Grid | None]:
        """Create an unstructured UGrid from a DISV package.

        Args:
            disv_package (DisvData): DISV package data

        Returns:
            The data_objects UGrid.
        """
        geom, dogrid = self._transformed_builder_geom_from_disv_disu(disv_package)
        if not geom:
            return None, None
        # If we already encountered an identical DISV package, use that UGrid.
        if dogrid:
            return dogrid, None

        # Build the cogrid
        cogrid = build_stacked_grid_from_2d(geom.locations2d, geom.icvert, disv_package.grid_info().nlay)
        cogrid.set_cell_tops_and_bottoms(geom.tops, geom.bots)
        cogrid.cell_centers = geom.centers2d * disv_package.grid_info().nlay
        if geom.model_on_off_cells:
            cogrid.model_on_off_cells = geom.model_on_off_cells

        # Save to a temp file, make the UGrid and return it
        temp_filename = io_util.get_temp_filename(folder=io_util.get_xms_temp_directory(), suffix='.xmc')
        cogrid.write_to_file(temp_filename, self._binary_arrays)
        dogrid = _dogrid_from_file(temp_filename)
        self._dogrids_dict[geom] = dogrid
        return dogrid, cogrid

    def _build_disu_ugrid(self, disu_package) -> tuple[DoGrid | None, Grid | None]:
        """Create an unstructured UGrid from a DISU package.

        Args:
            disu_package (DisuData): DISU package data

        Returns:
            The data_objects UGrid.
        """
        geom, dogrid = self._transformed_builder_geom_from_disv_disu(disu_package)
        if not geom:
            return None, None
        # If we already encountered an identical DIS package, use that UGrid.
        if dogrid:
            return dogrid, None

        # Build the cogrid
        connectiondata = disu_package.block('CONNECTIONDATA')
        iac = connectiondata.array('IAC').get_values()
        ja = connectiondata.array('JA').get_values()
        ihc = connectiondata.array('IHC').get_values()

        disu_to_ugrid = DisuUgridStreamBuilder(geom.locations2d, geom.icvert, geom.tops, geom.bots, iac, ja, ihc)
        disu_to_ugrid.build_ugrid_stream()
        ugrid = UGrid(disu_to_ugrid.ug_cell_pts, disu_to_ugrid.ug_cell_stream)
        cogrid = UGrid3d(ugrid=ugrid)
        cogrid.cell_centers = geom.centers2d
        cogrid.set_cell_tops_and_bottoms(geom.tops, geom.bots)
        if geom.model_on_off_cells:
            cogrid.model_on_off_cells = geom.model_on_off_cells

        # Save to a temp file, make the UGrid and return it
        temp_filename = io_util.get_temp_filename(folder=io_util.get_xms_temp_directory(), suffix='.xmc')
        cogrid.write_to_file(temp_filename, self._binary_arrays)
        dogrid = _dogrid_from_file(temp_filename)
        self._dogrids_dict[geom] = dogrid
        return dogrid, cogrid

    def _transformed_builder_geom_from_disv_disu(self, dis_package) -> tuple[DisvDisuGeom | None, DoGrid | None]:
        """Returns the geom needed to build the grid from the disv package.

        Args:
            dis_package (DisvData or DisuData): DISV or DISU package data

        Returns:
            (tuple): tuple containing:
                - (DisvDisuGeom): The geom.
                - The data_objects UGrid or None.
        """
        geom = geom_from_disv_disu(dis_package, transform=False)
        if not geom:
            return None, None
        dogrid = self._dogrids_dict.get(geom)
        if dogrid:
            return geom, dogrid

        # Transform the points if needed
        _transform(geom)
        return geom, None


def _get_projection(model: ModelDataBase) -> Projection | None:
    """Returns the projection object, only if we found a projection (.prj file) when reading.

    Args:
        model: The model.

    Returns:
        data_objects.parameters.Projection
    """
    projection = None
    if model and model.projection_wkt != '':
        # Create the projection from the well known text saved with the model when we read the model/dis* package
        projection = Projection(wkt=model.projection_wkt)
    # We used to set the horizontal units in the projection if they specified the length units in DIS*, but
    # if the length units are "FEET", how do we know if they mean "FEET (U.S. SURVEY)" or "FEET (INTERNATIONAL)". If
    # there's existing data in GMS, and the display projection doesn't match what we pick here, the grid gets
    # reprojected and then the DIS* package doesn't match the UGrid in dis_grid_comparer.py. So I'm commenting this out.
    # elif dis_package:
    #     length_units = dis_package.options_block.get('LENGTH_UNITS')
    #     if length_units:
    #         # I went with "FEET (INTERNATIONAL)" instead of "FEET (U.S. SURVEY)" for no particular reason
    #         units_dict = {'FEET': 'FEET (INTERNATIONAL)', 'METERS': 'METERS', 'CENTIMETERS': 'CENTIMETERS'}
    #         projection_units = units_dict.get(length_units.upper())
    #         if projection_units:
    #             projection = Projection()
    #             projection.horizontal_units = projection_units
    return projection


def _read_origin_and_angle(data):
    """Reads the XORIGIN, YORIGIN, ANGROT options and returns them as a tuple.

    Args:
        data:

    Returns:
        (tuple): tuple containing:
            - xorigin (float): X origin.
            - yorigin (float): Y origin.
            - angrot (float): Angle of rotation counterclockwise from X axis in degrees.
    """
    xorigin = float(data.options_block.get('XORIGIN', default=0.0))
    yorigin = float(data.options_block.get('YORIGIN', default=0.0))
    angrot = float(data.options_block.get('ANGROT', default=0.0))
    return xorigin, yorigin, angrot


def _rotate(origin, point, angle):
    """Rotate a point counterclockwise by a given angle around a given origin.

    The angle should be given in radians. See https://stackoverflow.com/questions/34372480
    """
    ox, oy = origin
    px, py = point

    qx = ox + math.cos(angle) * (px - ox) - math.sin(angle) * (py - oy)
    qy = oy + math.sin(angle) * (px - ox) + math.cos(angle) * (py - oy)
    return [qx, qy]


def _transform(geom: DisvDisuGeom) -> None:
    """Transform the points if necessary."""
    if geom.xorigin != 0.0 or geom.yorigin != 0.0 or geom.angrot != 0.0:
        _transform_points(geom.xorigin, geom.yorigin, geom.angrot, geom.locations2d, geom.centers2d)


def _transform_points(xorigin: float, yorigin: float, angrot: float, locations2d, centers2d) -> None:
    """Transform the points by translating and rotating.

    Args:
        xorigin: X value of origin.
        yorigin: Y value of origin.
        angrot: Angle of rotation.
        locations2d: List of points.
        centers2d: List of points.
    """
    for i in range(len(locations2d)):
        locations2d[i] = _rotate((0, 0), locations2d[i], math.radians(angrot))
        locations2d[i][0] += xorigin
        locations2d[i][1] += yorigin

    for i in range(len(centers2d)):
        centers2d[i] = _rotate((0, 0), centers2d[i], math.radians(angrot))
        centers2d[i][0] += xorigin
        centers2d[i][1] += yorigin


def _elevations_from_dis(dis_package) -> tuple[list[float], list[float]]:
    """Get the top and bottom elevations from the DIS package.

    The returned tops are actually the same size as bots which is what the cogrid needs.

    Args:
        dis_package (DisData): DIS package data

    Returns:
        (tuple): tuple containing:
            - tops (float): tops.
            - bots (float): bottoms.
    """
    tops = dis_package.get_tops()
    bottoms = dis_package.get_bottoms()
    return tops, bottoms


def _model_on_off_from_dis(dis_package):
    """Get the model on/off values from the DIS package.

    Args:
        dis_package (DisData): DIS package data

    Returns:
        (list): array values
    """
    idomain = []
    griddata = dis_package.block('GRIDDATA')
    if griddata.has('IDOMAIN'):
        idomain_array = griddata.array('IDOMAIN')
        if idomain_array:
            idomain = idomain_array.get_values()

    idomain_int = [int(i) for i in idomain]
    return idomain_int


def _read_vertices_block(filename):
    """Reads the vertices block.

    Args:
        filename: The filename.

    Returns:
        List of locations.
    """
    locations2d = []
    with open(filename, 'r') as file:
        for line in file:
            words = line.split()
            locations2d.append([float(words[1]), float(words[2])])
    return locations2d


def _read_cell2d_block(filename):
    """Reads the cell2d block.

    Args:
        filename: The filename.

    Returns:
        (tuple): tuple containing:
            - centers2d (list[(float, float)]): List of points.
            - icvert: array of vertex numbers (in the VERTICES block) used to define the cell.
    """
    centers2d = []
    icvert = []
    with open(filename, 'r') as file:
        for line in file:
            words = line.split()
            centers2d.append([float(words[1]), float(words[2])])
            ncvert = int(words[3])
            icvert.append([int(word) - 1 for word in words[4:4 + ncvert]])
    return centers2d, icvert


def _dogrid_from_file(file_name: str) -> DoGrid:
    """Returns a data_objects UGrid given a file.

    Args:
        file_name: The file path.

    Returns:
        See description.
    """
    dogrid = DoGrid(file_name)
    dogrid.uuid = str(uuid.uuid4())
    return dogrid


def _dis_locations_from_offsets(array: Array, num_vals, reverse=False):
    """Creates the x or y locations for the grid based on data from DELR or DELC.

    Args:
        array: class with array information
        num_vals (int): either nrow or ncol of the grid
        reverse (bool): reverse the array values before creating offsets

    Returns:
        (list): x or y locations
    """
    values = array.get_values()
    out_list = [0.0]
    if reverse:
        values.reverse()
    for i in range(num_vals):
        out_list.append(out_list[-1] + values[i])
    return out_list
