"""This module checks the simulation for data that could cause problems with running the AdH model."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from collections import defaultdict, deque
import os
from pathlib import Path
import traceback

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import ModelCheckError
from xms.data_objects.parameters import Arc
from xms.grid.ugrid import UGrid
from xms.guipy.data.target_type import TargetType

# 4. Local modules
from xms.adh.data.adh_data import AdhData
from xms.adh.data.model import get_model
from xms.adh.data.sediment_materials_io import SedimentMaterialsIO


class AdHModelChecker:
    """Model Checker for AdH."""
    def __init__(self, adh_data: AdhData):
        """Construct the model checker.

        Args:
            adh_data: The AdH data object.
        """
        self._adh_data = adh_data
        self.errors = []

    def _add_error(self, problem, description='', fix=''):
        """Append a model check error to the list we will send back to SMS.

        Args:
            problem (str): The problem text
            description (str): The problem description text
            fix (str): The problem fix text
        """
        error = ModelCheckError(problem=problem, description=description, fix=fix)
        self.errors.append(error)

    def _check_model_control(self):
        """Make sure we have a valid model control constant values."""
        if self._adh_data.model_control is None:
            return

        model_constants = self._adh_data.model_control.param_control.model_constants
        if model_constants.enable_wet_dry_stabilization:
            if model_constants.wet_dry_stabilization_length == 0:
                problem = 'Warning! Wet dry stabilization length is 0.'
                description = 'Wet dry stabilization is enabled, but the length is 0.'
                fix = 'Disable wet dry stabilization length or set it to a non-zero length.'
                self._add_error(problem, description, fix)

    def _check_mesh(self):
        """Make sure we have a valid mesh."""
        # Make sure a mesh is in the simulation.
        if not self._adh_data.co_grid:
            problem = 'STOP! Simulation requires an unstructured mesh.'
            description = 'An unstructured mesh is required for this simulation.'
            fix = 'Add an unstructured mesh to the simulation.'
            self._add_error(problem, description, fix)
            return

        # All elements should have a cell type, a cell point count, and 3 point ids
        if not self._adh_data.co_grid.check_all_cells_are_of_type(UGrid.cell_type_enum.TRIANGLE):
            problem = 'STOP! Not all elements are triangles.'
            description = 'An unstructured 2D mesh where all elements are triangles is required for this simulation.'
            fix = 'Split the non-triangular elements into triangles.'
            self._add_error(problem, description, fix)

    def _check_materials(self):
        """Make sure that we have valid hydrodynamic materials."""
        # Make sure that we have the material component.
        if self._adh_data.materials_io is None:
            problem = 'STOP! Simulation requires an AdH material coverage.'
            description = 'An AdH material coverage is required for this simulation.'
            fix = 'Add an AdH material coverage to the simulation.'
            self._add_error(problem, description, fix)
            return

        if len(self._adh_data.materials_io.materials.material_properties) == 1:
            problem = 'STOP! Simulation requires an active material.'
            description = 'A material that is not OFF is required for this simulation.'
            fix = 'Add a material in the material list of the material coverage.'
            self._add_error(problem, description, fix)
            return

        fric = self._adh_data.materials_io.materials.friction
        mannings = fric.loc[fric['CARD_2'] == 'MNG']
        mannings = mannings.loc[mannings['REAL_01'] == 0.0]
        use_sed = self._adh_data.sediment_materials_io is not None
        for row_data in mannings.itertuples():
            comp_id = row_data.STRING_ID
            mat_name = self._adh_data.materials_io.materials.material_properties[comp_id].material_name
            if use_sed:
                problem = "STOP! Manning's N cannot be 0.0 when sediment transport is used."
            else:
                problem = "Warning! Manning's N = 0.0 is only valid for analytical purposes."
            description = f"Material '{mat_name}' has a Manning's N value of 0.0."
            fix = f"Change the Manning's N value for material {mat_name} to a value greater than 0.0."
            self._add_error(problem, description, fix)

    def _check_constituents(self):
        """Make sure that we are using transport constituents correctly."""
        if self._adh_data.transport_constituents_io is None:
            return

        mat_uuid = self._adh_data.materials_io.info.attrs['transport_uuid']
        bc_uuid = self._adh_data.bc_io.info.attrs['transport_uuid']
        if mat_uuid and bc_uuid and mat_uuid != bc_uuid:
            problem = 'STOP! Conflicting transport constituents used.'
            description = 'The materials coverage and the boundary conditions coverage refer to different transport ' \
                          'constituent components for this simulation.'
            fix = 'Change either the material coverage or the boundary conditions coverage to use the same transport ' \
                  'constituent component.'
            self._add_error(problem, description, fix)
            return

        if not self._adh_data.transport_constituents_io.param_control.salinity and \
                not self._adh_data.transport_constituents_io.param_control.temperature and \
                not self._adh_data.transport_constituents_io.param_control.vorticity and \
                len(self._adh_data.transport_constituents_io.user_constituents.index) == 0:
            problem = 'Warning! A transport constituents component is used without any constituents specified.'
            description = 'The transport constituents component does not have any user defined constituents, ' \
                          'nor does it use salinity, temperature, or vorticity.'
            fix = 'Add a constituent in the transport constituents component.'
            self._add_error(problem, description, fix)

    def _check_sediment_constituents(self):
        """Make sure that we are using sediment transport constituents correctly."""
        if self._adh_data.sediment_constituents_io is None:
            return

        if self._adh_data.sediment_materials_io is None:
            problem = 'STOP! Sediment transport constituents used without a sediment material coverage.'
            description = 'The sediment materials coverage is necessary when using a ' \
                          'sediment transport constituent component.'
            fix = 'Add a sediment material coverage to this simulation.'
            self._add_error(problem, description, fix)
            return

        if self._adh_data.bc_io is None:
            return

        mat_uuid = self._adh_data.sediment_materials_io.info.attrs['sediment_transport_uuid']
        bc_uuid = self._adh_data.bc_io.info.attrs['sediment_uuid']
        if mat_uuid and bc_uuid and mat_uuid != bc_uuid:
            problem = 'STOP! Conflicting sediment transport constituents used.'
            description = 'The sediment materials coverage and the boundary conditions coverage refer to different ' \
                          'sediment transport constituent components for this simulation.'
            fix = 'Change either the sediment material coverage or the boundary conditions coverage to use the same ' \
                  'sediment transport constituent component.'
            self._add_error(problem, description, fix)
            return

        if len(self._adh_data.sediment_constituents_io.param_control.sand.index) == 0 and \
                len(self._adh_data.sediment_constituents_io.param_control.clay.index) == 0:
            problem = 'Warning! A sediment transport constituents component is used without any constituents specified.'
            description = 'The sediment transport constituents component does not have any constituents.'
            fix = 'Add a constituent in the sediment transport constituents component.'
            self._add_error(problem, description, fix)

    def _check_sediment(self):
        """Make sure that we are using sediment materials and global options correctly."""
        if self._adh_data.sediment_materials_io is None:
            return

        if self._adh_data.sediment_constituents_io is None:
            problem = 'Error! No sediment transport constituents are specified.'
            description = 'Sediment transport requires a sediment transport constituents component.'
            fix = 'Specify the sediment transport constituents component in the Global Sediment Properties dialog ' \
                  'of the sediment materials coverage.'
            self._add_error(problem, description, fix)

        # Check that there are bed layers specified.
        global_sed = self._adh_data.sediment_materials_io.materials[SedimentMaterialsIO.UNASSIGNED_MAT]
        if len(global_sed.bed_layers.index) == 0:
            problem = 'Error! No bed layers are specified.'
            description = 'Sediment transport requires at least 1 bed layer.'
            fix = 'Add a bed layer in the Global Sediment Properties dialog of the sediment materials coverage.'
            self._add_error(problem, description, fix)

        # Check for correct units when using sediment.
        gravity = self._adh_data.model_control.param_control.model_constants.gravity
        if gravity < 9.81 - 0.75 or gravity > 9.81 + 0.75:
            problem = 'Error! Must use MKS units when using sediment transport.'
            description = 'The units detected from gravity are not correct.'
            fix = 'Change all values to use MKS units.'
            self._add_error(problem, description, fix)

    def _check_boundary_conditions(self):
        """Make sure that we are using boundary conditions correctly."""
        self._check_boundary_coverage()
        self._check_time_series()

    def _check_boundary_coverage(self):
        """Make sure there is a boundary condition coverage."""
        if self._adh_data.bc_io is None:
            problem = 'Warning! Simulation does not have an AdH boundary conditions coverage.'
            description = 'An AdH boundary conditions coverage is not present for this simulation.'
            fix = 'Add an AdH boundary conditions coverage to the simulation.'
            self._add_error(problem, description, fix)

    def _check_time_series(self):
        """Make sure that time series are used correctly."""
        if self._adh_data.bc_io is None:
            return

        time_series = self._adh_data.bc_io.bc.time_series
        for _key, ts in time_series.items():
            dupes = ts.time_series['X'].duplicated()
            any_duplicates = dupes.any()
            if any_duplicates:
                problem = 'Warning! Duplicate time series values detected.'
                description = 'The time series contains duplicate time values.'
                fix = 'Remove the duplicate time values.'
                self._add_error(problem, description, fix)

    def _check_vessel(self, coverage_arcs: list[Arc], first_segment: Arc):
        """Ensure the vessel arcs form one unbranched path with 2 ends, and that `first_segment` is one of the ends.

        Args:
           coverage_arcs: A list of arcs in the vessel coverage.
           first_segment: The user-defined segment at the start of the path.
        """
        # Build the full list of arcs
        all_arcs = [first_segment] + coverage_arcs

        # Build node‐degree map and adjacency
        node_to_arcs = defaultdict(list)
        for arc in all_arcs:
            for node in (arc.start_node, arc.end_node):
                coord = (node.x, node.y)
                node_to_arcs[coord].append(arc)

        # Check degrees
        endpoint_nodes = []
        invalid_nodes = []
        for coord, arcs_at_node in node_to_arcs.items():
            deg = len(arcs_at_node)
            if deg == 1:
                endpoint_nodes.append(coord)
            elif deg == 2:
                continue
            else:
                invalid_nodes.append((coord, deg))

        # Detect branches in the path
        if invalid_nodes:
            problem = "Error! Vessel path is branched."
            description = (
                "Each interior node must connect exactly two arcs, and only the two endpoints may connect one arc."
            )
            fix = "Inspect nodes and remove branches in connected segments."
            self._add_error(problem, description, fix)
            return

        # Check for more than 2 endpoints
        if len(endpoint_nodes) != 2:
            problem = "Error! Vessel path must have exactly two endpoints."
            description = (
                f"Found {len(endpoint_nodes)} endpoints, but a valid path must have exactly two."
            )
            fix = "Ensure the vessel arcs does not form a loop or contain branches."
            self._add_error(problem, description, fix)
            return

        # Ensure the first_segment is attached to one of the endpoints
        fs_ends = {
            (first_segment.start_node.x, first_segment.start_node.y),
            (first_segment.end_node.x, first_segment.end_node.y),
        }
        if not (fs_ends & set(endpoint_nodes)):
            problem = "Error! First segment is not at a path endpoint."
            description = (
                "The designated first segment must be one of the two end arcs."
            )
            fix = "Choose a first segment at the start or end of the connected arcs."
            self._add_error(problem, description, fix)
            return

        # Ensure that all arcs are connected in one chain
        start_coord = endpoint_nodes[0]
        visited_arcs = set()
        queue = deque([start_coord])
        visited_nodes = {start_coord}

        while queue:
            coord = queue.popleft()
            for arc in node_to_arcs[coord]:
                if arc not in visited_arcs:
                    visited_arcs.add(arc)
                    # enqueue the other end of this arc
                    n1 = (arc.start_node.x, arc.start_node.y)
                    n2 = (arc.end_node.x, arc.end_node.y)
                    other = n2 if coord == n1 else n1
                    if other not in visited_nodes:
                        visited_nodes.add(other)
                        queue.append(other)

        if len(visited_arcs) != len(all_arcs):
            problem = "Error! Vessel arcs are not all connected."
            description = (
                f"Only {len(visited_arcs)} out of {len(all_arcs)} arcs "
                "were reachable from one endpoint."
            )
            fix = "Check for disconnected arcs or gaps in the vessel chain."
            self._add_error(problem, description, fix)
            return

        # If we reach here, the vessel is a single path with correct endpoints.
        return

    def _check_vessels(self):
        """Make sure we are using the vessel coverages correctly."""
        if self._adh_data.model_control is None:
            return
        # In case the user never opened the model control
        if 'uuids' not in self._adh_data.model_control.vessel_uuids:
            return
        # Get selected vessel coverages from the model control
        vessel_uuids = self._adh_data.model_control.vessel_uuids['uuids'].values.tolist()
        if not vessel_uuids:
            return
        for coverage, component in zip(self._adh_data.vessel_coverages, self._adh_data.vessel_components):
            # Ensure the coverage is enabled in the model control
            if coverage.uuid not in vessel_uuids:
                continue
            data = component.data
            coverage_arcs = []
            first_segment = None
            first_segment_count = 0
            for arc in coverage.arcs:
                values = data.feature_values(TargetType.arc, feature_id=arc.id)

                # Get arc parameters for getting first segments
                arc_parameters = get_model().arc_parameters
                arc_parameters.restore_values(values)

                # find the first segment
                if arc_parameters.group('first_segment').is_active:
                    first_segment = arc
                    first_segment_count += 1
                else:
                    coverage_arcs.append(arc)
            # If there isn't exactly 1 first segment, raise an error
            if first_segment_count != 1:
                problem = f'Incorrect number of first segments found in coverage "{coverage.name}".'
                description = f'Expected exactly 1 first segment per coverage, but found {first_segment_count}.'
                fix = f'Modify coverage "{coverage.name}" to ensure it contains exactly one first segment.'
                self._add_error(problem, description, fix)
                # Skip checking the coverage
                continue
            self._check_vessel(coverage_arcs, first_segment)

    def get_problems(self):
        """Get a list of the problem strings after running model checks."""
        return [check.problem_text for check in self.errors]

    def get_descriptions(self):
        """Get a list of the description strings after running model checks."""
        return [check.description_text for check in self.errors]

    def get_fixes(self):
        """Get a list of the fix strings after running model checks."""
        return [check.fix_text for check in self.errors]

    def run_checks(self):
        """Parent function for all checks."""
        try:
            self._check_mesh()
            self._check_model_control()
            self._check_materials()
            self._check_constituents()
            self._check_sediment_constituents()
            self._check_sediment()
            self._check_boundary_conditions()
            self._check_vessels()
        except Exception as ex:
            error_file = Path(f'adh_model_check_{os.getpid()}.log').resolve()
            with open(error_file, 'a') as file:
                traceback.print_exception(type(ex), ex, ex.__traceback__, file=file)
            self._add_error(problem='Error!', description="Model check program error.",
                            fix=f'See {error_file} for details.')
            raise ex
