"""Module for the MapMaterialsRunner."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"
__all__ = ['MapMaterialsRunner']

# 1. Standard Python modules
import itertools
from typing import Callable, Sequence

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.components.component_builders.main_file_maker import make_main_file
from xms.constraint import UGrid2d
from xms.data_objects.parameters import Coverage, Polygon
from xms.gmi.data.generic_model import UNASSIGNED_MATERIAL_ID
from xms.guipy.data.target_type import TargetType
from xms.guipy.dialogs.feedback_thread import ExpectedError, FeedbackThread
from xms.guipy.dialogs.log_timer import Timer
from xms.snap import SnapPolygon

# 4. Local modules
from xms.hydroas.components.mapped_material_component import MappedMaterialComponent
from xms.hydroas.data.material_data import MaterialData
from xms.hydroas.dmi.xms_data import XmsData


class MapMaterialsRunner(FeedbackThread):
    """Feedback thread for mapping materials."""
    def __init__(self, data: XmsData):
        """
        Initialize the class.

        Args:
            data: Interprocess communication object.
        """
        super().__init__(create_query=False)
        self._data = data
        self.display_text |= {
            'title': 'Mapping material coverage',
            'working_prompt': 'Mapping material coverage. Please wait...',
        }

    def _run(self):
        """Run the feedback thread."""
        self._data.unlink_materials()

        self._log.info('Retrieving mesh...')
        ugrid = self._data.ugrid
        if not ugrid:
            raise ExpectedError('Cannot map materials with no mesh. Link a mesh first, then apply materials again.')

        self._log.info('Retrieving material coverage...')
        coverage, data = self._data.material_coverage

        self._log.info('Identifying materials in coverage...')
        mapping = _identify_materials(data)

        self._log.info('Retrieving material data...')
        _retrieve_material_data(coverage, data, mapping, self._log.info)

        self._log.info('Classifying material polygons...')
        polygons, names, values = _classify_material_polygons(data, mapping)

        self._log.info('Snapping material polygons...')
        cell_id_to_material_id = _snap_material_polygons(ugrid, polygons, self._log.info)

        self._log.info('Building component...')
        mapped_component = _build_component(data, cell_id_to_material_id, self._data.grid_hash)

        self._data.add_mapped_materials(mapped_component)
        self._data.link(mapped_component.uuid)


def _identify_materials(data: MaterialData) -> dict[str, list]:
    """
    Identify all the materials in the coverage.

    Args:
        data: Data manager for the unmapped material coverage.

    Returns:
        Dictionary of material_type -> []. material_type is the material's name in GenericModel. The values are all
        empty lists.
    """
    mapping = {}  # material_type -> [polygons]
    model = data.generic_model
    model.material_parameters.restore_values(data.material_values)
    for material_type in model.material_parameters.group_names:
        if material_type != UNASSIGNED_MATERIAL_ID:
            mapping[material_type] = []

    return mapping


def _retrieve_material_data(
    coverage: Coverage, data: MaterialData, mapping: dict[str, list], log: Callable[[str], None]
):
    """
    Populate mapping with material data.

    Args:
        coverage: Coverage containing polygons.
        data: Data manager for the unmapped materials.
        mapping: Contains one key for each material type, whose value is an empty list. The lists will be filled with
            polygons after this function returns.
        log: Callback to send progress messages to.
    """
    timer = Timer()
    for number, polygon in enumerate(coverage.polygons, start=1):
        if timer.report_due:
            log(f'Retrieved {number} polygons...')
        material_type = data.feature_type(TargetType.polygon, feature_id=polygon.id)
        # The mapping shouldn't have the unassigned material in it, so this filters out polygons with the unassigned
        # material. The mapping also won't contain materials that were deleted, so this will also filter out polygons
        # that are assigned deleted materials.
        if material_type in mapping:
            mapping[material_type].append(polygon)


def _classify_material_polygons(data: MaterialData, mapping: dict[str, list]) -> tuple[list, list, list]:
    """
    Build parallel lists of: list of polygons, material name, material values.

    Args:
        data: Data manager for the unmapped material coverage.
        mapping: Mapping from material type (group name in GenericModel) to list of polygons for that material.

    Returns:
        List of (polygons, name, values). polygons is a list of polygons defining the material, name is the name of
        the material (as displayed to the user), and values is a GenericModel Group.
    """
    material_polygons = []
    material_names = []
    material_values = []

    model = data.generic_model
    model.material_parameters.restore_values(data.material_values)

    for material_type, polygons in mapping.items():
        label = model.material_parameters.group(material_type).label
        values = model.material_parameters.group(material_type)
        material_names.append(label)
        material_values.append(values)
        material_polygons.append(polygons)

    return material_polygons, material_names, material_values


def _snap_material_polygons(ugrid: UGrid2d, polygons: list[list], report_progress: Callable) -> Sequence[int]:
    """
    Snap some polygons to a UGrid.

    Args:
        ugrid: The UGrid to snap to.
        polygons: The polygons to snap. Each element is a list of polygons defining the area covered by a material.
        report_progress: Callable to report progress with.

    Returns:
        List of indexes into `polygons`, parallel to the list of cells in the UGrid.
    """
    cell_id_to_material_id = np.zeros(shape=(ugrid.ugrid.cell_count, ), dtype=int)

    if not polygons:
        return cell_id_to_material_id

    report_progress('Flattening polygons...')
    flattened_polygons = list(itertools.chain.from_iterable(polygons))
    original_ids = [polygon.id for polygon in flattened_polygons]
    for new_id, polygon in enumerate(flattened_polygons, start=1):
        polygon.id = new_id

    report_progress('Creating snapper...')
    snapper = SnapPolygon()
    report_progress('Adding grid...')
    snapper.set_grid(ugrid, False)
    report_progress('Adding polygons...')
    snapper.add_polygons(flattened_polygons)
    for material_id, polygon_list in enumerate(polygons, start=1):
        for material_polygon in polygon_list:
            cells = snapper.get_cells_in_polygon(material_polygon.id)
            cells = list(cells)  # numpy treats tuples as multiple dimensions instead of indexes.
            cell_id_to_material_id[list(cells)] = material_id

    for index, polygon in enumerate(flattened_polygons):
        polygon.id = original_ids[index]

    return cell_id_to_material_id


def _build_component(
    unmapped_data: MaterialData, cell_materials: Sequence[int], grid_hash: str
) -> MappedMaterialComponent:
    """
    Create the mapped material component.

    Args:
        unmapped_data: The unmapped material component's data manager.
        cell_materials: List of material indexes for each cell.

    Returns:
        The created component.
    """
    main_file = make_main_file(MappedMaterialComponent)
    component = MappedMaterialComponent(main_file)
    mapped_data = component.data
    mapped_data.generic_model = unmapped_data.generic_model
    mapped_data.material_values = unmapped_data.material_values
    mapped_data.cell_materials = cell_materials
    mapped_data.grid_hash = grid_hash
    mapped_data.commit()
    return component


def _get_cell_id_to_material_id_map(material_polygons: list[list[Polygon]], mesh):
    """
    Get a mapping from cell ID to material ID.

    Args:
        material_polygons: Polygons defining the areas covered by materials.
                           Each element is a list of all the polygons covering the area of that material.
        mesh: The mesh to make the mapping for.

    Returns:
        A dict-like object mapping cell IDs in the mesh to material ID to apply to the cell.
    """
