"""Map Material and Sediment Material coverage locations and attributes to the SRH-2D domain."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os
import shutil
import uuid

# 2. Third party modules

# 3. Aquaveo modules
from xms.components.display.display_options_io import read_display_options_from_json, write_display_options_to_json
from xms.components.display.display_options_io import write_display_option_polygon_locations
from xms.data_objects.parameters import Component
from xms.guipy.data.category_display_option_list import CategoryDisplayOptionList
from xms.guipy.data.target_type import TargetType
from xms.snap.snap_polygon import SnapPolygon

# 4. Local modules
from xms.srh.components.mapped_material_component import MappedMaterialComponent
from xms.srh.components.mapped_sed_material_component import MappedSedMaterialComponent


class MaterialMapper:
    """Class for mapping material coverage to a mesh for SRH-2D."""
    def __init__(self, coverage_mapper, wkt, sediment, generate_snap):
        """Constructor."""
        self._generate_snap = generate_snap
        self._logger = coverage_mapper._logger
        self._co_grid = coverage_mapper.co_grid
        self._sediment = sediment
        if not self._sediment:
            self._new_comp_unique_name = 'Mapped_Material_Component'
            self._material_component_file = coverage_mapper.material_component_file
            self._material_coverage = coverage_mapper.material_coverage
            self._material_component = coverage_mapper.material_component
        else:
            self._new_comp_unique_name = 'Mapped_Sed_Material_Component'
            self._material_component_file = coverage_mapper.sed_material_component_file
            self._material_coverage = coverage_mapper.sed_material_coverage
            self._material_component = coverage_mapper.sed_material_component
        self._snap_poly = SnapPolygon()
        self._snap_poly.set_grid(grid=self._co_grid, target_cells=False)
        self._snap_poly.add_polygons(polygons=self._material_coverage.polygons)
        self._comp_main_file = ''
        self._poly_to_cells = {}
        self._comp_path = ''
        self._mannings_n = None
        self.sediment_lines = None
        self._mat_df = self._material_component.data.materials.to_dataframe()
        self._mat_comp_ids = self._mat_df['id'].to_list()
        self._mat_names = self._mat_df['Name'].to_list()
        self.mapped_comp_uuid = None
        self.mapped_material_display_uuid = None
        self.grid_wkt = wkt

    def do_map(self):
        """Creates the mapped material component."""
        self._get_polygon_cells()

        df = self._material_component.data.materials.to_dataframe()
        if not self._sediment:
            self._mannings_n = df["Manning's N"].tolist()
            curve = df['Depth Varied Curve'].tolist()
            ids = df['id'].tolist()
            names = df['Name'].tolist()

            for i in range(len(curve)):
                if curve[i] != 0:
                    df1 = self._material_component.data.depth_curve_from_mat_id(ids[i]).to_dataframe()
                    crv_vals = [tuple(r) for r in df1.to_numpy()]
                    if len(crv_vals) > 1:
                        self._mannings_n[i] = [tuple(r) for r in df1.to_numpy()]
                    else:
                        msg = f'Invalid "Depth Varied Curve" defined for "{names[i]}" material. ' \
                              f"The constant Manning's N value will be used instead: {self._mannings_n[i]}."
                        self._logger.warning(msg)
        else:
            self.sediment_lines = []
            ids = df['id'].tolist()
            mat_names = df['Name'].tolist()
            sed_prop = {}
            thick_lines = []
            bed_lines = []
            for mat_idx, (i, mat_name) in enumerate(zip(ids, mat_names)):
                sed_prop[i] = self._material_component.data.sediment_properties_from_mat_id(i).to_dataframe()
                self.sediment_lines.append(f'NumSubSurfaceLayers {mat_idx} {len(sed_prop[i])}')
                thick = sed_prop[i]['Thickness'].tolist()
                units = sed_prop[i]['Units'].tolist()
                # change for SRH 3.3.0 - this is not used
                # density = sed_prop[i]['Density'].tolist()
                for idx, (t, u) in enumerate(zip(thick, units)):
                    unit_str = 'EN' if u == 1 else 'SI'
                    thick_lines.append(f'SubsurfaceThickness {mat_idx} {idx+1} {t} {unit_str} 1')
                    curve = self._material_component.data.gradation_curve_from_mat_id_prop_id(i, idx + 1).to_dataframe()
                    if len(curve) > 0:
                        bed_lines.append((mat_idx, idx + 1, [tuple(r) for r in curve.to_numpy()], mat_name))
            self.sediment_lines.extend(thick_lines)
            self.sediment_lines.extend(bed_lines)

        if self._generate_snap:
            self._create_component_folder_and_copy_display_options()
            self._create_drawing()

            # Create the data_objects component
            comp_name = f'Snapped {self._material_coverage.name} display'
            comp_uuid = os.path.basename(os.path.dirname(self._comp_main_file))
            do_comp = Component(
                main_file=self._comp_main_file,
                name=comp_name,
                model_name='SRH-2D',
                unique_name=self._new_comp_unique_name,
                comp_uuid=comp_uuid
            )

            if not self._sediment:
                comp = MappedMaterialComponent(self._comp_main_file)
            else:
                comp = MappedSedMaterialComponent(self._comp_main_file)
            return do_comp, comp

        return None, None  # pragma: no cover

    def _create_drawing(self):
        """Uses cell ids to get cell point coords to draw polygons for materials mapped to cells."""
        ugrid = self._co_grid.ugrid
        for comp_id, cell_ids in self._poly_to_cells.items():
            poly_list = []
            for cid in cell_ids:
                cell_locs = ugrid.get_cell_locations(cid)
                locs_list = [item for sublist in cell_locs for item in sublist]
                if len(locs_list) < 9:
                    continue  # pragma: no cover
                # repeat the first point
                locs_list.append(locs_list[0])
                locs_list.append(locs_list[1])
                locs_list.append(locs_list[2])
                outer_dict = {'outer': locs_list}
                poly_list.append(outer_dict)
            filename = os.path.join(self._comp_path, f'mat_{comp_id}.matid')
            write_display_option_polygon_locations(filename, poly_list)

    def _get_polygon_cells(self):
        """Uses xmssnap to get the cells for each polygon."""
        self._logger.info('Mapping material coverage to mesh.')
        num_cells = self._co_grid.ugrid.cell_count
        cell_flag = [True] * num_cells
        polys = self._material_coverage.polygons
        for poly in polys:
            cells = self._snap_poly.get_cells_in_polygon(poly.id)
            comp_id = self._material_component.get_comp_id(TargetType.polygon, poly.id)
            if comp_id is None or comp_id < 0:
                comp_id = 0  # pragma: no cover
            if comp_id not in self._poly_to_cells:
                self._poly_to_cells[comp_id] = []
            self._poly_to_cells[comp_id].extend(cells)
            for cid in cells:
                cell_flag[cid] = False
        # add all unassigned cells to comp_id = 0 (unassigned_material)
        if 0 not in self._poly_to_cells:
            self._poly_to_cells[0] = []  # pragma: no cover
        for i in range(len(cell_flag)):
            if cell_flag[i]:
                self._poly_to_cells[0].append(i)
        for i in range(1, len(self._mat_comp_ids)):
            if self._mat_comp_ids[i] not in self._poly_to_cells or not self._poly_to_cells[self._mat_comp_ids[i]]:
                self._logger.warning(f'Material: {self._mat_names[i]} was not assigned to any elements.')

    def _create_component_folder_and_copy_display_options(self):
        """Creates a folder for the mapped material component and copies display options from the material coverage."""
        if self.mapped_comp_uuid is None:
            comp_uuid = str(uuid.uuid4())  # pragma: no cover
        else:
            comp_uuid = self.mapped_comp_uuid
        self._logger.info('Creating component folder')
        mat_comp_path = os.path.dirname(self._material_component_file)
        self._comp_path = os.path.join(os.path.dirname(mat_comp_path), comp_uuid)

        if os.path.exists(self._comp_path):
            shutil.rmtree(self._comp_path)  # pragma: no cover
        os.mkdir(self._comp_path)

        mat_comp_display_file = os.path.join(mat_comp_path, 'material_display_options.json')
        comp_display_file = os.path.join(self._comp_path, 'material_display_options.json')
        if os.path.isfile(mat_comp_display_file):
            shutil.copyfile(mat_comp_display_file, comp_display_file)
            categories = CategoryDisplayOptionList()  # Generates a random UUID key for the display list
            json_dict = read_display_options_from_json(comp_display_file)
            if self.mapped_material_display_uuid is None:
                json_dict['uuid'] = str(uuid.uuid4())  # pragma: no cover
            else:
                json_dict['uuid'] = self.mapped_material_display_uuid
            json_dict['comp_uuid'] = comp_uuid
            json_dict['is_ids'] = 0
            # Set projection of free locations to be that of the mesh/current display
            categories.projection = {'wkt': self.grid_wkt}
            categories.from_dict(json_dict)
            write_display_options_to_json(comp_display_file, categories)
            self._comp_main_file = comp_display_file
        else:
            self._logger.info('Could not find material_display_options.json file')  # pragma: no cover
