"""MapActivityToUGridTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules
from rtree import index

# 3. Aquaveo modules
from xms.tool_core import IoDirection, Tool

# 4. Local modules

ARG_INPUT_GRID = 0
ARG_INPUT_ALGORITHM = 1
ARG_INPUT_SOLIDS = 2
ARG_INPUT_DEFAULT = 3
ARG_OUTPUT_DATASET_NAME = 4
ARG_OUTPUT_DATASET = 5


class MapSolidsToUGridTool(Tool):
    """Create a dataset for a UGrid with the material id's."""

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Map Solids')
        self._mat_dict = dict()
        self._xy_to_z = dict()
        self._dset_vals = []
        self._default = 100.0
        self._ugrid = None
        self._cogrid = None

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.grid_argument(name='input_grid', description='Input grid'),
            self.string_argument(name='algorithm', description='Mapping algorithm', value='Centroid',
                                 choices=['Centroid', 'Predominant material']),
            self.file_argument(name='solid_file', description='Solid file (*sol)'),
            self.integer_argument(name='default_mat', description='Default material id', value=100),
            self.string_argument(name='dataset_name', description='Name of the output dataset', value='materials',
                                 optional=True),
            self.dataset_argument(name='material_dataset', description='Material dataset', hide=True,
                                  io_direction=IoDirection.OUTPUT)
        ]
        return arguments

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}

        # Validate primary and secondary grids are specified and 2D
        self._validate_input_grid(errors, arguments[ARG_INPUT_GRID])

        return errors

    def _validate_input_grid(self, errors, argument):
        """Validate grid.

        Args:
            errors (dict): Dictionary of errors keyed by argument name.
            argument (GridArgument): The grid argument.
        """
        key = argument.name
        self._cogrid = self.get_input_grid(argument.text_value)
        if not self._cogrid or not self._cogrid.ugrid:
            errors[key] = 'Could not read grid.'
        else:
            self._ugrid = self._cogrid.ugrid

    def _read_solid_file(self, filename):
        """Validate grids are specified and 2D.

        Args:
            filename (str): name of file containing solid information
        """
        try:
            self.logger.info('Reading solid file...')

            with open(filename) as f:
                done = False
                line = ''
                while done is False:
                    mat_info = dict()
                    while line.startswith('ID ') is False:
                        line = f.readline()
                    id = int(line[3:])
                    mat_info['id'] = id

                    while line.startswith('MAT ') is False:
                        line = f.readline()
                    mat_id = int(line[4:])
                    mat_info['mat_id'] = mat_id

                    while line.startswith('VERT ') is False:
                        line = f.readline()

                    num_verts = int(line[5:])

                    # get the point locations
                    vert_locs = []
                    for _ in range(num_verts):
                        line = f.readline().strip()
                        locs = line.split("\t")
                        vert_locs.append([float(locs[0]), float(locs[1]), float(locs[2])])
                    mat_info['vert_locs'] = vert_locs

                    line = f.readline()
                    num_tris = int(line[4:])
                    tris = []
                    solid_rtree = index.Index()
                    for i in range(num_tris):
                        line = f.readline().strip()
                        vert_ids = line.split("\t")
                        tri = [int(vert_ids[0]), int(vert_ids[1]), int(vert_ids[2])]
                        tris.append(tri)
                        if self._tri_is_collinear(vert_locs[tri[0]], vert_locs[tri[1]], vert_locs[tri[2]]) is False:
                            x_list = [vert_locs[tri[0]][0], vert_locs[tri[1]][0], vert_locs[tri[2]][0]]
                            y_list = [vert_locs[tri[0]][1], vert_locs[tri[1]][1], vert_locs[tri[2]][1]]
                            solid_rtree.insert(i, (min(x_list), min(y_list), max(x_list), max(y_list)))
                        else:
                            pass

                    mat_info['tris'] = tris
                    mat_info['rtree'] = solid_rtree
                    mat_info['min_max'] = dict()
                    self._mat_dict[id] = mat_info

                    # see if there are more entries - if there aren't, we are done
                    line = f.readline()
                    while line.startswith('ENDS') is False:
                        line = f.readline()
                    line = f.readline()
                    if not line:
                        done = True

        except ValueError:
            self.fail(f'Error reading file: {filename}. Check file format.')

    def _tri_is_collinear(self, p0, p1, p2):
        """Return True if points are collinear.

        Args:
            p0 (list): location of point 1
            p1 (list): location of point 2
            p2 (list): location of point 3

        Return:
            (bool) True if collinear
        """
        x1, y1 = p1[0] - p0[0], p1[1] - p0[1]
        x2, y2 = p2[0] - p0[0], p2[1] - p0[1]
        return abs(x1 * y2 - x2 * y1) < 1e-12

    def _get_elev_at_pt(self, v1, v2, v3, p):
        """Using weights, calculate the elevation at the point in the triangle.

        Args:
            v1 (list): location of point 1
            v2 (list): location of point 2
            v3 (list): location of point 3
            p (list): xy point in question

        Return:
            (float): elevation of xy point in the triangle (NONE if point is not in triangle)
        """
        # https://codeplea.com/triangular-interpolation
        denominator = (v2[1] - v3[1]) * (v1[0] - v3[0]) + (v3[0] - v2[0]) * (v1[1] - v3[1])

        w1 = ((v2[1] - v3[1]) * (p[0] - v3[0]) + (v3[0] - v2[0]) * (p[1] - v3[1])) / denominator
        w2 = ((v3[1] - v1[1]) * (p[0] - v3[0]) + (v1[0] - v3[0]) * (p[1] - v3[1])) / denominator
        w3 = 1 - w1 - w2
        min_val = min([w1, w2, w3])
        if min_val < 0.0:
            return None

        return w1 * v1[2] + w2 * v2[2] + w3 * v3[2]

    def _map_mat_to_centroid(self):
        """Map the materials to the cell centroids.
        """
        self.logger.info('Mapping materials to cell centroids...')
        num_cells = self._ugrid.cell_count
        for cell_id in range(num_cells):
            _, centroid = self._ugrid.get_cell_centroid(cell_id)
            xy = (centroid[0], centroid[1])

            # get centroid z
            cell_locs = self._ugrid.get_cell_locations(cell_id)
            centroid_z = 0
            for cell_loc in cell_locs:
                centroid_z += cell_loc[2] / 8.0

            if xy not in self._xy_to_z.keys():
                self._xy_to_z[xy] = dict()

            all_ids = []
            for id in self._mat_dict.keys():
                min_max = None
                if id in self._xy_to_z[xy].keys():  # we have already calculated this
                    min_max = self._xy_to_z[xy][id]
                else:
                    rtree = self._mat_dict[id]['rtree']
                    vert_locs = self._mat_dict[id]['vert_locs']
                    tris = self._mat_dict[id]['tris']

                    intersections = list(rtree.intersection((centroid[0], centroid[1], centroid[0], centroid[1])))
                    elevs = []
                    for tri in intersections:
                        elev = self._get_elev_at_pt(vert_locs[tris[tri][0]], vert_locs[tris[tri][1]],
                                                    vert_locs[tris[tri][2]], centroid)
                        if elev is not None:
                            elevs.append(elev)
                    if len(elevs) > 0:
                        min_max = (min(elevs), max(elevs))
                    self._xy_to_z[xy][id] = min_max

                if min_max is not None and centroid_z >= min_max[0] and centroid_z <= min_max[1]:
                    all_ids.append(self._mat_dict[id]['mat_id'])

            if len(all_ids) == 1:
                self._dset_vals.append(float(all_ids[0]))
            else:
                self._dset_vals.append(self._default)

    def _map_predom_mat(self):
        """Map the predominant materials to the cells.
        """
        self.logger.info('Mapping predominant materials to cells...')
        num_cells = self._ugrid.cell_count
        for cell_id in range(num_cells):
            _, centroid = self._ugrid.get_cell_centroid(cell_id)
            xy = (centroid[0], centroid[1])

            # create a dictionary if needed
            if xy not in self._xy_to_z.keys():
                self._xy_to_z[xy] = dict()

            # get the elevations so we can get the top and bottom
            corner_elevs = []
            cell_locs = self._ugrid.get_cell_locations(cell_id)
            for cell_loc in cell_locs:
                corner_elevs.append(cell_loc[2])
            top = max(corner_elevs)
            bot = min(corner_elevs)

            # find the length of each material and store in a dict
            length_of_mat = dict()
            for id in self._mat_dict.keys():
                min_max = None
                if id in self._xy_to_z[xy].keys():  # we have already calculated this
                    min_max = self._xy_to_z[xy][id]
                else:
                    rtree = self._mat_dict[id]['rtree']
                    vert_locs = self._mat_dict[id]['vert_locs']
                    tris = self._mat_dict[id]['tris']

                    intersections = list(rtree.intersection((centroid[0], centroid[1], centroid[0], centroid[1])))
                    elevs = []
                    for tri in intersections:
                        elev = self._get_elev_at_pt(vert_locs[tris[tri][0]], vert_locs[tris[tri][1]],
                                                    vert_locs[tris[tri][2]], centroid)
                        if elev is not None:
                            elevs.append(elev)
                    if len(elevs) > 0:
                        min_max = (min(elevs), max(elevs))
                    self._xy_to_z[xy][id] = min_max

                if min_max is not None and (min_max[1] > bot and min_max[0] < top):
                    calc_top = min(top, min_max[1])
                    calc_bottom = max(bot, min_max[0])
                    length = calc_top - calc_bottom
                    mat_id = self._mat_dict[id]['mat_id']
                    if mat_id in length_of_mat.keys():
                        length = length + length_of_mat[mat_id]
                    length_of_mat[mat_id] = length

            if len(length_of_mat) > 0:
                lengths = list(length_of_mat.values())
                sorted_lengths = sorted(lengths)
                max_length = sorted_lengths[-1]
                position = lengths.index(max_length)

                # we will take the material with the longest length
                self._dset_vals.append(list(length_of_mat.keys())[position])
            else:
                self._dset_vals.append(self._default)

    def _create_dataset(self, dataset_name):
        """Create the dataset from the values we mapped.

        Args:
            dataset_name (str): name of the dataset
        """
        self.logger.info('Creating dataset...')
        builder = self.get_output_dataset_writer(name=dataset_name, geom_uuid=self._cogrid.uuid, location='cells')
        builder.write_xmdf_dataset([0.0], [self._dset_vals])
        self.set_output_dataset(builder)

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        self._read_solid_file(arguments[ARG_INPUT_SOLIDS].text_value)
        self._default = float(arguments[ARG_INPUT_DEFAULT].value)
        if arguments[ARG_INPUT_ALGORITHM].value == 'Centroid':
            self._map_mat_to_centroid()
        else:  # arguments[ARG_INPUT_ALGORITHM].value == 'Predominant material':
            self._map_predom_mat()

        self._create_dataset(arguments[ARG_OUTPUT_DATASET_NAME].text_value)
