"""Imports SRH simulation."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import logging
import os
import uuid

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.core.filesystem import filesystem
from xms.coverage.grid.grid_cell_to_polygon_coverage_builder import GridCellToPolygonCoverageBuilder
from xms.data_objects.parameters import Arc, Component, Coverage, Point, Simulation, UGrid
from xms.guipy.param import param_h5_io

# 4. Local modules
from xms.srh.components.coverage_arc_builder import CoverageArcBuilder
from xms.srh.file_io.bc_component_builder import build_bc_component
from xms.srh.file_io.geom_reader import GeomReader
from xms.srh.file_io.hydro_reader import HydroReader
from xms.srh.file_io.material_component_builder import MaterialComponentBuilder
from xms.srh.file_io.material_reader import MaterialReader
from xms.srh.file_io.monitor_component_builder import build_monitor_component
from xms.srh.file_io.monitor_point_reader import MonitorPointReader
from xms.srh.file_io.obstructions_component_builder import build_obstructions_component


class SRHSimImporter:
    """Read an SRH-2D simulation when a *.srhhydro file is opened in SMS."""
    def __init__(self, xms_data=None):
        """Construct the Importer.

        Args:
            xms_data (:obj:`dict`): XMS data dictionary. Useful for testing because it will avoid any Query calls.
                ::
                    {
                        'filename': '',  # Path to the *.srhhydro file to read
                        'comp_dir': '',  # Path to the XMS "Components" temp folder
                    }

        """
        self._logger = logging.getLogger('xms.srh')
        self._xms_data = xms_data
        self._query = None

        self._hydro_reader = None
        self._geom_reader = None
        self._mpoint_reader = None

        # Stuff we will be sending back to XMS
        self._sim_comp = None
        self._mesh = None
        self._bc_cov = None
        self._bc_do_comp = None  # data_object for BC coverage component
        self._bc_id_to_att_id = {}  # Need mapping of SRH nodestring id to XMS map coverage ids since monitors mixed in
        self._mat_cov = None
        self._mat_do_comp = None  # data_object for Material coverage component
        self._sed_mat_cov = None
        self._sed_mat_do_comp = None  # data_object for Material coverage component
        self._monitor_cov = None
        self._monitor_do_comp = None  # data_object for Monitor coverage
        self._obs_cov = None
        self._obs_do_comp = None  # data_object for Obstructions coverage component

        if not self._xms_data:
            self._get_xms_data()

    def _get_xms_data(self):
        """Get all data from XMS needed to import the SRH-2D simulation."""
        self._logger.info('Retrieving data from XMS...')
        self._xms_data = {
            'filename': '',
            'comp_dir': '',
        }
        try:
            self._query = Query()
            self._xms_data['filename'] = self._query.read_file
            # Get the SMS temp components directory
            self._xms_data['comp_dir'] = os.path.join(self._query.xms_temp_directory, 'Components')
        except:  # noqa - # pragma no cover - hard to test Exceptions using QueryPlayback
            self._logger.exception('Unable to retrieve data from SMS needed to import SRH-2D simulation')

    def _add_xms_data(self):
        """Send imported data to SMS."""
        self._logger.info('Preparing to send imported data to SMS...')

        # Add a new simulation
        sim = Simulation(model='SRH-2D', name='Sim', sim_uuid=str(uuid.uuid4()))
        self._query.add_simulation(sim, components=[self._sim_comp])

        # Add the mesh geometry
        self._query.add_ugrid(self._mesh)
        self._query.link_item(taker_uuid=sim.uuid, taken_uuid=self._mesh.uuid)

        # Add the BC coverage geometry and its hidden component.
        self._query.add_coverage(
            self._bc_cov, model_name='SRH-2D', coverage_type='Boundary Conditions', components=[self._bc_do_comp]
        )
        self._query.link_item(taker_uuid=sim.uuid, taken_uuid=self._bc_cov.uuid)

        # Add the material coverage geometry and its hidden component.
        if self._mat_cov:
            self._query.add_coverage(
                self._mat_cov, model_name='SRH-2D', coverage_type='Materials', components=[self._mat_do_comp]
            )
            self._query.link_item(taker_uuid=sim.uuid, taken_uuid=self._mat_cov.uuid)

        # Add the sediment material coverage geometry and its hidden component.
        if self._sed_mat_cov:
            self._query.add_coverage(
                self._sed_mat_cov,
                model_name='SRH-2D',
                coverage_type='Sediment Materials',
                components=[self._sed_mat_do_comp]
            )
            self._query.link_item(taker_uuid=sim.uuid, taken_uuid=self._sed_mat_cov.uuid)

        # Add the monitor coverage and its hidden component if we have one
        if self._monitor_cov:
            self._query.add_coverage(
                self._monitor_cov, model_name='SRH-2D', coverage_type='Monitor', components=[self._monitor_do_comp]
            )
            self._query.link_item(taker_uuid=sim.uuid, taken_uuid=self._monitor_cov.uuid)

        # Add the obstructions coverage geometry and its hidden component.
        if self._obs_cov:
            self._query.add_coverage(
                self._obs_cov, model_name='SRH-2D', coverage_type='Obstructions', components=[self._obs_do_comp]
            )
            self._query.link_item(taker_uuid=sim.uuid, taken_uuid=self._obs_cov.uuid)

    def _copy_restart_file(self, sim_comp_dir):
        """Copy the restart file to the simulation component directory if using an initial conditions file.

        Args:
            sim_comp_dir (:obj:`str`): Path to the simulation component's data folder
        """
        hydro_data = self._hydro_reader.model_control.hydro
        if hydro_data.initial_condition == 'Restart File':
            import_path = os.path.dirname(self._xms_data['filename'])
            orig_restart = ''
            if hydro_data.restart_file:
                orig_restart = filesystem.resolve_relative_path(import_path, hydro_data.restart_file)
            if os.path.isfile(orig_restart):
                # Store the basename. Always in the same directory as simulation component data.
                hydro_data.restart_file = os.path.basename(orig_restart)
                comp_restart = os.path.join(sim_comp_dir, hydro_data.restart_file)
                filesystem.copyfile(orig_restart, comp_restart)
            else:
                hydro_data.restart_file = ''  # Invalid restart file

    def _copy_ceiling_file(self, sim_comp_dir):
        """Copy the pressure ceiling file to the simulation component directory if it exists.

        Args:
            sim_comp_dir (:obj:`str`): Path to the simulation component's data folder
        """
        for f in self._hydro_reader.file_list:
            if f.lower().startswith('pressuredatasetfile'):
                words = f.split(' ')
                import_path = os.path.dirname(self._xms_data['filename'])
                orig_ceiling = filesystem.resolve_relative_path(import_path, words[1])
                if os.path.isfile(orig_ceiling):
                    comp_ceiling = os.path.join(sim_comp_dir, 'ceiling.shrceiling')
                    filesystem.copyfile(orig_ceiling, comp_ceiling)

    def _read_hydro(self):
        """Read parameters from a *.srhhydro file."""
        self._hydro_reader = HydroReader()
        self._hydro_reader.read(self._xms_data['filename'])

        comp_uuid = str(uuid.uuid4())
        sim_comp_dir = os.path.join(self._xms_data['comp_dir'], comp_uuid)
        os.makedirs(sim_comp_dir, exist_ok=True)
        self._copy_restart_file(sim_comp_dir)  # Copy restart file to simulation component folder if there is one.
        self._copy_ceiling_file(sim_comp_dir)

        sim_main_file = os.path.join(sim_comp_dir, 'sim_comp.nc')
        param_h5_io.write_to_h5_file(sim_main_file, self._hydro_reader.model_control)
        self._sim_comp = Component(
            main_file=sim_main_file, comp_uuid=comp_uuid, model_name='SRH-2D', unique_name='Sim_Manager'
        )

    def _read_geom(self, filename):
        """Read mesh geometry an BC arc geometry from a *.srhgeom file.

        Args:
            filename (:obj:`str`): Filepath of the *.srhgeom file

        """
        self._geom_reader = GeomReader()
        self._geom_reader.read(filename)

        self._mesh = UGrid(
            self._geom_reader.temp_mesh_file,
            name=self._geom_reader.data['name'],
            uuid=str(uuid.uuid4()),
            projection=self._geom_reader.geom_projection_from_grid_units()
        )

        if self._geom_reader.monitor_line_ids:
            self._hydro_reader.monitor_line_ids = self._geom_reader.monitor_line_ids

    def _read_hydromat(self, filename, is_sediment):
        """Read material assignments from a *.srhmat file. Build a polygon coverage from common assignment areas.

        Args:
            filename (:obj:`str`): Filepath of the *.srhmat file
            is_sediment (:obj:`bool`): True if this is the sediment material file
        """
        mat_reader = MaterialReader(is_sediment)
        mat_reader.read(filename)

        # Create a dataset of materials (size of cells)
        cell_materials = [0 for _ in range(self._geom_reader.cogrid.ugrid.cell_count)]
        for material, cells in mat_reader.material_cells.items():
            for cell in cells:
                cell_materials[cell] = material

        cov_name = 'Sediment Materials' if is_sediment else 'Materials'
        cov_builder = GridCellToPolygonCoverageBuilder(
            self._geom_reader.cogrid, cell_materials, self._mesh.projection, cov_name
        )
        new_cov_geom = cov_builder.create_polygons_and_build_coverage()
        comp_builder = MaterialComponentBuilder(
            self._xms_data['comp_dir'], mat_reader.material_names, cov_builder.dataset_polygon_ids
        )
        new_cov_uuid = new_cov_geom.uuid
        if is_sediment:
            self._sed_mat_cov = new_cov_geom
            self._sed_mat_do_comp = comp_builder.build_sed_material_component(
                new_cov_uuid, self._hydro_reader.sed_material_data
            )
        else:
            self._mat_cov = new_cov_geom
            self._mat_do_comp = comp_builder.build_material_component(
                new_cov_uuid, self._hydro_reader.materials_manning
            )

    def _build_bc_coverage(self):
        """Create the data_objects BC Coverage from data imported from the *.srhgeom file."""
        self._logger.info('Building BC coverage geometry...')

        arc_builder = CoverageArcBuilder(self._geom_reader.data['nodes'])
        for nodestring_id, nodestring in self._geom_reader.data['node_strings'].items():
            if nodestring_id in self._hydro_reader.monitor_line_ids:  # Do not put monitor lines in the BC coverage
                continue
            self._bc_id_to_att_id[nodestring_id] = arc_builder.add_arc(nodestring[0], nodestring[-1], nodestring[1:-1])

        self._bc_cov = Coverage(name='Boundary Conditions', uuid=str(uuid.uuid4()), projection=self._mesh.projection)
        self._bc_cov.arcs = arc_builder.arcs
        self._bc_cov.complete()

    def _build_obs_coverage(self):
        """Build the obstructions coverage geometry and component data."""
        if not self._hydro_reader.obstruction_decks and not self._hydro_reader.obstruction_piers:
            return  # No obstructions

        grid_units = 'en' if 'foot' in self._geom_reader.data['units'].lower() else 'si'
        self._logger.info('Building obstructions coverage...')
        self._obs_cov = Coverage(name='Obstructions', uuid=str(uuid.uuid4()), projection=self._mesh.projection)
        ids = []
        obs_list = []
        pier_pts = []
        pier_ids = {}
        deck_ids = {}
        if self._hydro_reader.obstruction_piers:  # Add disjoint pier points if we have any
            for pier in self._hydro_reader.obstruction_piers:
                pier_words = pier.split()
                if len(pier_words) < 10:
                    raise RuntimeError(f'Invalid Pier obstruction line: {pier}')
                comp_id = len(pier_pts) + 1
                att_id = len(pier_pts) + 1
                pier_ids[att_id] = comp_id
                ids.append(comp_id)
                obs_list.append(self._obs_param_from_line(pier_words))
                if pier_words[6].lower() != grid_units:
                    msg = f'The units specified for the obstruction on pt id: {att_id} do not match the units of ' \
                          f'the mesh. It is recommended that the obstruction units be consistent with the mesh units.'
                    self._logger.warning(msg)
                pt = Point(x=float(pier_words[7]), y=float(pier_words[8]), z=float(pier_words[9]), feature_id=att_id)
                pier_pts.append(pt)
            self._obs_cov.set_points(pier_pts)

        if self._hydro_reader.obstruction_decks:  # Add obstruction deck arcs if we have any
            next_pt_id = len(pier_pts) + 1
            deck_arcs = []
            for deck in self._hydro_reader.obstruction_decks:
                deck_words = deck.split()
                if len(deck_words) < 13:
                    raise RuntimeError(f'Invalid Deck obstruction line: {deck}')
                comp_id = len(pier_pts) + len(deck_arcs) + 1
                att_id = len(deck_arcs) + 1
                deck_ids[att_id] = comp_id
                ids.append(comp_id)
                obs_list.append(self._obs_param_from_line(deck_words))
                if deck_words[6].lower() != grid_units:
                    msg = f'The units specified for the obstruction on arc id: {att_id} do not match the units of ' \
                          f'the mesh. It is recommended that the obstruction units be consistent with the mesh units.'
                    self._logger.warning(msg)
                start_node = Point(
                    x=float(deck_words[7]), y=float(deck_words[8]), z=float(deck_words[9]), feature_id=next_pt_id
                )
                next_pt_id += 1
                end_node = Point(
                    x=float(deck_words[-3]), y=float(deck_words[-2]), z=float(deck_words[-1]), feature_id=next_pt_id
                )
                next_pt_id += 1
                verts = []
                for i in range(10, len(deck_words) - 3, 3):
                    verts.append(Point(float(deck_words[i]), float(deck_words[i + 1]), float(deck_words[i + 2])))
                arc = Arc(start_node=start_node, end_node=end_node, vertices=verts, feature_id=att_id)
                deck_arcs.append(arc)
            self._obs_cov.arcs = deck_arcs
        self._obs_cov.complete()

        from xms.srh.data.obstruction_data import ObstructionData
        obs_data = ObstructionData('dummy_file.nc')
        for i in range(len(ids)):
            obs_data.append_obstruction_data_with_id(obs_list[i], ids[i])

        self._obs_do_comp = build_obstructions_component(
            obs_data.obstruction_data, self._obs_cov.uuid, self._xms_data['comp_dir'], pier_ids, deck_ids
        )

    def _obs_param_from_line(self, words):
        """Create ObstructionParam class from file data.

        Args:
            words (:obj:`list`): line split from file

        Returns:
            ObstructionParam
        """
        from xms.srh.data.par.obstruction_param import ObstructionParam
        obs = ObstructionParam()
        obs.width = float(words[2])
        obs.thickness = float(words[3])
        obs.drag = float(words[4])
        obs.porosity = float(words[5])
        obs.units = 'Feet' if words[6].lower() == 'en' else 'Meters'
        return obs

    def _build_monitor_coverage(self):
        """Build a monitor coverage if we have monitor points and/or arcs."""
        if not self._mpoint_reader and not self._hydro_reader.monitor_line_ids:
            return
        self._monitor_cov = Coverage(name='Monitor', uuid=str(uuid.uuid4()), projection=self._mesh.projection)
        next_pt_id = 1
        if self._mpoint_reader:  # Add disjoint monitor points if we have any
            monitor_pts = self._mpoint_reader.get_do_points()
            self._monitor_cov.set_points(monitor_pts)
            next_pt_id = len(monitor_pts) + 1
        arc_builder = CoverageArcBuilder(self._geom_reader.data['nodes'], next_point_id=next_pt_id)
        for monitor_line_id in self._hydro_reader.monitor_line_ids:
            monitor_line_def = self._geom_reader.data['node_strings'][monitor_line_id]
            arc_builder.add_arc(monitor_line_def[0], monitor_line_def[-1], monitor_line_def[1:-1])
        self._monitor_cov.arcs = arc_builder.arcs
        self._monitor_cov.complete()
        self._monitor_do_comp = build_monitor_component(self._monitor_cov.uuid, self._xms_data['comp_dir'])

    def _get_hy8_filename(self):
        """Gets the path to the hy8 file if it is in the simulation.

        Returns:
            (:obj:`str`): hy8 file

        """
        hy8_filename = ''
        if self._hydro_reader.bc_hy8_file:
            hydro_file_dir = os.path.dirname(self._hydro_reader.filename)
            hy8_filename = filesystem.resolve_relative_path(hydro_file_dir, self._hydro_reader.bc_hy8_file)
        return hy8_filename

    def read(self):
        """Trigger the read of the SRH-2D simulation."""
        self._read_hydro()
        for file in self._hydro_reader.file_list:
            if file.startswith('Grid'):
                geom_filename = os.path.join(os.path.dirname(self._xms_data['filename']), file[5:].strip('"\''))
                self._read_geom(geom_filename)
                self._build_bc_coverage()
                att_ids = [self._bc_id_to_att_id[bc_id] for bc_id in self._hydro_reader.bcs]
                hy8_file = self._get_hy8_filename()
                self._bc_do_comp = build_bc_component(
                    self._hydro_reader.bcs, self._bc_cov.uuid, self._xms_data['comp_dir'], att_ids, hy8_file
                )
            elif file.startswith('MonitorPtFile'):  # Read monitor point file. Not dependent on any other files.
                self._mpoint_reader = MonitorPointReader()
                mpoint_filename = os.path.join(os.path.dirname(self._xms_data['filename']), file[14:].strip('"\''))
                self._mpoint_reader.read(mpoint_filename)

        # Need to read the *.srhhydro, *.srhgeom, and *.srhmpoint files before building monitor and obstructions
        # coverages.
        self._build_obs_coverage()
        self._build_monitor_coverage()

        for file in self._hydro_reader.file_list:  # Ensure *.srhhydro and *.srhgeom get read first
            if file.startswith('HydroMat') and self._geom_reader and self._geom_reader.cogrid:
                # Read the material file
                mat_filename = os.path.join(os.path.dirname(self._xms_data['filename']), file[9:].strip('"\''))
                self._read_hydromat(mat_filename, False)
            elif file.startswith('SubsurfaceBedFile') and self._geom_reader and self._geom_reader.cogrid:
                # Read the material file
                sed_mat_filename = os.path.join(os.path.dirname(self._xms_data['filename']), file[18:].strip('"\''))
                self._read_hydromat(sed_mat_filename, True)

        if self._query:
            self._add_xms_data()

    def send(self):
        """Send imported data to SMS."""
        if self._query:
            self._query.send()
