"""Module for the ModelNativeExportThread."""

__copyright__ = "(C) Copyright Aquaveo 2024"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os
from pathlib import Path

# 2. Third party modules

# 3. Aquaveo modules
from xms.constraint import RectilinearGrid2d
from xms.data_objects.parameters import Arc, Coverage
from xms.gmi.data.generic_model import Section
from xms.gmi.data_bases.coverage_base_data import CoverageBaseData
from xms.guipy.data.target_type import TargetType
from xms.guipy.dialogs.feedback_thread import ExpectedError, FeedbackThread

# 4. Local modules
from xms.aeolis.data.model import get_model
from xms.aeolis.dmi.xms_data import XmsData
from xms.aeolis.file_io.simulation_exporter import export_simulation


class ModelNativeExportThread(FeedbackThread):
    """Thread for exporting a simulation to model native files."""
    def __init__(self, data: XmsData):
        """
        Construct the worker.

        Args:
            data: Interprocess communication object.
        """
        super().__init__(None, is_export=True, create_query=False)
        self._data = data

    def _run(self):
        """Run the thread."""
        model_control = self._data.model_control
        process_control = self._data.process_control
        ugrid = self._get_ugrid()
        coverage, data = self._get_coverage_and_data()
        arcs = extract_arc_data(coverage, data)
        where = Path(os.getcwd())
        export_simulation(model_control, process_control, ugrid, arcs, self._report_progress, where)

    def _get_ugrid(self) -> RectilinearGrid2d:
        """Get the UGrid or raise an error."""
        ugrid = self._data.linked_grid
        if not ugrid:
            raise ExpectedError('No grid was linked to the simulation.')

        assert isinstance(ugrid, RectilinearGrid2d)
        return ugrid

    def _get_coverage_and_data(self) -> tuple[Coverage, CoverageBaseData]:
        """Get the coverage and its data manager or raise an error."""
        coverage, data = self._data.linked_coverage
        if not coverage:
            raise ExpectedError('No coverage was linked to the simulation.')
        return coverage, data

    def _report_progress(self, message: str):
        """
        Method for reporting progress.

        Args:
            message: The message to report to the user.
        """
        self._log.info(message)


def extract_arc_data(coverage: Coverage, data: CoverageBaseData) -> list[tuple[Arc, Section]]:
    """
    Extract arcs and their parameters.

    Args:
        coverage: Coverage containing arcs to extract.
        data: Data manager to extract arc parameters from.

    Returns:
        Extracted arcs and their parameters.
    """
    arc_data = []
    section = get_model().arc_parameters

    for arc in coverage.arcs:
        section_copy = section.copy()
        values = data.feature_values(TargetType.arc, feature_id=arc.id)
        section_copy.restore_values(values)
        arc_data.append((arc, section_copy))

    return arc_data
