"""This module writes out the tel file for CMS-Flow."""

# 1. Standard Python modules
from io import StringIO
import math
import os
import shutil
import threading

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import tree_util
from xms.constraint import Numbering, Orientation, QuadtreeGrid2d, QuadtreeGridBuilder, read_grid_from_file
from xms.constraint.quadtree import QuadGridCells
from xms.snap.snap_polygon import SnapPolygon

# 4. Local modules


def convert_constraint(co_grid):
    """Converts a UGrid's constraint from Rectilinear 2D to Quadtree 2D if not already.

    Args:
        co_grid (CoGrid): The constrained UGrid

    Returns:
        QuadtreeGrid2d: The Quadtree 2D constrained object
    """
    if not co_grid or type(co_grid) is QuadtreeGrid2d:
        return co_grid  # Already has a Quadtree 2D constraint
    builder = QuadtreeGridBuilder()
    numbering = Numbering.kji
    orientation = (Orientation.x_increase, Orientation.y_increase)

    quad_grid_cells = QuadGridCells()
    # ugrid = co_grid.ugrid
    locations_x = co_grid.locations_x
    locations_y = co_grid.locations_y
    num_i = len(locations_x) - 1
    num_j = len(locations_y) - 1
    quad_grid_cells.set_base_grid(num_i=num_i, num_j=num_j, numbering=numbering)
    quad_grid_cells.set_all_included()

    builder.is_2d_grid = True
    builder.locations_x = locations_x
    builder.locations_y = locations_y
    builder.origin = co_grid.origin
    builder.angle = co_grid.angle
    builder.numbering = numbering
    builder.orientation = orientation
    builder.quad_grid_cells = quad_grid_cells
    quad_grid = builder.build_grid()
    # Assign the cell elevations if the source grid had them.
    cell_elevations = co_grid.cell_elevations
    if cell_elevations is not None:
        quad_grid.cell_elevations = cell_elevations
    # Assign the model on/off cell flags if the source grid had them.
    on_off_cells = co_grid.model_on_off_cells
    if on_off_cells and len(on_off_cells) == num_i * num_j:
        quad_grid.model_on_off_cells = on_off_cells
    return quad_grid


class CentroidThread(threading.Thread):
    """This class runs a thread for finding cell centroids in a given range."""
    def __init__(self, ugrid, centroid_list, start_idx, end_idx):
        """Class for starting a thread for getting cell centroids.

        Args:
            ugrid (XmUgrid): The geometry to get cell centers from.
            centroid_list (list): The list of centroid locations to fill in.
            start_idx (int): The starting index to fill in.
            end_idx (int): The end index to fill in (exclusive).
        """
        threading.Thread.__init__(self)
        self.ugrid = ugrid
        self.start_idx = start_idx
        self.end_idx = end_idx
        self.centroid_list = centroid_list

    def run(self):
        """Gets the centroids for the cells in the index range."""
        for cell_idx in range(self.start_idx, self.end_idx):
            self.centroid_list[cell_idx] = self.ugrid.get_cell_centroid(cell_idx)[1]


class CMSFlowTelExporter:
    """Exporter for CMS-Flow tel file."""
    def __init__(self):
        """Constructs class for tel file export."""
        self.ss = StringIO()
        self.proj_name = ""
        self.cell_activity = None
        self.angle = 0.0
        self.origin_x = 0.0
        self.origin_y = 0.0
        self.ctheta = 0.0
        self.stheta = 0.0

    def convert_coord_to_local(self, coord):
        """Converts a point in world coordinates to an offset relative to the grid origin.

        Args:
             coord (:obj:`data_objects.parameters.Point`): The world coordinates of a cell center to be transformed

        Returns:
            (:obj:`tuple` of :obj:`float`): The x and y offsets of the cell center from the grid origin

        """
        x_coord = coord[0]
        y_coord = coord[1]
        x_offset = (x_coord - self.origin_x) * self.ctheta - (y_coord - self.origin_y) * self.stheta
        y_offset = (x_coord - self.origin_x) * self.stheta + (y_coord - self.origin_y) * self.ctheta
        return x_offset, y_offset

    @staticmethod
    def remove_noise(value):
        """Remove numerical round off noise in a floating point value.

        Args:
            value (float): The value being cleaned up
        """
        number_digits = 0
        value_digits = round(value, number_digits)
        string_a = str(value_digits)
        number_digits = 1
        value_digits = round(value, number_digits)
        string_b = str(value_digits)
        # increase precision until max of 11 or noise encountered
        while number_digits <= 11 and len(string_b) <= len(string_a) + 1:
            number_digits += 1
            value_digits = round(value, number_digits)
            string_a = string_b
            string_b = str(value_digits)
        return round(value, number_digits - 1)

    def write_cell_lines(self, co_grid, ugrid):
        """Write the cell definition lines.

        Args:
            co_grid (CoGrid): The CMS-Flow Quadtree.
            ugrid (XmUGrid): The Quadtree geometry
        """
        base_size = co_grid.get_base_cell_dimensions(0)
        cell_getter = co_grid.quad_grid_cells
        num_cells = ugrid.cell_count
        center_list = [None for _ in range(num_cells)]
        half_list = int(num_cells / 2)
        cell_z = co_grid.cell_elevations

        thread1 = CentroidThread(ugrid, center_list, 0, half_list)
        thread2 = CentroidThread(ugrid, center_list, half_list, num_cells)

        thread1.start()
        thread2.start()
        thread1.join()
        thread2.join()

        #     -------------------------
        #     |  9  |  8  |  7  |  6  |
        #     -------------------------
        #     | 10  |           |  5  |
        #     -------   cell    -------
        #     | 11  |           |  4  |
        #     -------------------------
        #     |  0  |  1  |  2  |  3  |
        #     -------------------------
        for cell_idx in range(num_cells):
            cell_id = cell_idx + 1  # change to 1 based
            neighbors = cell_getter.get_smoothed_neighbors(cell_idx)
            top_left = neighbors[8]
            top_right = neighbors[7]
            right_top = neighbors[5]
            right_bottom = neighbors[4]
            bottom_right = neighbors[2]
            bottom_left = neighbors[1]
            left_bottom = neighbors[11]
            left_top = neighbors[10]
            cell_center = center_list[cell_idx]
            offset_x, offset_y = self.convert_coord_to_local(cell_center)
            cell_refinement = cell_getter.get_cell_refinement(cell_idx).to_integer()
            cell_refinement_power = float(2**cell_refinement)
            size_i = base_size[0] / cell_refinement_power
            size_j = base_size[1] / cell_refinement_power

            # If there is only one neighbor on a side, write that id first for the side regardless of its location.
            if top_right >= 0 and top_left < 0:
                top_left = top_right
                top_right = -1
            if right_bottom >= 0 and right_top < 0:
                right_top = right_bottom
                right_bottom = -1
            if bottom_left >= 0 and bottom_right < 0:
                bottom_right = bottom_left
                bottom_left = -1
            if left_top >= 0 and left_bottom < 0:
                left_bottom = left_top
                left_top = -1

            if top_left >= 0 and top_left == top_right:
                top_right = -1
            if right_top >= 0 and right_top == right_bottom:
                right_bottom = -1
            if bottom_right >= 0 and bottom_right == bottom_left:
                bottom_left = -1
            if left_bottom >= 0 and left_bottom == left_top:
                left_top = -1

            # Write the following values for each cell (on a single line):
            #
            # cell id, center x, center y, i-size, j-size, top-left, top-right, right-top, right-bottom, bottom-right,
            #     bottom-left, center z, refinement level
            if self.cell_activity[cell_idx] and not math.isclose(cell_z[cell_idx], -999.0):
                z_value = cell_z[cell_idx]
                if z_value != 0.0:
                    z_value *= -1  # Invert elevation, depths always negative in SMS
            else:
                z_value = -999.0  # If a cell is inactive, output -999.0 for center Z.

            offset_x = self.remove_noise(offset_x)
            off_x_str = str(offset_x).rjust(13)
            offset_y = self.remove_noise(offset_y)
            off_y_str = str(offset_y).rjust(13)
            size_i = self.remove_noise(size_i)
            size_i_str = str(size_i).rjust(13)
            size_j = self.remove_noise(size_j)
            size_j_str = str(size_j).rjust(13)
            z_value = self.remove_noise(z_value)
            z_str = str(z_value).rjust(14)
            self.ss.write(
                f"{cell_id:>7d} {off_x_str} {off_y_str} "  # Output center location as relative from origin
                f"{size_i_str} {size_j_str} "
                f"{top_left + 1 :>7d} {top_right + 1 :>7d} "  # If a neighbor does not exist for a location, output 0.
                f"{right_top + 1 :>7d} {right_bottom + 1 :>7d} "
                f"{bottom_right + 1 :>7d} {bottom_left + 1 :>7d} "
                f"{left_bottom + 1 :>7d} {left_top + 1 :>7d} "
                f"{z_str} "
                f"{cell_refinement}\n"
            )

    def write_tel_file(self, query=None):
        """Write the CMS-Flow telescoping grid file.

        Args:
            query (Query): The interprocess communication object. If not supplied, will create connection.
        """
        query = query if query else Query()
        self.proj_name = os.path.splitext(os.path.basename(query.xms_project_path))[0]

        # get the Quadtree geometry.
        sim_item = tree_util.find_tree_node_by_uuid(query.project_tree, query.current_item_uuid())
        quad_item = tree_util.descendants_of_type(
            sim_item, xms_types=['TI_UGRID_PTR', 'TI_CGRID2D_PTR'], recurse=False, allow_pointers=True, only_first=True
        )
        if quad_item:
            quad = query.item_with_uuid(quad_item.uuid)
            co_grid = read_grid_from_file(quad.cogrid_file)
            ugrid = co_grid.ugrid
            cell_count = ugrid.cell_count
            if cell_count < 1:
                raise RuntimeError('Error: No cells in the Quadtree!')
            self.cell_activity = [True for _ in range(cell_count)]

            # Get the activity coverage, if there is one.
            activity_item = tree_util.descendants_of_type(
                sim_item,
                xms_types=['TI_COVER_PTR'],
                recurse=False,
                allow_pointers=True,
                only_first=True,
                coverage_type='ACTIVITY_CLASSIFICATION'
            )
            if activity_item:
                activity_cov = query.item_with_uuid(activity_item.uuid, generic_coverage=True)
                if activity_cov:
                    snap = SnapPolygon()
                    snap.set_grid(co_grid, True)
                    polygons = activity_cov.m_cov.polygons
                    snap.add_polygons(polygons)
                    for idx in range(len(activity_cov.m_activity)):
                        if not activity_cov.m_activity[idx]:
                            cells = snap.get_cells_in_polygon(idx)
                            for cell in cells:
                                self.cell_activity[cell] = False  # mark as inactive
            elif co_grid.model_on_off_cells:  # If activity defined on the UGrid, use that
                self.cell_activity = co_grid.model_on_off_cells

            self._write_tel_file(cell_count, co_grid, ugrid)
        else:
            raise RuntimeError('Error: Unable to retrieve Quadtree from SMS!')

    def _write_tel_file(self, cell_count, co_grid, ugrid):
        """Writes the tel file.

        Args:
            cell_count (int): The number of cells in the grid.
            co_grid (CoGrid): The constrained quadtree grid.
            ugrid (XmUGrid): The Quadtree geometry
        """
        co_grid = convert_constraint(co_grid)
        angle = co_grid.angle
        # Store cos and sin of angle for cell center coordinate transformations.
        rad_angle = math.radians(360.0 - angle)
        self.ctheta = math.cos(rad_angle)
        self.stheta = math.sin(rad_angle)
        origin = co_grid.origin
        # Store the origin's coordinates for cell center coordinate transformations.
        self.origin_x = origin[0]
        self.origin_y = origin[1]

        self.ss.write("CMS-Telescoping\n")
        if angle is not None and co_grid is not None:
            angle_string = str(angle)
            origin_x_string = str(self.origin_x)
            origin_y_string = str(self.origin_y)
            self.ss.write(f"{angle_string}  {origin_x_string}  {origin_y_string}  {cell_count}\n")

            # Write the cell definitions.
            self.write_cell_lines(co_grid, ugrid)
        # Dump the in-memory stream to file on disk.
        out = open(self.proj_name + ".tel", "w")
        self.ss.seek(0)
        shutil.copyfileobj(self.ss, out)
        out.close()
