"""Module for importing a SCHISM simulation with feedback."""

__copyright__ = "(C) Copyright Aquaveo 2023"
__license__ = "All rights reserved"

# 1. Standard Python modules
import binascii
import os
from typing import Sequence
import uuid

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.api.dmi import Query, XmsEnvironment as XmEnv
from xms.constraint import UGrid2d
from xms.data_objects.parameters import Arc, Component, Coverage, Point, Projection, Simulation, UGrid as DoUGrid
from xms.datasets.dataset_writer import DatasetWriter
from xms.gmi.components.utils import new_component_dir
from xms.guipy.data.target_type import TargetType
from xms.guipy.dialogs.feedback_thread import ExpectedError, FeedbackThread

# 4. Local modules
from xms.schism.components.mapped_upwind_solver_coverage_component import MappedUpwindSolverCoverageComponent
from xms.schism.data.coverage_data import CoverageData
from xms.schism.data.mapped_bc_data import MappedBcData
from xms.schism.data.mapped_upwind_solver_coverage_data import MappedUpwindSolverCoverageData
from xms.schism.data.model import get_model, needed_files, parameter_for_file
from xms.schism.data.sim_data import SimData
from xms.schism.external.mapped_tidal_data import MappedTidalData
from xms.schism.external.project import GEOGRAPHIC_WKT, looks_geographic
from xms.schism.feedback.display_options_helper import DisplayOptionsHelper
from xms.schism.file_io import read_bctides, read_fort14, read_namelist, read_vgrid


class CoverageBuilder:
    """Class for building a coverage using node IDs in a UGrid rather than locations."""
    def __init__(self, name: str, ugrid: UGrid2d):
        """
        Initialize the builder.

        Args:
            name: Name to assign the coverage. Appears in the UI.
            ugrid: Reference geometry for building arcs from.
        """
        self._ugrid = ugrid
        self._built_points = {}
        self._arcs = []
        self._coverage = Coverage(name=name, uuid=str(uuid.uuid4()))

    def add_arc(self, node_ids: Sequence[int]):
        """
        Add an arc to the coverage.

        Args:
            node_ids: Sequence of IDs of points in UGrid representing the arc.

        Returns:
            ID of the new arc.
        """
        indices = [node_id - 1 for node_id in node_ids]
        locations = self._ugrid.ugrid.get_points_locations(indices)
        points = []
        for node_id, location in zip(node_ids, locations):
            if node_id not in self._built_points:
                feature_id = len(self._built_points) + 1
                point = Point(x=location[0], y=location[1], z=location[2], feature_id=feature_id)
                self._built_points[node_id] = point
            points.append(self._built_points[node_id])
        start = points.pop(0)
        end = points.pop(-1)
        arc = Arc(feature_id=len(self._arcs) + 1, start_node=start, end_node=end, vertices=points)
        self._arcs.append(arc)

    def complete(self) -> Coverage:
        """
        Mark the coverage as complete.

        Returns:
            The completed coverage.
        """
        self._coverage.arcs = self._arcs
        self._coverage.complete()
        return self._coverage


class ImportSimulationRunner(FeedbackThread):
    """Class for importing a SCHISM simulation in a worker thread."""
    def __init__(self, query: Query):
        """Constructor."""
        super().__init__(query)
        self.sim: Simulation
        self.global_parameters = get_model().global_parameters
        self.ugrid_uuid: str
        self.upwind_solver_uuid: str = ''
        self.projection = Projection()

        self._open_boundaries: Sequence[Sequence[int]]
        self._closed_boundaries: Sequence[Sequence[int]]
        self._domain_hash: str

        self.display_text = {
            'title': 'SCHISM Import Simulation',
            'working_prompt': 'Importing SCHISM simulation files. Please wait...',
            'warning_prompt': 'Warning(s) encountered while importing simulation. Review log output for more details.',
            'error_prompt': 'Error(s) encountered while importing simulation. Review log output for more details.',
            'success_prompt': 'Successfully imported simulation',
            'note': '',
            'auto_load': 'Close this dialog automatically when importing is finished.'
        }

    def _read_namelist(self):
        """Read the param.nml file."""
        read_namelist(self._query.read_file, self.global_parameters)

    def _read_hgrid_gr3(self):
        """Read hgrid.gr3."""
        self._log.info('Reading hgrid.gr3...')
        try:
            fort14_file = read_fort14('hgrid.gr3', read_boundaries=True)
        except IOError:
            raise ExpectedError('Could not read file: hgrid.gr3')

        self.ugrid = fort14_file.ugrid
        if looks_geographic(self.ugrid):
            self.projection = Projection(wkt=GEOGRAPHIC_WKT)
        else:
            self.projection = Projection(horizontal_units='METERS')

        depths = self.ugrid.ugrid.locations[:, 2]
        elevations = -depths
        self.ugrid._instance.SetPointElevations(elevations)  # TODO: Put in xmsconstraint so private access isn't needed
        self.ugrid_uuid = self.ugrid.uuid
        self._open_boundaries = fort14_file.open_boundaries
        self._open_boundary_strings = [open_boundary.nodes for open_boundary in fort14_file.open_boundaries]
        self._closed_boundaries = fort14_file.closed_boundaries
        grid_path = os.path.join(XmEnv.xms_environ_temp_directory(), f'{self.ugrid_uuid}.xmc')
        self.ugrid.write_to_file(grid_path)
        with open(grid_path, 'rb') as f:
            crc = str(hex(binascii.crc32(f.read())))
            self._domain_hash = crc
        do_ugrid = DoUGrid(grid_path, name='hgrid', uuid=self.ugrid_uuid, projection=self.projection)

        self._query.add_ugrid(do_ugrid)

    def _read_bctides_in(self):
        """Read bctides.in."""
        self._log.info('Reading bctides.in...')

        try:
            self.bctides = read_bctides('bctides.in', self._open_boundary_strings)
        except IOError:
            raise ExpectedError('Could not read file: bctides.in')

        self.global_parameters.group('other').parameter('tip_dp').value = self.bctides.cutoff_depth

    def _read_vgrid_in(self):
        """Read the vgrid.in file."""
        try:
            vgrid = read_vgrid('vgrid.in')
        except IOError:
            raise ExpectedError('Could not read file: vgrid.in')

        other_group = self.global_parameters.group('other')
        other_group.parameter('hc').value = vgrid.hc
        other_group.parameter('theta_b').value = vgrid.theta_b
        other_group.parameter('theta_f').value = vgrid.theta_f

        s_levels = [[level] for level in vgrid.s_levels]
        other_group.parameter('s_levels').value = s_levels

        z_levels = [[level] for level in vgrid.z_levels]
        other_group.parameter('z_levels').value = z_levels

    def _read_optional_files(self):
        """Read all optional files."""
        try:
            files = needed_files(self.global_parameters)
        except ValueError as err:  # pragma: nocover
            raise ExpectedError(err.args[0])

        for file in sorted(files):  # Sorting them keeps tests stable.
            self._read_optional_file(file)

    def _read_optional_file(self, file: str):
        """
        Read an optional file.

        Args:
            file: Name of the file to read.
        """
        if file.endswith('.gr3') or file.endswith('.ic'):
            self._read_gr3(file)
        elif file in ['sflux/sflux_air_1.*.nc', 'sflux/sflux_rad_1.*.nc', 'sflux/sflux_prc_1.*.nc']:
            self._log.warning('Files in sflux/* were not read.')
        elif file == 'tvd.prop':
            self._read_tvd_prop()
        else:  # pragma: nocover
            raise AssertionError(f'Unsupported file: {file}')  # needed_files should only give files this supports

    def _read_tvd_prop(self):
        """Read the tvd.prop file if necessary."""
        mapping = np.loadtxt('tvd.prop', dtype=int, usecols=1)
        data = MappedUpwindSolverCoverageData()
        data.solver = mapping
        data.commit()

        helper = DisplayOptionsHelper(data.main_file)
        helper.add_feature_types(TargetType.polygon, ['Upwind', 'Higher order'])
        helper.draw_ugrid(self.ugrid.ugrid, mapping)

        component = MappedUpwindSolverCoverageComponent(data.main_file)
        self._mapped_upwind_solver_coverage = component

        do_component = Component(
            main_file=component.main_file,
            module_name=component.module_name,
            class_name=component.class_name,
            comp_uuid=component.uuid
        )

        self.upwind_solver_uuid = data.uuid
        self._query.add_component(do_component=do_component)

    def _create_sim(self):
        """Create the simulation and its component."""
        sim_data = SimData()
        sim_data.global_values = self.global_parameters.extract_values()
        sim_data.commit()
        sim = Simulation(model='SCHISM', sim_uuid=str(uuid.uuid4()), name='SCHISM Simulation')
        do_sim_component = Component(
            name=sim.name,
            comp_uuid=sim_data.uuid,
            main_file=str(sim_data.main_file),
            model_name='SCHISM',
            unique_name='SimComponent'
        )
        self._query.add_simulation(sim, [do_sim_component])

        self._query.link_item(sim.uuid, self.ugrid_uuid)
        if self.upwind_solver_uuid:
            self._query.link_item(sim.uuid, self.upwind_solver_uuid)

        self._sim_uuid = sim.uuid

    def _build_coverage(self):
        """Build the unmapped boundary condition coverage."""
        builder = CoverageBuilder(name='SCHISM BCs', ugrid=self.ugrid)
        for open_boundary in self._open_boundary_strings:
            builder.add_arc(open_boundary)
        coverage = builder.complete()

        main_file = new_component_dir() / 'schism_coverage.nc'
        data = CoverageData(main_file)
        data.add_features(TargetType.arc, self.bctides.values, ['open'] * len(self.bctides.values))

        data.commit()
        display_options = DisplayOptionsHelper(data.main_file)
        display_options.add_feature_types(TargetType.arc, ['Open'])

        all_ids = [arc.id for arc in coverage.arcs]
        display_options.add_feature_ids(TargetType.arc, all_ids)
        display_options.map_features_to_type(TargetType.arc, all_ids, 'Open')

        do_comp = Component(
            name='SCHISM BC',
            comp_uuid=data.uuid,
            main_file=str(main_file),
            model_name='SCHISM',
            unique_name='CoverageComponent'
        )

        self._query.add_coverage(coverage, 'SCHISM', 'Boundary Conditions', [do_comp])

    def _build_mapped_coverage(self):
        """Build the mapped boundary condition coverage."""
        main_file = new_component_dir() / 'mapped_bc.nc'
        mapped_data = MappedBcData(main_file)

        for boundary, values in zip(self._open_boundary_strings, self.bctides.values):
            mapped_data.add_open_arc(boundary, values)

        for boundary in self._closed_boundaries:
            mapped_data.add_closed_arc(boundary.nodes, boundary.boundary_type)

        mapped_data.domain_hash = self._domain_hash
        mapped_data.commit()

        lines = []
        for boundary in self._open_boundaries:
            indices = [node_id - 1 for node_id in boundary.nodes]
            line = self.ugrid.ugrid.get_points_locations(indices)
            lines.append(line)

        helper = DisplayOptionsHelper(mapped_data.main_file)
        helper.add_feature_types(TargetType.arc, ['Open'])
        helper.draw_lines('Open', lines)

        do_component = Component(
            main_file=str(mapped_data.main_file),
            name='Mapped coverage',
            model_name='SCHISM',
            unique_name='MappedBcComponent',
            comp_uuid=mapped_data.uuid
        )

        self._query.add_component(do_component=do_component)
        self._query.link_item(taker_uuid=self._sim_uuid, taken_uuid=mapped_data.uuid)

    def _build_tides(self):
        """Build the mapped tides component."""
        main_file = new_component_dir() / 'schism_mapped_tides.nc'
        tide_data = MappedTidalData(main_file)
        tide_data.domain_hash = self._domain_hash

        if self.bctides.elevation:
            tide_data.elevation = self.bctides.elevation
        if self.bctides.velocity:
            tide_data.velocity = self.bctides.velocity
        if self.bctides.forcing_frequencies:
            tide_data.properties = self.bctides.forcing_frequencies
        tide_data.commit()

        tide_name = 'SCHISM tides (applied)'

        do_component = Component(
            main_file=str(main_file),
            name=tide_name,
            model_name='SCHISM',
            unique_name='MappedTidalComponent',
            comp_uuid=tide_data.uuid
        )

        self._query.add_component(do_component=do_component)
        self._query.link_item(taker_uuid=self._sim_uuid, taken_uuid=do_component.uuid)

    def _read_gr3(self, file_name: str):
        """
        Read a .gr3 file.

        Args:
            file_name: File to read.
        """
        group_name, parameter_name = parameter_for_file(file_name)

        self._log.info(f'Reading {file_name}...')
        dataset_name = file_name.rsplit('.', maxsplit=2)[0]  # Everything except the extension
        dataset_uuid = str(uuid.uuid4())
        try:
            fort14_file = read_fort14(file_name, read_boundaries=False)
        except IOError:
            raise ExpectedError(f'Could not read file: {file_name}')
        h5_file = os.path.join(XmEnv.xms_environ_temp_directory(), f'{dataset_uuid}.h5')
        writer = DatasetWriter(
            h5_filename=h5_file,
            name=dataset_name,
            geom_uuid=self.ugrid_uuid,
            dset_uuid=dataset_uuid,
        )
        writer.write_xmdf_dataset([0.0], [fort14_file.dataset])
        self._query.add_dataset(writer)

        self.global_parameters.group(group_name).parameter(parameter_name).value = dataset_uuid

    def _run(self):
        """Run the thread."""
        # Most import code assumes we're in the same directory as the param.nml file.
        # SMS ensures this when it runs, but the testing code doesn't.
        os.chdir(os.path.dirname(self._query.read_file))

        self._read_namelist()  # Most read functions need information from the namelist, so this needs to be first.
        self._read_hgrid_gr3()  # Most *.gr3 files add datasets to the mesh this creates
        self._read_bctides_in()
        self._read_vgrid_in()

        self._read_optional_files()

        self._create_sim()
        self._build_coverage()
        self._build_mapped_coverage()
        self._build_tides()

        self._query.send()
