"""Exports SRH simulation."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import json
import logging
import os
from pathlib import Path
import shutil
import typing

# 2. Third party modules
import orjson

# 3. Aquaveo modules
from xms.api.dmi import Query  # noqa: I202
from xms.api.dmi import XmsEnvironment as XmEnv
from xms.core.filesystem import filesystem

# 4. Local modules
from xms.srh.components.check_runner import CheckRunner
from xms.srh.components.parameters.parameters_manager import ParametersManager
from xms.srh.components.sim_query_helper import SimQueryHelper
from xms.srh.file_io.geom_writer import GeomWriter
from xms.srh.file_io.hydro_writer import HydroWriter
from xms.srh.file_io.material_writer import MaterialWriter
from xms.srh.file_io.parameters_exporter import ParametersExporter
from xms.srh.file_io.pest_exporter import PestExporter
from xms.srh.mapping.coverage_mapper import CoverageMapper


def write(
    query: Query, out_dir: typing.Optional[str | Path] = None, logger: logging.Logger = None, sim_uuid: str = None
):
    """Writes the SWMM input file.

    Args:
        query: Interprocess communication object.
        out_dir: The output directory.
        logger: The logger.
        sim_uuid: The UUID of the simulation.
    """
    writer = FileWriter(query, out_dir, logger, sim_uuid)
    return writer.write()


class FileWriter:
    """Class for exporting SRH-2D."""
    def __init__(self, query: Query, out_dir: [str | Path], logger: logging.Logger, sim_uuid):
        """Constructor.

        Args:
            query (Query): The XMS interprocess communication object or None if one does not exist
            out_dir (:obj:`str`): output directory
            logger (logging.Logger): The logger
            sim_uuid (str): The UUID of the simulation
        """
        super().__init__()
        self.out_dir = str(out_dir)
        self.query = query or Query()
        self.sim_query_helper = None
        self.coverage_mapper = None
        self._exporter = None
        self.project_name = None
        self.sim_component = None
        self.wse_data = None
        self.ceiling_file = None
        self.vwse_dsets = None
        self._logger = logger
        self.files_exported = []
        self._sim_uuid = sim_uuid

    def write(self):
        """Creates the snap preview of coverages onto the mesh."""
        try:
            self._setup_query()
            if self._using_pest() or self._using_parameters():
                if not self._check_srh_license():  # pragma: no cover
                    return

            self.coverage_mapper.do_map()
            if self._using_parameters():
                self._do_export_with_parameters()
            else:
                self.do_export()
        except Exception as error:
            err_str = str(error)
            if err_str:
                self._logger.exception(f'Error exporting simulation: {err_str}')
                # raise error

    def get_sediment_enabled(self) -> bool:
        """Returns whether sediment is enabled."""
        return self.sim_component.data.enable_sediment

    def _check_srh_license(self):  # pragma: no cover
        """Check for the srh license.

        Returns:
            (:obj:`bool`): True if no problems with the license
        """
        if not hasattr(self.query, 'cp'):
            return True  # pragma no cover - no good way to reach

        if not self.query.cp.get('SRH Hydro Interface', False):
            msg = 'Unable to save simulation. The simulation is an "Advanced Simulation" which ' \
                  'requires the "SRH Hydro Interface" in the SMS license. Aborting.'
            self._logger.error(msg)
            return False
        return True

    def _setup_query(self):
        self.project_name = os.path.splitext(os.path.basename(self.query.xms_project_path))[0]
        self.sim_query_helper = SimQueryHelper(self.query, at_sim=True, sim_uuid=self._sim_uuid)
        self.sim_query_helper.get_sim_data()
        self._simulation_name = self.sim_query_helper.sim_tree_item.name
        self.sim_component = self.sim_query_helper.sim_component
        self.wse_data = self.sim_query_helper.wse_dataset

        sim_data = self.sim_component.data
        if sim_data.hydro.initial_condition == 'Water Surface Elevation Dataset' and self.wse_data is None:
            msg = 'No data set specified for initial condition "Water Surface Elevation Dataset". ' \
                  'A data set must be selected in the model control dialog.'
            self._logger.error(msg)

        ceiling_file = os.path.join(os.path.dirname(self.sim_component.main_file), 'ceiling.srhceiling')
        if self.sim_component.data.hydro.use_pressure_ceiling:
            if self.sim_query_helper.structures_3d:  # pragma: no cover
                msg = '3D Structures are included in the simulation. A ceiling file will be generated by ' \
                      'the 3D Structures and will override the defined ceiling file.'
                self._logger.warning(msg)
                self.sim_component.data.hydro.use_pressure_ceiling = False
                self.sim_component.save_sim_data_to_main_file()
            else:
                if not os.path.isfile(ceiling_file):
                    msg = 'Error: model control set to use pressure ceiling but no ceiling file defined. ' \
                          'Create the pressure ceiling file in the model control dialog.'
                    self._logger.error(msg)
                else:
                    self.ceiling_file = ceiling_file

        if self.sim_query_helper.structures_3d:
            self.sim_query_helper.get_3d_structure_data()
            if os.path.isfile(ceiling_file):
                self.ceiling_file = ceiling_file

        self.coverage_mapper = CoverageMapper(self.sim_query_helper, generate_snap=False)
        self.coverage_mapper.ceiling_file = self.ceiling_file
        check_sim = CheckRunner(self.sim_query_helper)
        check_sim.run_check()
        if self.sim_query_helper.co_grid is None:
            self._logger.error('No mesh or ugrid in the SRH simulation.')
            raise RuntimeError()

    def do_export(self):
        """Export the simulation."""
        if XmEnv.xms_environ_running_tests() != 'TRUE':  # can't reach during testing
            self._logger.info(f'Writing SRH files to this directory: \n\n{self.out_dir}\n')  # pragma no cover
        self.files_exported = []
        self.export_geom()
        self.export_materials()
        self.export_sed_materials()
        self.export_monitor_points()
        self.export_wse_dataset()
        self.export_pressure_ceiling()
        self.export_hydro()
        self.export_srh_post(using_parameters=False)
        self._export_pest_files()

    def _do_export_with_parameters(self):
        """Exports the simulation(s) when parameters are in use."""
        param_exporter = ParametersExporter(self.vwse_dsets)
        param_exporter.export(self)

    def _export_pest_files(self):
        """Exports the simulation(s) when parameters are in use."""
        if self._using_pest():
            pest_exporter = PestExporter()
            pest_exporter.export(self)

    def _using_parameters(self):
        """Returns True if we're using parameters. Also caches the param data if not done already.

        Returns:
            See description.

        """
        param_data = ParametersManager.read_parameter_file(self.sim_component.main_file)
        return param_data and param_data['use_parameters'] and param_data['run_type'] == 'Scenarios'

    def _using_pest(self):
        """Returns True if we're using PEST. Also caches the param data if not done already.

        Returns:
            See description.

        """
        param_data = ParametersManager.read_parameter_file(self.sim_component.main_file)
        return param_data and param_data['use_parameters'] and param_data['run_type'] != 'Scenarios'

    def export_geom(self):
        """Exports the srhgeom file."""
        self._logger.info('Writing SRH-2D geometry file.')
        co_grid = self.coverage_mapper.co_grid
        base_name = f'{self.project_name}.srhgeom'
        self.files_exported.append(f'Grid "{base_name}"')
        file_name = os.path.join(self.out_dir, base_name)
        ugrid = co_grid.ugrid
        grid_name = self.coverage_mapper.grid_name
        grid_units = self.coverage_mapper.grid_units
        node_strings = []
        if self.coverage_mapper.monitor_arc_ids_to_grid_cell_ids:
            node_strings = sorted(self.coverage_mapper.monitor_arc_ids_to_grid_cell_ids.items())
        if self.coverage_mapper.bc_arc_id_to_grid_ids:
            tmp_node_strings = sorted(self.coverage_mapper.bc_arc_id_to_grid_ids.items())
            if node_strings:
                cnt = len(node_strings) + 1
                for pair in tmp_node_strings:
                    node_strings.append((cnt, pair[1]))
                    cnt += 1
            else:
                node_strings = tmp_node_strings
        cnt = len(node_strings)
        for struct in self.coverage_mapper.bc_3d_structures:
            cnt += 1
            node_strings.append((cnt, struct['up_snap_ids']))
            struct['up_nodestring_id'] = cnt
            cnt += 1
            node_strings.append((cnt, struct['down_snap_ids']))
            struct['down_nodestring_id'] = cnt

        if len(node_strings) > 100:  # pragma: no cover
            msg = 'SRH supports a maximum of 100 node strings (monitor lines, BC lines). ' \
                  f'This simulation has {len(node_strings)} node strings defined.'
            self._logger.warning(msg)

        writer = GeomWriter(
            file_name=file_name, ugrid=ugrid, grid_name=grid_name, grid_units=grid_units, node_strings=node_strings
        )
        writer.write()
        self._logger.info('Success writing SRH-2D geometry file.')

    def export_materials(self):
        """Exports the SRH material file."""
        if self.coverage_mapper.material_names is None:
            return  # pragma no cover
        mat_names = self.coverage_mapper.material_names
        comp_ids = self.coverage_mapper.material_comp_ids
        my_dict1 = self.coverage_mapper.material_comp_id_to_grid_cell_ids
        # change from component ids to indexes for export
        my_dict = {}
        for i, comp_id in enumerate(comp_ids):
            my_dict[i] = my_dict1.get(comp_id, [])
        self._logger.info('Writing SRH-2D material file.')
        base_name = f'{self.project_name}.srhmat'
        file_name = os.path.join(self.out_dir, base_name)
        self.files_exported.append(f'HydroMat "{base_name}"')
        writer = MaterialWriter(file_name=file_name, mat_names=mat_names, mat_grid_cells=my_dict)
        writer.write()
        self._logger.info('Success writing SRH-2D material file.')

    def export_sed_materials(self):
        """Exports the SRH sediment material file."""
        mat_names = self.coverage_mapper.sed_material_names
        if not mat_names:
            return
        comp_ids = self.coverage_mapper.sed_material_comp_ids
        my_dict1 = self.coverage_mapper.sed_material_comp_id_to_grid_cell_ids
        # change from component ids to indexes for export
        my_dict = {}
        for i, comp_id in enumerate(comp_ids):
            my_dict[i] = my_dict1.get(comp_id, [])
        self._logger.info('Writing SRH-2D sediment material file.')
        base_name = f'{self.project_name}.srhsedmat'
        file_name = os.path.join(self.out_dir, base_name)
        self.files_exported.append(f'SubsurfaceBedFile "{base_name}" SRHMAT')
        writer = MaterialWriter(file_name=file_name, mat_names=mat_names, mat_grid_cells=my_dict)
        writer.write()
        self._logger.info('Success writing SRH-2D sediment material file.')

    def export_monitor_points(self):
        """Exports the SRH monitor points."""
        points = self.coverage_mapper.monitor_points
        if not points:
            return
        self._logger.info('Writing SRH-2D monitor point file.')
        base_name = f'{self.project_name}.srhmpoint'
        file_name = os.path.join(self.out_dir, base_name)
        self.files_exported.append(f'MonitorPtFile "{base_name}"')
        with open(file_name, 'w') as file:
            file.write('SRHMON 30\n')
            file.write(f'NUMMONITORPTS {len(points)}\n')
            for pt in points:
                file.write(f'monitorpt {pt[0]} {pt[1]} {pt[2]}\n')
        self._logger.info('Success writing SRH-2D monitor point file.')

    def export_wse_dataset(self):
        """Exports the Water Surface Elevation (WSE) data set."""
        if self.wse_data is None:
            return
        self._logger.info('Writing SRH-2D water surface elevation data set file.')
        filename = os.path.join(self.out_dir, 'wse.2dm')
        with open(filename, 'w') as file:
            file.write('MESH2D\n')
            grid_name = self.coverage_mapper.grid_name
            file.write(f'MESHNAME "{grid_name}"\n')
            for i in range(len(self.wse_data)):
                file.write(f'ND {i + 1} 0.0 0.0 {self.wse_data[i]}\n')
        self._logger.info('Success writing SRH-2D water surface elevation data set file.')

    def export_pressure_ceiling(self):
        """Exports the pressure ceiling file."""
        if not self.ceiling_file:
            return
        # make sure that the mesh has not changed since the ceiling file was generated
        if not self.sim_query_helper.structures_3d:
            ceiling_mesh = os.path.join(os.path.dirname(self.ceiling_file), 'ceiling_mesh.json')
            if os.path.isfile(ceiling_mesh):
                with open(ceiling_mesh, 'r') as f:
                    line = f.readline()
                    pd = json.loads(line)
                    uuid = self.sim_query_helper.grid_uuid
                    ug = self.sim_query_helper.co_grid.ugrid
                    crc = self.sim_query_helper.co_grid_file_crc32
                    mesh_data = {
                        'mesh_uuid': uuid,
                        'num_nodes': ug.point_count,
                        'num_cells': ug.cell_count,
                        'mesh_crc32': crc
                    }
                    if pd != mesh_data:
                        msg = 'Pressure ceiling file was generated with a different mesh than the current simulation ' \
                            'mesh. The ceiling file must be regenerated in the Model Control dialog.'
                        self._logger.error(msg)
        out_file = os.path.join(self.out_dir, 'ceiling.srhceiling')
        filesystem.copyfile(self.ceiling_file, out_file)
        self.files_exported.append('PressureDatasetFile "ceiling.srhceiling"')

    def export_hydro(self, is_template=False, run_name=None):
        """Exports the SRH hydro file.

        Args:
            is_template (:obj:`bool`): flag used with PEST runs
            run_name (:obj:`str`): name used with scenario runs
        """
        message = 'Writing SRH-2D hydro file'
        if is_template:
            message = f'{message} PEST template'
        message = f'{message}.'
        self._logger.info(message)
        extension = ''
        if is_template:
            extension = '.tpl'
        file_name = os.path.join(self.out_dir, f'{self.project_name}.srhhydro{extension}')
        writer = HydroWriter(
            file_name=file_name,
            model_control=self.sim_component.data,
            file_list=self.files_exported,
            logger=self._logger,
            main_file=self.sim_component.main_file,
            is_template=is_template,
            grid_units=self.coverage_mapper.grid_units
        )
        writer.materials_manning = self.coverage_mapper.material_mannings
        writer.materials_bed_shear = self._materials_bed_shear()
        writer.sed_material_data = self.coverage_mapper.sed_material_data
        if self.get_sediment_enabled() and writer.sed_material_data is None:
            msg = 'Sediment is enabled but no "Sediment Materials" coverage is included. SRH will not run.'
            self._logger.error(msg)
        writer.obstruction_decks = self.coverage_mapper.obstructions_decks
        writer.obstruction_piers = self.coverage_mapper.obstructions_piers
        writer.bc_data = self.coverage_mapper.bc_arc_id_to_bc_param
        writer.bc_arc_id_to_comp_id = self.coverage_mapper.bc_arc_id_to_comp_id
        writer.bc_bc_id_to_structure = self.coverage_mapper.bc_id_to_structures
        writer.bc_arc_id_to_bc_id = self.coverage_mapper.bc_arc_id_to_bc_id
        writer.bc_arc_id_to_grid_pts = self.coverage_mapper.bc_arc_id_to_grid_pts
        writer.bc_arc_id_to_node_string_length = self.coverage_mapper.bc_arc_id_to_node_string_lengths
        writer.bc_3d_structures = self.coverage_mapper.bc_3d_structures
        writer.monitor_3d_structures = self.coverage_mapper.structures_3d_monitor
        writer.bc_comp_file = self.coverage_mapper.bc_component_file
        if self.coverage_mapper.monitor_arc_ids_to_grid_cell_ids:
            writer.num_monitor_lines = len(self.coverage_mapper.monitor_arc_ids_to_grid_cell_ids)
        self._hydro_external_files(writer)
        writer.write(run_name)
        message = 'Success writing SRH-2D hydro file'
        if is_template:
            message = f'{message} PEST template'
        message = f'{message}.'
        self._logger.info(message)

    def _materials_bed_shear(self):
        """Get list of material bed shear values."""
        bed_shear_list = None
        if not self.get_sediment_enabled():
            return bed_shear_list

        list_len = len(self.coverage_mapper.material_mannings)
        val = self.sim_component.data.sediment.shear_partitioning_scaled_d90_factor
        if self.sim_component.data.sediment.shear_partitioning_option == 'Percentage':
            val = self.sim_component.data.sediment.shear_partitioning_percentage
        bed_shear_list = [val] * list_len
        return bed_shear_list

    def _get_hy8_dest_path(self):
        """Gets the project SRH-2D folder.

        Returns:
            (:obj:`str`): SRH-2D folder
        """
        proj_folder = os.path.normpath(f'{os.path.splitext(XmEnv.xms_environ_project_path())[0]}_models')
        srh_folder = os.path.join(proj_folder, 'SRH-2D')
        if srh_folder not in self.out_dir:
            return self.out_dir
        return srh_folder

    def _hydro_external_files(self, writer):
        """Copy or write files that the hydro file references.

        Args:
            writer (:obj:`HydroWriter`): the srh hydro writer class
        """
        if self.coverage_mapper.bc_arc_id_to_bc_param:
            for _, bc in self.coverage_mapper.bc_arc_id_to_bc_param.items():
                if bc.bc_type == 'Culvert HY-8':
                    # get the SRH-2D folder
                    hy8_dst_path = self._get_hy8_dest_path()
                    new_hy8_file = os.path.join(hy8_dst_path, 'culvert.hy8')
                    # use copy2 to preserve file attributes like 'modification time'
                    shutil.copy2(self.coverage_mapper.bc_hy8_file, new_hy8_file)
                    writer.bc_hy8_file = new_hy8_file
                    break

        if self.sim_component.data.hydro.initial_condition == 'Restart File':
            sim_dir = os.path.dirname(self.sim_component.main_file)
            if not self.sim_component.data.hydro.restart_file:
                self._logger.error('Missing restart file. Specify a restart file and try again.')
                raise RuntimeError()
            rst_file = os.path.join(sim_dir, self.sim_component.data.hydro.restart_file)
            case = self.sim_component.data.hydro.case_name
            out_rst_file = os.path.join(self.out_dir, f'{case}_RST1.dat')
            filesystem.copyfile(rst_file, out_rst_file)
            self._restart_with_dip()

    def _restart_with_dip(self):
        """Use the DIP file with the restart initial condition."""
        case = self.sim_component.data.hydro.case_name
        dip_file = os.path.join(self.out_dir, f'{case}_DIP.dat')
        with open(dip_file, 'w') as f:
            f.write('$DATAC\n')
            f.write(' irest = 1\n')
            f.write('$ENDC\n')
        res_file = os.path.join(self.out_dir, f'{case}_RES.dat')
        with open(res_file, 'w') as _:
            pass

    def export_srh_post(self, using_parameters):
        """Exports SRH post.

        Args:
            using_parameters (:obj:`bool`): True if simulation uses parameters.

        """
        self._logger.info('Writing srh_post file.')
        # get the node string ids for bc arcs that are not of type 'Wall'
        arc_ids = []
        cnt = 0
        if self.coverage_mapper.monitor_arc_ids_to_grid_cell_ids:
            cnt = len(self.coverage_mapper.monitor_arc_ids_to_grid_cell_ids)
        if self.coverage_mapper.bc_arc_id_to_grid_ids is not None:
            tmp_node_strings = sorted(self.coverage_mapper.bc_arc_id_to_grid_ids.items())
            for arc_id, _ in tmp_node_strings:
                cnt += 1
                bc_par = self.coverage_mapper.bc_arc_id_to_bc_param[arc_id]
                if bc_par.bc_type != 'Wall (no-slip boundary)':
                    arc_ids.append(cnt)
                # if bc_par.bc_type == 'Pressure':
                #     pressure = True

        # Save some useful things to a srh_post.json file so we have them later
        length = 'm'
        force = 'N'
        if self.sim_component.data.output.output_units == 'English':
            length = 'ft'
            force = 'lb'

        dsnames = [
            f'B_Stress_{force}_p_{length}2', 'Froude', f'Vel_Mag_{length}_p_s', 'Velocity', f'Water_Depth_{length}',
            f'Water_Elev_{length}'
        ]
        # if pressure:
        #     dsnames.append(f'Water_Elev+Pressure_{length}')
        if self.get_sediment_enabled():
            dsnames.extend([f'Bed_Elev_{length}', 'CONC_T_ppm', 'D50_mm', f'ERO_DEP_{length}'])

        output_prefix = ''
        geom_file = f'{self.project_name}.srhgeom'

        json_dict = {
            'OUT_FILE': f'{output_prefix}{self.sim_component.data.hydro.case_name}_XMDFC.h5',
            'SRH_POST_OUT_FILE': f'{output_prefix}{self.sim_component.data.hydro.case_name}_XMDF.h5',
            'GEOM_FILE': geom_file,
            'ARC_IDS': arc_ids,
            'GRID_UUID': self.coverage_mapper.grid_uuid,
            'LENGTH_UNITS': self.coverage_mapper.grid_units,
            'DATASET_NAMES': dsnames
        }
        file_name = os.path.join(self.out_dir, 'srh_post.json')
        with open(file_name, 'wb') as f:
            data = orjson.dumps(json_dict)
            f.write(data)

        self._logger.info('Success writing srh_post file.')
