"""This module reads in a tel file."""

__copyright__ = "(C) Copyright Aquaveo 2024"
__license__ = "All rights reserved"

# 1. Standard Python modules
import copy
import enum
import math
import os
import queue
import uuid

# 2. Third party modules
import h5py
import rtree

# 3. Aquaveo modules
from xms.api.dmi import Query
import xms.constraint as xc
from xms.constraint.rectilinear_geometry import Numbering, Orientation
from xms.coverage.activity import ActivityCoverage
from xms.coverage.grid.grid_cell_to_polygon_coverage_builder import GridCellToPolygonCoverageBuilder
from xms.data_objects.parameters import Point, UGrid

# 4. Local modules
from xms.cmsflow.feedback.xmlog import XmLog


class LocationDirection(enum.IntEnum):
    """Enumeration for importing directions from CMS-Flow tel file."""
    loc_top_left = 0
    loc_top_right = 1
    loc_right_top = 2
    loc_right_bot = 3
    loc_bot_right = 4
    loc_bot_left = 5
    loc_left_bot = 6
    loc_left_top = 7
    loc_num_directions = 8


class CMSFlowTelImporter:
    """Importer for CMS-Flow tel file."""
    new_cell_idx_column = 3

    def __init__(self, filename=None, projection=None, cms_version=None, bathymetry=None):
        """Constructs the objects to fill in on reading the tel file.

        Args:
            filename (str): The name and full path of the tel file to read.
            projection (data_objects.parameters.Projection): The projection from cmcards.
            cms_version (list[int]): A list containing the major, then minor, version of CMS-Flow.
            bathymetry (Optional[list[float]]): The cell-based depth dataset, if not provided cell elevations from
                the tel file will be used.
        """
        self.filename = filename
        self.angle = None
        self.projection = projection
        self.origin_x = 0.0
        self.origin_y = 0.0
        self.cell_centers = []
        self.cell_elevations = bathymetry
        self.cell_sizes = []
        self.cell_neighbors = []
        self.refinement = []
        self.refinement_idx = []
        self.cms_version = cms_version if cms_version else []
        self._renumber = True if self.cms_version and self.cms_version[0] <= 4 else False
        self.activity = []
        self.activity_coverage = None
        self.query = None

    def _level_size(self, level):
        return 1 << (self.num_levels - level)

    def _get_neighbor_ijk(self, direction, cell_ijk, cell_level, cell_size, cell_quad, neighbor_level, neighbor_size):
        """Gets the neighbor ijk indices given the parameters.

        Args:
            direction (LocationDirection): The direction of the neighbor.
            cell_ijk (Tuple[int, int, int]): The ijk indices of the current cell.
            cell_level (Level): The level of the current cell.
            cell_size (int): The size of the current cell.
            cell_quad (int): The quadrant of the current cell.
            neighbor_level (Level): The level of the neighbor cell.
            neighbor_size (int): The size of the neighbor cell.

        Returns:
            Tuple[bool, Tuple[int, int, int]]: A tuple containing a boolean indicating if the neighbor exists and the
                                               ijk indices of the neighbor.
        """
        neighbor_ijk = list(cell_ijk)

        if neighbor_level.same_refinement(cell_level):
            assert cell_size == neighbor_size
            if direction == LocationDirection.loc_top_left:
                neighbor_ijk[1] += cell_size
            elif direction == LocationDirection.loc_right_top:
                neighbor_ijk[0] += cell_size
            elif direction == LocationDirection.loc_bot_right:
                neighbor_ijk[1] -= neighbor_size
            elif direction == LocationDirection.loc_left_bot:
                neighbor_ijk[0] -= neighbor_size
            else:
                return False, tuple(neighbor_ijk)
        elif neighbor_level.more_refined_than(cell_level):
            assert neighbor_size < cell_size
            if direction == LocationDirection.loc_top_left:
                neighbor_ijk[1] += cell_size
            elif direction == LocationDirection.loc_top_right:
                neighbor_ijk[0] += neighbor_size
                neighbor_ijk[1] += cell_size
            elif direction == LocationDirection.loc_right_top:
                neighbor_ijk[0] += cell_size
                neighbor_ijk[1] += neighbor_size
            elif direction == LocationDirection.loc_right_bot:
                neighbor_ijk[0] += cell_size
            elif direction == LocationDirection.loc_bot_right:
                neighbor_ijk[0] += neighbor_size
                neighbor_ijk[1] -= neighbor_size
            elif direction == LocationDirection.loc_bot_left:
                neighbor_ijk[1] -= neighbor_size
            elif direction == LocationDirection.loc_left_bot:
                neighbor_ijk[0] -= neighbor_size
            elif direction == LocationDirection.loc_left_top:
                neighbor_ijk[0] -= neighbor_size
                neighbor_ijk[1] += neighbor_size
            else:
                return False, tuple(neighbor_ijk)
        else:
            assert neighbor_size > cell_size
            if cell_quad == 0 and direction == LocationDirection.loc_bot_right:
                neighbor_ijk[1] -= neighbor_size
            elif cell_quad == 0 and direction == LocationDirection.loc_left_bot:
                neighbor_ijk[0] -= neighbor_size
            elif cell_quad == 1 and direction == LocationDirection.loc_right_top:
                neighbor_ijk[0] += cell_size
            elif cell_quad == 1 and direction == LocationDirection.loc_bot_right:
                neighbor_ijk[0] -= cell_size
                neighbor_ijk[1] -= neighbor_size
            elif cell_quad == 2 and direction == LocationDirection.loc_top_left:
                neighbor_ijk[1] += cell_size
            elif cell_quad == 2 and direction == LocationDirection.loc_left_bot:
                neighbor_ijk[0] -= neighbor_size
                neighbor_ijk[1] -= cell_size
            elif cell_quad == 3 and direction == LocationDirection.loc_top_left:
                neighbor_ijk[0] -= cell_size
                neighbor_ijk[1] += cell_size
            elif cell_quad == 3 and direction == LocationDirection.loc_right_top:
                neighbor_ijk[0] += cell_size
                neighbor_ijk[1] -= cell_size
            else:
                return False, tuple(neighbor_ijk)

        return True, tuple(neighbor_ijk)

    def _build_co_quadtree_grid(self):
        """Builds a quadtree constrained UGrid from all the information in the .tel file.

        This replaces the following function in SMS C++: ugBuildQuadGridFromInfo
        self.cell_centers = a_location
        self.cell_elevations = a_elevations
        self.cell_sizes = a_size
        self.cell_neighbors = a_neighborIdxs
        self.refinement = a_refinement

        """
        if not self.cell_centers:
            XmLog().instance.error(f'Zero location values in {self.filename}.')
            return

        array_size = len(self.cell_centers)
        if len(self.cell_sizes) != array_size or len(self.cell_neighbors) != array_size or \
                len(self.refinement) != array_size:
            XmLog().instance.error('Cell array sizes must all match.')
            return

        self.num_levels = max(self.refinement)
        if self.num_levels > 31:
            XmLog().instance.error(f'Maximum number of levels in a .tel file is 31.  This file has {self.num_levels}.')
            return

        cells = xc.quadtree.QuadGridCells()
        cells.number_of_levels = self.num_levels

        # get origin cell idx
        origin_idx = -1
        i = 0
        for neighbors in self.cell_neighbors:
            if neighbors[LocationDirection.loc_bot_right] == -1 and \
                    neighbors[LocationDirection.loc_bot_left] == -1 and \
                    neighbors[LocationDirection.loc_left_bot] == -1 and \
                    neighbors[LocationDirection.loc_left_top] == -1:
                origin_idx = i
            i += 1
        if origin_idx == -1:
            XmLog().instance.error('Unable to find origin cell.')
            return

        # get number of rows (x direction)
        rows = self._level_size(self.refinement[origin_idx])
        count = 0
        neighbor_idx = self.cell_neighbors[origin_idx][LocationDirection.loc_right_top]
        while neighbor_idx != -1:
            rows += self._level_size(self.refinement[neighbor_idx])
            neighbor_idx = self.cell_neighbors[neighbor_idx][LocationDirection.loc_right_top]
            count += 1
            if count > array_size:
                XmLog().instance.error('Unable to find number of rows.')
                return
        rows /= self._level_size(0)

        # get number of cols (y direction)
        cols = self._level_size(self.refinement[origin_idx])
        count = 0
        neighbor_idx = self.cell_neighbors[origin_idx][LocationDirection.loc_top_left]
        while neighbor_idx != -1:
            cols += self._level_size(self.refinement[neighbor_idx])
            neighbor_idx = self.cell_neighbors[neighbor_idx][LocationDirection.loc_top_left]
            count += 1
            if count > array_size:
                XmLog().instance.error('Unable to find number of columns.')
                return
        cols /= self._level_size(0)

        cells.set_base_grid(int(rows), int(cols), None, Numbering.kji, False)

        # mark cell refinement
        processed = [False for _ in range(len(self.refinement))]
        to_process = queue.Queue()
        to_process.put((origin_idx, (0, 0, 0)))
        cells.set_level(0, 0, 0, xc.quadtree.RefineLevel(self.refinement[origin_idx]))
        cells.set_cell_index(0, 0, 0, origin_idx)
        while not to_process.empty():
            cell_idx, cell_ijk = to_process.get()
            processed[cell_idx] = True
            cell_level = xc.quadtree.RefineLevel(self.refinement[cell_idx])
            cell_quad = cells.get_cell_quad_idx(cell_ijk[0], cell_ijk[1], cell_ijk[2])
            cell_size = self._level_size(self.refinement[cell_idx])
            for direction in range(LocationDirection.loc_top_left, LocationDirection.loc_num_directions, 2):
                neighbor_idx = self.cell_neighbors[cell_idx][direction]
                if neighbor_idx >= 0 and not processed[neighbor_idx]:
                    # get Ijk
                    neighbor_level = xc.quadtree.RefineLevel(self.refinement[neighbor_idx])
                    neighbor_size = self._level_size(self.refinement[neighbor_idx])
                    success, neighbor_ijk = self._get_neighbor_ijk(
                        direction, cell_ijk, cell_level, cell_size, cell_quad, neighbor_level, neighbor_size
                    )
                    if not success:
                        XmLog().instance.error(
                            f'Invalid direction to neighbor for cell ID {cell_idx + 1} and neighbor '
                            f'ID {neighbor_idx + 1}.'
                        )
                        return

                    if not processed[neighbor_idx]:
                        # mark cell
                        cells.set_level(neighbor_ijk[0], neighbor_ijk[1], neighbor_ijk[2], neighbor_level)
                        cells.set_cell_index(neighbor_ijk[0], neighbor_ijk[1], neighbor_ijk[2], neighbor_idx)
                        processed[neighbor_idx] = True
                        to_process.put((neighbor_idx, neighbor_ijk))
                    else:
                        # make sure previously set neighbor matches
                        idx = cells.get_cell_index(neighbor_ijk[0], neighbor_ijk[1], neighbor_ijk[2])
                        if idx != neighbor_idx:
                            XmLog().instance.error(
                                f'Location from cell ID {cell_idx + 1} to neighbor ID '
                                f'{neighbor_idx + 1} fails to match previously set location..'
                            )
                            return

                        level = cells.get_level(neighbor_ijk[0], neighbor_ijk[1], neighbor_ijk[2])
                        if not level.same_refinement(neighbor_level):
                            XmLog().instance.error(
                                'Refinement fails to match previous refinement for neighbor ID '
                                f'{neighbor_idx + 1} from cell ID {cell_idx + 1}.'
                            )
                            return

        # get steps x-direction
        steps_x = []
        base_size = self._level_size(0)
        for i in range(1, cells.number_of_base_rows + 1):
            cell_idx = cells.get_cell_index(i * base_size - 1, 0, 0)
            if cell_idx >= 0:
                if i == 1:
                    # Add the first step if we're at the beginning of the grid
                    steps_x.append(0.0)
                x = self.cell_centers[cell_idx].x + self.cell_sizes[cell_idx].x / 2
                steps_x.append(x)
            else:
                XmLog().instance.error('Grid cell missing when calculating grid x steps.')
                return

        # get steps y-direction
        steps_y = []
        for i in range(1, cells.number_of_base_columns + 1):
            cell_idx = cells.get_cell_index(0, i * base_size - 1, 0)
            if cell_idx >= 0:
                if i == 1:
                    # Add the first step if we're at the beginning of the grid
                    steps_y.append(0.0)
                y = self.cell_centers[cell_idx].y + self.cell_sizes[cell_idx].y / 2
                steps_y.append(y)
            else:
                XmLog().instance.error('Grid cell missing when calculating grid y steps.')
                return

        builder = xc.QuadtreeGridBuilder()
        builder.is_2d_grid = True
        builder.locations_x = steps_x
        builder.locations_y = steps_y
        builder.origin = (self.origin_x, self.origin_y, 0.0)
        builder.angle = self.angle
        builder.numbering = Numbering.kji
        builder.orientation = (Orientation.x_increase, Orientation.y_increase)
        builder.quad_grid_cells = cells
        self.quad_ugrid = builder.build_grid()
        self.quad_ugrid.uuid = str(uuid.uuid4())  # Will be used to link dataset selector widgets from .cmcards

        if self._renumber:
            num_cells = self.quad_ugrid.ugrid.cell_count
            new_cells = self.quad_ugrid.quad_grid_cells
            self.refinement_idx = []
            for cell_idx in range(num_cells):
                refine = new_cells.get_cell_refinement(cell_idx).to_integer()
                i, j, k = new_cells.get_cell_ijk(cell_idx)
                # If you change the order of self.refinement_idx, or change it, update self.new_cell_idx_column.
                self.refinement_idx.append((refine, j, i, cell_idx))
            self.refinement_idx.sort()
            # update the cell elevations and activity arrays
            old_cell_elevations = copy.deepcopy(self.cell_elevations)
            old_activity = copy.deepcopy(self.activity)
            for old_index, new_index in enumerate(self.refinement_idx):
                cell_idx = new_index[self.new_cell_idx_column]
                self.cell_elevations[cell_idx] = old_cell_elevations[old_index]
                self.activity[cell_idx] = old_activity[old_index]

        self.quad_ugrid.model_on_off_cells = self.activity
        self.quad_ugrid.cell_elevations = self.cell_elevations
        cov_builder = GridCellToPolygonCoverageBuilder(self.quad_ugrid, self.activity, self.projection, 'Activity')
        new_cov_geom = cov_builder.create_polygons_and_build_coverage()
        self.activity_coverage = ActivityCoverage()
        self.activity_coverage.m_cov = new_cov_geom
        activity = {}
        for poly_val, poly_ids in cov_builder.dataset_polygon_ids.items():
            for poly_id in poly_ids:
                activity[poly_id] = poly_val != 0
        self.activity_coverage.m_activity = activity

    def _initialize_cell_elevations(self, num_cells):
        """Initialize the array of cell elevations.

        Args:
            num_cells (int): The number of cells in the Quadtree

        Returns:
            bool: True if cell elevations should be overwritten with those from the .tel file.
        """
        if self.cell_elevations is None:  # Bathymetry dataset not specified
            self.cell_elevations = [0.0 for _ in range(num_cells)]
            return True
        else:
            filename, group_path = self.cell_elevations
            with h5py.File(filename, 'r') as f:
                if group_path not in f:
                    self.cell_elevations = [0.0 for _ in range(num_cells)]
                    return True
                self.cell_elevations = f[group_path][:][0].tolist()
            return False

    def _extrapolate_cell_elevations(self, point_search, inactive_cells):
        """Extrapolate cell center elevations to cells that we did not have valid elevation data for.

        Args:
            point_search (rtree.index.Index): rtree containing the cell center locations of cells that had valid
                elevation data
            inactive_cells (list): The cells we did not have valid elevation data for
                [(cell_idx, (center_x, center_y))]
        """
        XmLog().instance.info('Extrapolating cell elevations to inactive cell centers')
        for cell_idx, coords in inactive_cells:
            nearest_cells = list(point_search.nearest(coords))  # Find the k-nearest neighbor
            sum_z = 0.0
            for nearest_cell in nearest_cells:  # May return multiple if considered "equidistant"
                sum_z += self.cell_elevations[nearest_cell]
            self.cell_elevations[cell_idx] = sum_z / len(nearest_cells)

    def read_tel_file(self):
        """Write the cell definition lines."""
        if not os.path.isfile(self.filename):
            XmLog().instance.error(f'Telescoping grid file not found - {self.filename}')
            return

        world_coords = self.filename.lower().endswith('.worldtel')
        with open(self.filename, 'r') as file:
            XmLog().instance.info('Reading grid definition')
            file.readline()  # skip the first line identifier
            grid_def = file.readline().split()
            if self.angle is None:  # Angle in .cmcards file takes precedence if it exists
                self.angle = float(grid_def[0])
                # If reading a <= 5.2 version, invert the grid angle. CMS-Flow and SMS convention is CCW from the east,
                # but we did some smelly stuff in older versions where we would invert on both import and export.
                if len(self.cms_version) > 1 and self.cms_version[0] <= 5 and self.cms_version[1] <= 2:
                    self.angle = 360.0 - self.angle
            rad_theta = math.radians(360.0 - self.angle)
            self.origin_x = float(grid_def[1])
            self.origin_y = float(grid_def[2])
            costheta = math.cos(rad_theta)
            sintheta = math.sin(rad_theta)
            x0costheta = self.origin_x * costheta
            x0sintheta = self.origin_x * sintheta
            y0costheta = self.origin_y * costheta
            y0sintheta = self.origin_y * sintheta
            num_cells = int(grid_def[3])
            self.cell_centers = [Point(0.0, 0.0) for _ in range(num_cells)]
            override_elevations = self._initialize_cell_elevations(num_cells)
            self.activity = [1] * num_cells
            self.cell_sizes = [Point(0.0, 0.0) for _ in range(num_cells)]
            self.cell_neighbors = [[] for _ in range(num_cells)]
            self.refinement = [1 for _ in range(num_cells)]
            max_i_size = 0.0
            max_j_size = 0.0
            min_i_size = -1.0
            min_j_size = -1.0
            XmLog().instance.info('Reading grid cell definitions')
            # Build a rtree of the cell centers that have valid elevations. We will use it to extrapolate elevations at
            # the center of inactive cells, if we don't have elevation data for the inactive cell from another source.
            cells_with_z = []  # [(cell_idx, (center_x, center_y))]
            cells_no_z = []  # [(cell_idx, (center_x, center_y))]

            # Now read all the cell data
            for i in range(num_cells):
                cell_data = file.readline().split()

                # initialize the minimum values
                if min_i_size < 0.0:
                    min_i_size = float(cell_data[3])
                    min_j_size = float(cell_data[4])

                # compute the extreme cell sizes
                max_i_size = max([float(cell_data[3]), max_i_size])
                max_j_size = max([float(cell_data[4]), max_j_size])
                min_i_size = min([float(cell_data[3]), min_i_size])
                min_j_size = min([float(cell_data[4]), min_j_size])

                # store the refinement level if it is there
                if len(cell_data) > 14:
                    self.refinement[i] = int(cell_data[14])

                # store the cell size from the file
                self.cell_sizes[i] = Point(float(cell_data[3]), float(cell_data[4]))

                # If we are reading a .worldtel file, convert coordinates from global to local space.
                center_x = float(cell_data[1])
                center_y = float(cell_data[2])
                if world_coords:
                    x_world = center_x
                    y_world = center_y
                    center_x = x_world * costheta - y_world * sintheta - x0costheta + y0sintheta
                    center_y = x_world * sintheta + y_world * costheta - x0sintheta - y0costheta

                # store the cell centroid from the file
                self.cell_centers[i] = Point(center_x, center_y)

                # Assign the tel file's elevation to the cell only if we don't have elevation data from another source.
                # If we have actual elevation values at cell centers, we want to use that when building the geometry
                # because the elevation will be -999.0 in the tel file if the cell is inactive.
                tel_z = float(cell_data[13])
                z_value = tel_z if override_elevations else self.cell_elevations[i]
                if math.isclose(z_value, 0.0, abs_tol=1e-6):  # Make elevation 0.0 if close within garbage precision
                    z_value = 0.0
                    z_is_null = False
                else:
                    z_is_null = math.isclose(z_value, -999.0, abs_tol=1e-6)
                if z_is_null:  # We don't have real elevation data for this cell, need to extrapolate later.
                    cells_no_z.append((i, (center_x, center_y)))
                else:  # We have real elevation data for this cell, it will contribute to the extrapolation.
                    cells_with_z.append((i, (center_x, center_y)))

                # Invert depths to be negative elevations, but not 0.0 or the inactive value. Also don't do this if we
                # are reading a .worldtel file because the depths are already negative.
                if not world_coords and z_value and not z_is_null:
                    z_value *= -1.0
                self.cell_elevations[i] = z_value

                # Use the elevation from the tel file to determine activity, not the elevation we assigned the cell's
                # center. We may have elevation data for inactive cells from another source.
                self.activity[i] = 0 if z_is_null or math.isclose(tel_z, -999.0, abs_tol=1e-6) else 1

                # initialize the cell neighbors
                self.cell_neighbors[i] = [-1 for _ in range(8)]
                for idx in range(5, 13):
                    neighbor_idx = int(cell_data[idx]) - 1
                    self.cell_neighbors[i][idx - 5] = neighbor_idx

            # Determine if this grid is a quadtree or a cartesian grid (AKZ Cartesian support - not sure this works)
            # Initialize a flag checking for quad tree vs cartesian grid
            grid_is_quadtree = True
            # loop through the cell
            for i in range(num_cells):
                # compute the refinement level in both directions
                refinement_x = int(math.log(round(max_i_size / self.cell_sizes[i].x), 2))
                refinement_y = int(math.log(round(max_j_size / self.cell_sizes[i].y), 2))
                if refinement_x != refinement_y:
                    grid_is_quadtree = False
                    break  # exit the check

                # now make sure the cell sizes are close
                error_x = abs(self.cell_sizes[i].x - max_i_size / (2**refinement_x))
                error_y = abs(self.cell_sizes[i].y - max_j_size / (2**refinement_y))
                if error_x > min_i_size / 2 or error_y > min_j_size / 2:
                    grid_is_quadtree = False
                    break  # exit the check

            # We've read all the cells, fix cell sizes, centroids and refinement level (correct for precision limits)
            if grid_is_quadtree:
                for i in range(num_cells):
                    # assign the correct cell size (corrected for precision)
                    refinement = int(math.log(round(max_i_size / self.cell_sizes[i].x), 2))
                    self.refinement[i] = refinement
                    self.cell_sizes[i].x = max_i_size / (2**refinement)
                    self.cell_sizes[i].y = max_j_size / (2**refinement)

                    # correct the cell centroid
                    number_of_minimum_sized_cells = int(self.cell_centers[i].x / self.cell_sizes[i].x)
                    self.cell_centers[i].x = (number_of_minimum_sized_cells + 0.5) * self.cell_sizes[i].x
                    number_of_minimum_sized_cells = int(self.cell_centers[i].y / self.cell_sizes[i].y)
                    self.cell_centers[i].y = (number_of_minimum_sized_cells + 0.5) * self.cell_sizes[i].y

            else:
                for i in range(num_cells):
                    self.refinement[i] = 0

            if cells_no_z:  # Don't bother building the rtree if we had valid elevation data at all cells.
                XmLog().instance.info('Building rtree of cell centers with valid elevation data')
                point_search = rtree.index.Index()
                for cell_idx, coords in cells_with_z:
                    point_search.insert(cell_idx, coords)
                self._extrapolate_cell_elevations(point_search, cells_no_z)
            XmLog().instance.info('Building the Quadtree constrained grid')
            self._build_co_quadtree_grid()

    def get_sms_data(self):
        """Get the path and filename of the .cmcards file. Get the simulation Context vertex id.

        """
        self.query = Query()
        self.filename = self.query.read_file

    def send_sms_data(self):
        """Send the imported data to SMS.

        """
        if not self.query:
            return
        temp_dir = self.query.process_temp_directory
        xmc_file = os.path.join(temp_dir, f'{str(self.quad_ugrid.uuid)}.xmc')
        if not os.path.exists(temp_dir):
            os.mkdir(temp_dir)
        self.quad_ugrid.write_to_file(xmc_file)
        ugrid = UGrid(xmc_file, uuid=self.quad_ugrid.uuid, projection=self.projection)
        self.query.add_ugrid(ugrid)
        if self.activity_coverage:
            self.query.add_coverage(self.activity_coverage)
        self.query.send()

    def read(self):
        """Top-level function that triggers the read of the .cmcards file.

        A new simulation will be created in SMS. Any accompanying model-native files in the same directory will
        also be written.

        """
        self.get_sms_data()
        self.read_tel_file()
        self.send_sms_data()
