"""Map Material coverage and Sediment Material locations and attributes to the AdH domain."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os
import shutil
import uuid

# 2. Third party modules

# 3. Aquaveo modules
from xms.components.display.display_options_io import (
    write_display_option_polygon_locations, write_display_options_to_json
)
from xms.constraint.ugrid_builder import UGridBuilder
from xms.data_objects.parameters import Component
from xms.grid.ugrid import UGrid
from xms.guipy.data.category_display_option import CategoryDisplayOption
from xms.guipy.data.category_display_option_list import CategoryDisplayOptionList
from xms.guipy.data.polygon_texture import PolygonOptions
from xms.guipy.data.target_type import TargetType
from xms.snap.snap_polygon import SnapPolygon

# 4. Local modules
from xms.adh.components.mapped_component import MappedComponent
from xms.adh.data.materials_io import MaterialsIO
from xms.adh.data.sediment_materials_io import SedimentMaterialsIO
from xms.adh.gui.widgets.color_list import ColorList


class MaterialMapper:
    """Class for mapping material coverage to a mesh for AdH."""
    def __init__(self, coverage_mapper, generate_snap):
        """Constructor."""
        self._logger = coverage_mapper._logger
        self._wkt = coverage_mapper._wkt
        self._generate_snap = generate_snap
        self._co_grid = coverage_mapper._mesh
        self._active_mesh = None
        self._mat_cov = coverage_mapper._mat_cov
        self._sed_mat_cov = coverage_mapper._sed_mat_cov
        self._new_comp_unique_name = 'Mapped_Component'
        self.mapped_comp_uuid = None
        self.mapped_comp_display_uuid = None
        self._comp_uuid = ''
        self._mat_comp = coverage_mapper._mat_comp
        self._sed_mat_comp = coverage_mapper._sed_mat_comp
        self._component_file = coverage_mapper._mat_comp.main_file
        self._poly_files = []
        self._mat_combo_cell = None
        self._has_sed = False

    def do_map(self):
        """Creates the mapped material component."""
        self._get_polygon_cells()
        if self._generate_snap:
            self._create_component_folder()
            self._create_drawing()
            self._create_display_options()
            # Create the data_objects component
            do_comp = Component(
                main_file=self._comp_main_file,
                name='Snapped Material Display',
                unique_name=self._new_comp_unique_name,
                model_name='AdH',
                comp_uuid=os.path.basename(os.path.dirname(self._comp_main_file))
            )

            comp = MappedComponent(self._comp_main_file)
            return do_comp, comp, self._active_mesh
        return None, None, self._active_mesh  # pragma: no cover

    def _create_component_folder(self):
        """Creates the folder for the mapped materials component."""
        if self.mapped_comp_uuid is None:
            self._comp_uuid = str(uuid.uuid4())  # pragma: no cover
        else:
            self._comp_uuid = self.mapped_comp_uuid
        self._logger.info('Creating component folder')
        mat_comp_path = os.path.dirname(self._component_file)
        self._comp_path = os.path.join(os.path.dirname(mat_comp_path), self._comp_uuid)

        if os.path.exists(self._comp_path):
            shutil.rmtree(self._comp_path)  # pragma: no cover
        os.mkdir(self._comp_path)

    def _create_display_options(self):
        """Creates the display options file for the mapped materials component."""
        comp_display_file = os.path.join(self._comp_path, 'mapped_display_options.json')

        categories = CategoryDisplayOptionList()  # Generates a random UUID key for the display list
        if self.mapped_comp_display_uuid is not None:
            categories.uuid = self.mapped_comp_display_uuid
        categories.comp_uuid = self._comp_uuid
        categories.is_ids = False
        categories.target_type = TargetType.polygon
        self._poly_files.sort(key=lambda x: x[2])
        for idx, names in enumerate(self._poly_files):
            cat = CategoryDisplayOption()
            cat.file = os.path.basename(names[0])
            cat.description = names[1]
            cat.is_unassigned_category = False
            cat.label_on = False
            cat.options = PolygonOptions()
            if self._has_sed:
                ColorList.get_next_color_and_texture(idx, cat.options)
            else:
                cat.options = self._mat_comp.data.materials.material_display[names[2]]
            categories.categories.append(cat)
        categories.projection = {'wkt': self._wkt}

        write_display_options_to_json(comp_display_file, categories)
        self._comp_main_file = comp_display_file

    def _create_drawing(self):
        """Uses cell ids to get cell point coords to draw polygons for materials mapped to cells."""
        ugrid = self._co_grid.ugrid
        off_pair = (MaterialsIO.UNASSIGNED_MAT, SedimentMaterialsIO.UNASSIGNED_MAT)
        poly_to_cells = {}
        for cell_id, mat_ids in enumerate(self._mat_combo_cell):
            if mat_ids[0] == MaterialsIO.UNASSIGNED_MAT:
                if off_pair not in poly_to_cells:
                    poly_to_cells[off_pair] = []
                poly_to_cells[off_pair].append(cell_id)
            else:
                mat_ids_tuple = tuple(mat_ids)
                if mat_ids_tuple not in poly_to_cells:
                    poly_to_cells[mat_ids_tuple] = []
                poly_to_cells[mat_ids_tuple].append(cell_id)
        self._has_sed = False
        for (mat_id, sed_mat_id), cell_ids in poly_to_cells.items():
            poly_list = []
            for cid in cell_ids:
                cell_locs = ugrid.get_cell_locations(cid)
                locs_list = [item for sublist in cell_locs for item in sublist]
                if len(locs_list) < 9:
                    continue  # pragma: no cover
                # repeat the first point
                locs_list.append(locs_list[0])
                locs_list.append(locs_list[1])
                locs_list.append(locs_list[2])
                outer_dict = {'outer': locs_list}
                poly_list.append(outer_dict)
            filename = os.path.join(self._comp_path, f'mat_{mat_id}_sed_{sed_mat_id}.materials')
            mat_name = self._mat_comp.data.materials.material_properties[mat_id].material_name
            if self._sed_mat_comp and mat_id != MaterialsIO.UNASSIGNED_MAT:
                sed_mat_name = self._sed_mat_comp.data.materials[sed_mat_id].name
                name = f'Material ({mat_name}, {sed_mat_name})'
                self._has_sed = True
            else:
                name = f'{mat_name}'
            self._poly_files.append((filename, name, mat_id))
            write_display_option_polygon_locations(filename, poly_list)

    def _get_polygon_cells(self):
        """Uses xmssnap to get the cells for each polygon."""
        self._logger.info('Mapping materials to mesh.')
        self._mat_combo_cell = [
            [MaterialsIO.UNASSIGNED_MAT, SedimentMaterialsIO.UNASSIGNED_MAT]
            for _ in range(self._co_grid.ugrid.cell_count)
        ]
        self._get_material_cells()
        self._get_sediment_material_cells()

    def _get_material_cells(self):
        """Gets the cells for each hydrodynamic material. Also builds the active cell grid."""
        cell_count = self._co_grid.ugrid.cell_count
        if cell_count > 0:
            self._cell_activity = [False for _ in range(cell_count)]  # Guess everyone is inactive
        # If a polygon has no component id, or if assigned the 'OFF' material, then the snapped cells are inactive
        snap = SnapPolygon()
        snap.set_grid(self._co_grid, True)
        polygons = self._mat_cov.polygons
        snap.add_polygons(polygons)
        for idx in range(len(polygons)):
            polygon_id = polygons[idx].id
            mat_id = MaterialsIO.UNASSIGNED_MAT
            comp_id = self._mat_comp.get_comp_id(TargetType.polygon, polygon_id)
            if comp_id is None:
                comp_id = 0  # pragma: no cover
                is_off = True  # pragma: no cover
            else:
                mat_id = comp_id
                is_off = mat_id == MaterialsIO.UNASSIGNED_MAT
            cells = snap.get_cells_in_polygon(polygon_id)
            if not is_off:
                for cell in cells:
                    self._mat_combo_cell[cell][0] = mat_id
                    self._cell_activity[cell] = True
        self._active_mesh = self._build_active_grid()

    def _get_sediment_material_cells(self):
        """Gets the cells for each sediment material."""
        if self._sed_mat_cov is None:
            return
        # If a polygon has no component id, or if assigned the 'OFF' material, then the snapped cells are inactive
        snap = SnapPolygon()
        snap.set_grid(self._co_grid, True)
        polygons = self._sed_mat_cov.polygons
        snap.add_polygons(polygons)
        for idx in range(len(polygons)):
            polygon_id = polygons[idx].id
            mat_id = SedimentMaterialsIO.UNASSIGNED_MAT
            comp_id = self._sed_mat_comp.get_comp_id(TargetType.polygon, polygon_id)
            if comp_id is None:
                continue
            else:
                mat_id = comp_id
            cells = snap.get_cells_in_polygon(polygon_id)
            for cell in cells:
                self._mat_combo_cell[cell][1] = mat_id

    def _build_active_grid(self):
        """Creates a copy of the grid with only the active cells.

        Returns:
            A CoGrid with only active cells in it.
        """
        # If there is no activity, then return the whole mesh unchanged.
        if not self._cell_activity or False not in self._cell_activity:
            return self._co_grid
        cell_stream = self._co_grid.ugrid.cellstream
        # All elements should have a cell type, a cell point count, and 4-8 point ids
        split_cell_stream = []
        idx = 0
        while idx < len(cell_stream):
            # Add one for the cell point count and then the cell points
            # Add one more to get past the last point
            new_idx = idx + cell_stream[idx + 1] + 2
            split_cell_stream.append(cell_stream[idx:new_idx])
            idx = new_idx
        if len(split_cell_stream) != len(self._cell_activity):
            raise Exception('Unable to create a grid of active cells.')
        active_cell_stream = []
        # Only add back in cells that are active.
        for idx, cell in enumerate(split_cell_stream):
            if self._cell_activity[idx]:
                active_cell_stream.extend(cell)
        # Build the new grid.
        xmugrid = UGrid(self._co_grid.ugrid.locations, active_cell_stream)
        co_builder = UGridBuilder()
        co_builder.set_is_2d()
        co_builder.set_ugrid(xmugrid)
        active_grid = co_builder.build_grid()
        return active_grid
