"""This module checks the simulation for data that could cause problems with running the CMS-Flow model."""

__copyright__ = "(C) Copyright Aquaveo 2024"
__license__ = "All rights reserved"

# 1. Standard Python modules
from datetime import datetime
import os
from typing import cast, Optional

# 2. Third party modules
import xarray as xr

# 3. Aquaveo modules
from xms.api.dmi import ModelCheckError
from xms.data_objects.parameters import FilterLocation
from xms.grid.ugrid import UGrid
from xms.snap.snap_point import SnapPoint

# 4. Local modules
from xms.cmsflow import _cmsflow_cmcards_export as card_export
from xms.cmsflow.data.bc_data import BCData
from xms.cmsflow.dmi.xms_data import XmsData
from xms.cmsflow.mapping.coverage_mapper import CoverageMapper


class CMSFlowModelChecker:
    """Model Checker for CMS-Flow."""
    def __init__(self, xms_data: XmsData):
        """Construct the model checker.

        Args:
            xms_data: Data retrieval object.
        """
        self._data = xms_data
        self._errors = []

    def _add_error(self, problem, description='', fix=''):
        """Append a model check error to the list we will send back to SMS.

        Args:
            problem (str): The problem text.
            description (str): The problem description text.
            fix (str): The problem's fix text.
        """
        self._errors.append(ModelCheckError(problem=problem, description=description, fix=fix))

    def _check_grid(self):
        """Make sure we have a Quadtree, and its native projection matches the display projection."""
        if not self._data.ugrid:  # No linked mesh
            self._add_error('Error: No Quadtree has been linked to the simulation.')
            return

        # Ensure that the current display projection matches the CMS-Flow mesh.
        native_proj = self._data.native_ugrid_projection
        display_proj = self._data.ugrid_projection
        horiz_units_mismatch = native_proj.horizontal_units and \
            native_proj.horizontal_units != display_proj.horizontal_units
        vert_units_mismatch = native_proj.vertical_units and native_proj.vertical_units != display_proj.vertical_units
        native_proj_is_local = native_proj.coordinate_system.upper() in ['LOCAL', 'NONE']
        display_proj_is_local = display_proj.coordinate_system.upper() in ['LOCAL', 'NONE']
        coord_sys_match = (native_proj_is_local and display_proj_is_local) or \
                          (native_proj.coordinate_system == display_proj.coordinate_system)
        if not coord_sys_match or horiz_units_mismatch or vert_units_mismatch \
                or native_proj.coordinate_zone != display_proj.coordinate_zone \
                or native_proj.vertical_datum != display_proj.vertical_datum:
            self._add_error(
                'Error: Display projection does not match the CMS-Flow Quadtree. Change the display projection to '
                'match the Quadtree\'s object projection before export, or regenerate the Quadtree in the display '
                'projection.'
            )
        if native_proj.coordinate_system == 'GEOGRAPHIC':
            self._add_error(
                'Error: The CMS-Flow Quadtree is currently in a geographic projection. This is not recommended for '
                'CMS-Flow. Please regenerate in a rectangular projection.'
            )

    def _check_boundaries(self):
        """Make sure we have a boundary conditions coverage linked."""
        coverage, component = self._data.bc_coverage
        if not coverage:  # No linked BC overage
            self._add_error('Error: No boundary conditions coverage has been linked to the simulation.')
            return

        attributes = self._data.bc_arc_attributes
        if not attributes:
            self._add_error('Error: No assigned arcs found in the linked boundary conditions coverage.')
            return

        for arc_id, arc_attributes in attributes.items():
            self._check_boundary(arc_id, arc_attributes)

    def _check_boundary(self, arc_id: int, attributes: xr.Dataset):
        self._check_parent_cmsflow(arc_id, attributes)
        self._check_parent_adcirc(arc_id, attributes)

    def _check_parent_cmsflow(self, arc_id: int, attributes: xr.Dataset):
        if attributes['bc_type'].item() != 'WSE-forcing' or attributes['wse_source'].item() != 'Parent CMS-Flow':
            return

        parent_file = attributes['parent_cmsflow'].item()
        self._check_bc_file(parent_file, 'Parent .cmcards file', arc_id)

    def _check_parent_adcirc(self, arc_id: int, attributes: xr.Dataset):
        if attributes['bc_type'].item() != 'WSE-forcing' or attributes['wse_source'].item() != 'Parent ADCIRC':
            return

        if attributes['parent_adcirc_solution_type'] == 'ASCII':
            file = attributes['parent_adcirc_14'].item()
            self._check_bc_file(file, 'Parent grid (fort.14)', arc_id)
            file = attributes['parent_adcirc_63'].item()
            self._check_bc_file(file, 'Parent solution (fort.63)', arc_id)
        if attributes['parent_adcirc_solution_type'] == 'ASCII' and attributes['use_velocity'].item() != 0:
            file = attributes['parent_adcirc_64'].item()
            self._check_bc_file(file, 'Parent solution (fort.64)', arc_id)
        if attributes['parent_adcirc_solution_type'] == 'XMDF':
            adcirc_mesh = self._data.item_exists(attributes['parent_adcirc_14_uuid'].item())
            if not adcirc_mesh:
                self._add_error(
                    'Error: No referenced parent ADCIRC grid solution was found for a Parent ADCIRC boundary arc.'
                )
                return

            adcirc_wse = self._data.item_exists(attributes['parent_adcirc_solution_wse'].item())
            if not adcirc_wse:
                self._add_error('Error: No referenced parent ADCIRC WSE solution found.')
            if attributes['use_velocity'].item() != 0:
                adcirc_velocity = self._data.item_exists(attributes['parent_adcirc_solution'].item())
                if not adcirc_velocity:
                    self._add_error('Error: No referenced parent ADCIRC velocity solution found.')

    def _check_sim(self):
        """Check miscellaneous simulation inputs."""
        # Ensure a valid bottom roughness dataset has been selected.
        sim_data = self._data.sim_data
        dset_uuid = sim_data.flow.attrs['BOTTOM_ROUGHNESS_DSET']
        rough_type = sim_data.flow.attrs['BOTTOM_ROUGHNESS']
        rough_source = sim_data.flow.attrs['ROUGHNESS_SOURCE']
        rough_value = sim_data.flow.attrs['ROUGHNESS_CONSTANT']
        if rough_source == 'Constant' and rough_value < 0.0:
            self._add_error(
                f'Error: A positive constant value for {rough_type} must be selected '
                'in the "Flow" tab of the Model Control dialog.'
            )
        if rough_source == 'Constant' and rough_value >= 1.0:
            self._add_error(
                f'Error: A constant value of {rough_value} is too large for {rough_type}. '
                'Please select a smaller constant value in the "Flow" tab of the Model Control dialog.'
            )
        elif rough_source == 'Dataset' and not self._data.item_exists(dset_uuid):
            self._add_error(
                'Error: No bottom roughness dataset selected. Please select a '
                f'{rough_type} dataset in the "Flow" tab of the Model Control dialog.'
            )
        # If the C2SHORE sediment transport formula is being used, ensure wave forcing is enabled.
        if sim_data.sediment.attrs['TRANSPORT_FORMULA'] == 'C2SHORE' and sim_data.wave.attrs['WAVE_INFO'] == 'None':
            self._add_error(
                'Error: Wave forcing must be enabled to use the C2SHORE sediment transport formula. Enable wave '
                'forcing in the "Wave" tab of the Model Control dialog, or change the "Transport formula" option in '
                'the "Sediment" tab.'
            )

    def _check_save_points(self):
        """Make sure the closest cell to each save point is active."""
        save_points_cov, save_points_comp = self._data.save_points_coverage
        if not save_points_cov or not self._data.ugrid:
            return  # Save points not required.

        points = save_points_cov.get_points(FilterLocation.PT_LOC_DISJOINT)
        if not points:
            return  # No points in the save coverage
        point_ids = [point.id for point in points]
        # Build a mapper for the save points, which will snap to the closest active point
        cov_mapper = CoverageMapper(generate_snap=False)
        cov_mapper.set_quadtree(quadtree=self._data.ugrid, wkt=self._data.ugrid_projection.well_known_text)
        cov_mapper.set_save_points(coverage=save_points_cov, component=save_points_comp)
        small_grid = cov_mapper.set_activity(activity_cov=self._data.activity_coverage)
        small_ugrid = small_grid.ugrid if small_grid else None
        big_ugrid = self._data.ugrid.ugrid
        if not small_ugrid or small_ugrid.cell_count == big_ugrid.cell_count:
            return  # All cells are active, nothing to check.

        small_snapper = cov_mapper.get_save_points_snapper()
        small_snap_output = small_snapper.get_snapped_points(points)
        small_id_to_big_id = cov_mapper.get_snap_id_to_original_id()

        # Build a snapper for the entire grid, ignoring activity.
        big_snapper = SnapPoint()
        grid = cast(UGrid, self._data.ugrid)
        big_snapper.set_grid(grid=grid, target_cells=True)
        big_snap_output = big_snapper.get_snapped_points(points)
        for point_idx, (small_id, big_id) in enumerate(zip(small_snap_output['id'], big_snap_output['id'])):
            snap_id = small_id_to_big_id[small_id]
            if big_id != snap_id:
                self._add_error(
                    f'Warning: The closest cell to save point with feature id {point_ids[point_idx]} is cell '
                    f'{big_id + 1}, which is in an inactive region. The save point will be snapped to cell '
                    f'{snap_id + 1} on export. If this is not desired, delete the feature point or move it to an '
                    'active region of the grid.'
                )

    def _check_sediment(self):
        """Make sure we have valid datasets for sediment transport."""
        sim_data = self._data.sim_data
        transport_enabled = sim_data.sediment.attrs['CALCULATE_SEDIMENT']
        table_input_required = sim_data.sediment.attrs['FORMULATION_UNITS'] == 'Nonequilibrium total load'
        if not transport_enabled or not table_input_required:
            return

        # Make sure there is at least one row in the  grain size table if using nonequilibrium total load formulation.
        # Only one dimension, don't care what it is named.
        num_grain_sizes = list(sim_data.advanced_sediment_diameters_table.sizes.values())[0]
        if sim_data.sediment.attrs['FORMULATION_UNITS'] == 'Nonequilibrium total load' and num_grain_sizes <= 0:
            self._add_error('Error: A sediment size class diameter is required.')

        # If the simplified multi-grain size option is not enabled, check the bed layers table.
        simple_layers = sim_data.sediment.attrs['ENABLE_SIMPLIFIED_MULTI_GRAIN_SIZE'] == 1
        if not simple_layers:
            num_bed_layers = list(sim_data.bed_layer_table.sizes.values())[0]
            if num_bed_layers > 0:
                # Multiple bed layers are not allowed unless multiple sediment size classes.
                if num_grain_sizes <= 1 and num_bed_layers > 1:
                    self._add_error(
                        'Multiple bed layers can only be defined if '
                        'there are multiple sediment size classes.'
                    )
                elif num_grain_sizes > 1 and num_bed_layers == 1:
                    self._add_error(
                        'Multiple sediment size classes can only be defined if '
                        'there are multiple bed layers.'
                    )
                elif num_grain_sizes == 0 and num_bed_layers == 1:
                    self._add_error('A sediment size class is required.')
                else:
                    bl_table = sim_data.bed_layer_table
                    for i in range(num_bed_layers):
                        # d50 column required on every row.
                        d50_exists = self._data.item_exists(bl_table.d50.data[i])
                        if not d50_exists:
                            self._add_error(f'Error: A d50 dataset is required for row {i + 1}.')

                        has_lower = False
                        has_upper = False

                        # Ensure that there is at least one d05, d10, d16, d20, d30, or d35 column on each row.
                        for ds_name in ['d05', 'd10', 'd16', 'd20', 'd30', 'd35']:
                            node_uuid = bl_table[ds_name].data[i]
                            item_exists = self._data.item_exists(node_uuid)
                            if item_exists:
                                has_lower = True
                                break

                        for ds_name in ['d65', 'd84', 'd90', 'd95']:
                            node_uuid = bl_table[ds_name].data[i]
                            item_exists = self._data.item_exists(node_uuid)
                            if item_exists:
                                has_upper = True
                                break

                        if has_lower and not has_upper:
                            self._add_error(f'Error: At least 1 dataset between d65-d95 is required for row {i + 1}.')
                        elif not has_lower and has_upper:
                            self._add_error(f'Error: At least 1 dataset between d05-d35 is required for row {i + 1}.')
        else:  # Simplified multi grain size
            pass

    def _check_for_missing_simulation_files(self):
        """Check for non-existent file references in the simulation model control."""
        # Check files referenced by the simulation
        sim_data = self._data.sim_data
        # Initial conditions file
        init_file = sim_data.general.attrs['USE_INIT_CONDITIONS_FILE']
        if init_file == 1:
            self._check_file('INIT_CONDITIONS_FILE', sim_data.general, True)
        # Wind files
        has_wind_file = card_export.WIND_TYPES[sim_data.wind.attrs['WIND_TYPE']] == 'File'
        if has_wind_file:
            wind_file_type = card_export.WIND_FILE_TYPES[sim_data.wind.attrs['WIND_FILE_TYPE']]
            fleet_or_ascii = True if wind_file_type in ['Fleet', 'ASCII'] else False
            if wind_file_type == 'OWI':
                self._check_file('OCEAN_WIND_FILE', sim_data.wind, True)
                self._check_file('OCEAN_PRESSURE_FILE', sim_data.wind, True)
                self._check_file('OCEAN_XY_FILE', sim_data.wind, True)
            elif fleet_or_ascii:
                self._check_file('WIND_FILE', sim_data.wind, True)
                wind_grid_type = card_export.WIND_GRID_TYPES[sim_data.wind.attrs['WIND_GRID_TYPE']]
                if wind_grid_type == 'XYFile':
                    self._check_file('WIND_GRID_FILE', sim_data.wind, True)
        # Wave files
        wave_info = card_export.WAVE_INFO_TYPES[sim_data.wave.attrs['WAVE_INFO']]
        if wave_info == 'Inline':
            self._check_file('FILE_WAVE_SIM', sim_data.wave, True)

    def _check_file(self, card, data, is_attr):
        """Check if a file exists and warn the user if it does not.

        Args:
            card (str): Name of the file attribute
            data (xarray.Dataset): The data containing the file reference attribute
            is_attr (bool): True if filename is stored in attrs dict, False if in a xarray.DataArray
        """
        data_value = data.attrs[card] if is_attr else data[card].item()
        filename = card_export.CMSFlowCmcardsExporter.get_file_string_without_report(data_value)
        if not filename:
            self._add_error(f'Error: Referenced {card} file not set.')
        elif not os.path.exists(filename):
            abs_path = os.path.abspath(filename)
            self._add_error(f'Error: Unable to find referenced {card} file: {abs_path}')

    def _check_bc_file(self, file: str, label: str, arc_id: int):
        """
        Check that a file required by an arc boundary condition exists.

        Args:
            file: Name of a file that should be in project_folder.
            label: Label of the attribute. Used for reporting errors to the user.
            arc_id: Feature ID of the arc being checked. Used for reporting errors to the user.
        """
        filename = card_export.CMSFlowCmcardsExporter.get_file_string_without_report(file)
        if not filename:
            self._add_error(f'Error: Arc {arc_id}, parameter {label} not set.')
        elif not os.path.exists(filename):
            abs_path = os.path.abspath(filename)
            self._add_error(f'Error: Arc {arc_id}, parameter {label}, refers to nonexistent file: {abs_path}')

    def _check_wse_extraction(self):
        """Check if there are any extracted WSE forcing errors."""
        errors = check_forcing_errors(self._data)
        if errors:
            self._errors.extend(errors)

    def get_problems(self):
        """Get a list of the problem strings after running model checks."""
        return [check.problem_text for check in self._errors]

    def get_descriptions(self):
        """Get a list of the description strings after running model checks."""
        return [check.description_text for check in self._errors]

    def get_fixes(self):
        """Get a list of the fix strings after running model checks."""
        return [check.fix_text for check in self._errors]

    def run_checks(self):
        """Parent function for all checks."""
        self._check_grid()
        self._check_sim()
        self._check_for_missing_simulation_files()
        self._check_save_points()
        self._check_sediment()
        self._check_boundaries()
        self._check_wse_extraction()
        self._data.add_model_check_errors(self._errors)


wse_requires = 'One or more arcs use a WSE extraction boundary condition, which requires that'


def check_forcing_errors(xms_data: XmsData) -> list[ModelCheckError]:
    """
    Check for elevation and velocity forcing errors.

    Is safe to call when forcing is not needed.

    Args:
        xms_data: Interprocess communication object.

    Returns:
        List of found errors. This may be empty if no errors.
    """
    errors = []
    sim_date_ok = True

    sim_data = xms_data.sim_data
    _coverage, bc_component = xms_data.bc_coverage
    if not bc_component:
        return []

    bc_data: BCData = bc_component.data
    extracted_attrs = {
        key: value
        for key, value in xms_data.bc_arc_attributes.items()
        if value['bc_type'].item() == 'WSE-forcing' and value['wse_source'].item() == 'Extracted'
    }
    extracted_features = list(extracted_attrs.keys())
    if not extracted_features:
        return []

    if not sim_data.simulation_start or not sim_data.simulation_end:
        short = 'The simulation start time was not set.'
        long = f'{wse_requires} the simulation start time and end time are set.'
        fix = (
            "Open the simulation's model control and set the 'Start date/time' and 'Simulation duration' parameters on "
            'the General tab.'
        )
        errors.append(ModelCheckError(short, long, fix))
        sim_date_ok = False

    if not bc_data.wse_forcing_geometry or not xms_data.get_ugrid(bc_data.wse_forcing_geometry):
        short = 'WSE forcing geometry was not set or no longer exists.'
        long = f'{wse_requires} the BC coverage has a parent grid.'
        fix = (
            'Select any arc that uses an Extracted WSE source and click the buttons to select a parent grid '
            'and WSE dataset.'
        )
        errors.append(ModelCheckError(short, long, fix))

    elevation_error = _check_elevation_forcing(xms_data)
    if elevation_error is not None:
        errors.append(elevation_error)

    velocity_error = _check_velocity_forcing(xms_data)
    if velocity_error is not None:
        errors.append(velocity_error)

    # Elevation and velocity errors likely mean the time steps are bad and the combined checks will crash.
    if elevation_error is None and velocity_error is None and sim_date_ok:
        combined_error = _check_elevation_and_velocity_combined(xms_data)
        if combined_error is not None:
            errors.append(combined_error)

    if errors:
        short = 'One or more errors occurred due to Extracted WSE forcing arcs.'
        joined = ', '.join(str(feature) for feature in extracted_features)
        description = f'Arcs with the following IDs use Extracted WSE forcing: {joined}'
        fix = (
            'If Extracted WSE forcing is undesired, reassign the boundary conditions for all the arcs mentioned in the '
            'description so they no longer use Extracted WSE forcing. Otherwise, resolve the other issues instead to '
            'allow Extracted WSE forcing.'
        )
        errors.insert(0, ModelCheckError(short, description, fix))

    return errors


def _check_elevation_forcing(xms_data: XmsData) -> Optional[ModelCheckError]:
    """Check for errors due to a bad elevation-forcing dataset."""
    _coverage, component = xms_data.bc_coverage
    bc_data: BCData = component.data

    if not bc_data.wse_forcing_wse_source:
        short = 'WSE forcing elevation source not set.'
        long = f'{wse_requires} the BC coverage has an elevation source.'
        fix = 'Select any arc that uses an Extracted WSE source and click the button to select a WSE dataset.'
        return ModelCheckError(short, long, fix)

    # CMS used to write its datasets with the wrong era in the ref-time, so dates were always BC instead of AD.
    # Nobody noticed because XMS silently discards the era, but the Python side of things used to discard the whole
    # ref-time instead, which was ambiguous with having no date set. It now throws an exception instead, which we
    # catch here to notify the user that the date might *look* valid, but it's actually wrong.
    try:
        _ = xms_data.get_dataset(bc_data.wse_forcing_wse_source)
    except ValueError:
        short = 'WSE forcing dataset has invalid reference time.'
        long = (
            f'{wse_requires} the referenced datasets have a reference time that is at least 1 C.E., but one or both '
            'referenced datasets appear to have a date that is before this. This was a common bug with datasets '
            'produced by the CMS model executables up to version 5.4.3.'
        )
        fix = 'Click File|Save to save the project. SMS will rewrite the datasets with the correct era.'
        return ModelCheckError(short, long, fix)

    elevation_dataset = xms_data.get_dataset(bc_data.wse_forcing_wse_source)
    if elevation_dataset is None:
        short = 'WSE forcing elevation source no longer exists.'
        long = f'{wse_requires} the BC coverage has an elevation source.'
        fix = 'Select any arc that uses an Extracted WSE source and click the button to select a WSE dataset.'
        return ModelCheckError(short, long, fix)

    if elevation_dataset.ref_time is None:
        path = xms_data.tree_path(elevation_dataset.uuid)
        short = 'WSE forcing elevation dataset has invalid reference time.'
        long = f'{wse_requires} the referenced datasets have a reference time assigned.'
        fix = f'Assign a reference time on the dataset at "{path}".'
        return ModelCheckError(short, long, fix)

    sim_start = xms_data.sim_data.simulation_start
    if sim_start is not None and sim_start < elevation_dataset.ref_time:
        path = xms_data.tree_path(elevation_dataset.uuid)
        short = 'Simulation starts before WSE forcing elevation dataset.'
        long = f'{wse_requires} the elevation dataset starts before or at the start of the simulation.'
        fix = (
            'Assign a different elevation dataset on any Extracted WSE arc, assign a new reference time on the dataset '
            f'at "{path}", or use the Sample Time Steps tool to extrapolate the dataset to the start of the simulation.'
        )
        return ModelCheckError(short, long, fix)

    return None


def _check_velocity_forcing(xms_data: XmsData) -> Optional[ModelCheckError]:
    """
    Check for errors due to a bad velocity forcing dataset.

    Is safe to call when the velocity dataset is not necessary.
    """
    _coverage, component = xms_data.bc_coverage
    bc_data: BCData = component.data
    if not bc_data.wse_forcing_velocity_source:
        return None

    # CMS used to write its datasets with the wrong era in the ref-time, so dates were always BC instead of AD.
    # Nobody noticed because XMS silently discards the era, but the Python side of things used to discard the whole
    # ref-time instead, which was ambiguous with having no date set. It now throws an exception instead, which we
    # catch here to notify the user that the date might *look* valid, but it's actually wrong.
    try:
        _ = xms_data.get_dataset(bc_data.wse_forcing_velocity_source)
    except ValueError:
        short = 'WSE forcing velocity source has invalid reference time.'
        long = (
            f'{wse_requires} the referenced datasets have a reference time that is at least 1 C.E., but one or both '
            'referenced datasets appear to have a date that is before this. This was a common bug with datasets '
            'produced by the CMS model executables up to version 5.4.3.'
        )
        fix = 'Click File|Save to save the project. SMS will rewrite the datasets with the correct era.'
        return ModelCheckError(short, long, fix)

    velocity_dataset = xms_data.get_dataset(bc_data.wse_forcing_velocity_source)
    if velocity_dataset is None:
        short = 'WSE forcing velocity source no longer exists.'
        long = f'{wse_requires} the the referenced velocity dataset exists.'
        fix = (
            'Select any arc that uses an Extracted WSE source and click the button to select a velocity dataset, or '
            'uncheck the box that forces velocity.'
        )
        return ModelCheckError(short, long, fix)

    if velocity_dataset.ref_time is None:
        path = xms_data.tree_path(velocity_dataset.uuid)
        short = 'WSE forcing velocity source has invalid reference time.'
        long = f'{wse_requires} the referenced datasets have a reference time assigned.'
        fix = f'Assign a reference time on the dataset at "{path}".'
        return ModelCheckError(short, long, fix)

    sim_start = xms_data.sim_data.simulation_start
    if sim_start is not None and sim_start < velocity_dataset.ref_time:
        path = xms_data.tree_path(velocity_dataset.uuid)
        short = 'Simulation starts before WSE forcing velocity dataset.'
        long = f'{wse_requires} the velocity dataset starts before or at the start of the simulation.'
        fix = (
            'Assign a different velocity dataset on any Extracted WSE arc, assign a new reference time on the dataset '
            f'at "{path}", disable velocity forcing, or use the Sample Time Steps tool to extrapolate the dataset to '
            'the start of the simulation.'
        )
        return ModelCheckError(short, long, fix)

    return None


def _check_elevation_and_velocity_combined(xms_data: XmsData) -> Optional[ModelCheckError]:
    """
    Check for errors due to inconsistencies between the elevation and velocity forcing datasets.

    Assumes all the checks for the datasets separately have passed. Is safe to call when the velocity dataset is
    unnecessary.
    """
    _coverage, component = xms_data.bc_coverage
    bc_data: BCData = component.data
    if not bc_data.wse_forcing_velocity_source:
        return None

    start = xms_data.sim_data.simulation_start
    end = xms_data.sim_data.simulation_end

    elevation_dataset = xms_data.get_dataset(bc_data.wse_forcing_wse_source)
    elevation_offsets = [elevation_dataset.timestep_offset(i) for i in range(elevation_dataset.num_times)]
    elevation_times = [elevation_dataset.ref_time + offset for offset in elevation_offsets]
    truncated_elevation_times = [time.replace(microsecond=0) for time in elevation_times]
    filtered_elevation_times = list(_filter_times(truncated_elevation_times, start, end))

    velocity_dataset = xms_data.get_dataset(bc_data.wse_forcing_velocity_source)
    velocity_offsets = [velocity_dataset.timestep_offset(i) for i in range(velocity_dataset.num_times)]
    velocity_times = [velocity_dataset.ref_time + offset for offset in velocity_offsets]
    truncated_velocity_times = [time.replace(microsecond=0) for time in velocity_times]
    filtered_velocity_times = list(_filter_times(truncated_velocity_times, start, end))

    if filtered_elevation_times == filtered_velocity_times:
        return None

    short = 'Elevation and velocity datasets have different time steps.'
    long = (
        f'{wse_requires} the elevation and velocity datasets have the same time steps for the duration of the '
        'simulation. There must be the same number of time steps, at the same times, from the start of the simulation '
        'to the end.'
    )
    fix = (
        'Check that the correct datasets have been chosen, and their reference times are correct. If the datasets are '
        'expected to have differing time steps (e.g. because one is higher resolution than the other), then the Sample '
        'Time Steps tool can be used to derive a new dataset from one of the existing ones with time steps that match '
        'the other, including interpolating between times.'
    )
    return ModelCheckError(short, long, fix)


def _filter_times(times: list[datetime], start: datetime, end: datetime) -> list[datetime]:
    result = []
    for time in times:
        if not result:  # Keep the one just before the simulation starts
            result.append(time)
        elif time <= start:
            result[0] = time
        elif start < time <= end:
            result.append(time)
    return result
