"""SimQueryHelper class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import binascii
import json
import logging
import os
import shutil

# 2. Third party modules
import shapely
from shapely.geometry import LineString, Point, Polygon

# 3. Aquaveo modules
import xms.api._xmsapi.dmi as xmd
from xms.api.tree import tree_util
from xms.bridge import structure_util as su
from xms.bridge.calc.footprint_calc import FootprintCalculator
from xms.bridge.dmi.xms_data import XmsData as StructureXmsData
from xms.bridge.structure_component import StructureComponent
from xms.constraint import read_grid_from_file
from xms.data_objects.parameters import FilterLocation
from xms.grid.ugrid.ugrid import UGrid
from xms.grid.ugrid.ugrid_utils import read_ugrid_from_ascii_file
from xms.guipy.dialogs import treeitem_selector_datasets

# 4. Local modules
from xms.srh.bridges.ceiling_generator import CeilingGenerator
from xms.srh.components.bc_component import BcComponent
from xms.srh.components.material_component import MaterialComponent
from xms.srh.components.monitor_component import MonitorComponent
from xms.srh.components.obstruction_component import ObstructionComponent
from xms.srh.components.sed_material_component import SedMaterialComponent
from xms.srh.data.par.bc_data_param import BcDataParam


def _swap_structure_upstream_downstream_if_needed(data_dict):
    """Swaps the upstream and downstream profiles if the upstream profile is on the left side of the arc.

    Args:
        data_dict (:obj:`dict`): dictionary of data for the structure
    """
    arc_ls = LineString(data_dict['arc_pts'])
    up_arc = data_dict['up_arc']
    up_arc_ls = LineString(up_arc)
    offset = arc_ls.distance(up_arc_ls) / 2.0
    arc_offset = arc_ls.parallel_offset(offset, 'right')
    arc_offset = shapely.reverse(arc_offset)  # reversing to match shapely 1.8.5
    dn_arc = data_dict['down_arc']
    dn_arc_ls = LineString(dn_arc)
    if arc_offset.distance(up_arc_ls) > arc_offset.distance(dn_arc_ls):
        data_dict['up_arc'] = dn_arc
        data_dict['down_arc'] = up_arc


class SimQueryHelper:
    """Class used to get data from XMS related to SRH."""
    def __init__(self, query, at_sim=False, sim_uuid=None):
        """Constructor. Must be constructed with a Context at the simulation Context level.

        Args:
            query (:obj:`Query`): class to communicate with XMS
            at_sim (:obj:`bool`): True if query Context is at the simulation level, False if it
                is at the simulation component level
            sim_uuid (str): The simulation UUID or None if you want to get the simulation UUID from the current item
        """
        self._query = query
        self.sim_uuid = None
        self.sim_comp_file = ''
        self.component_folder = ''
        self.sim_component = None
        self.sim_tree_item = None
        if self._query is not None:
            self._initialize_from_xms(at_sim, sim_uuid)
        self._logger = logging.getLogger('xms.srh')
        self.mapped_comps = None  # (list(tuple)): (data objects component, display_options_action, component name)
        self.grid_name = ''
        self.grid_units = ''
        self.grid_uuid = ''
        self.grid_wkt = ''
        self.co_grid = None
        self.co_grid_file = ''
        self.co_grid_file_crc32 = ''
        self.existing_mapped_component_uuids = []
        self.coverages = dict()
        self.bc_component = None
        self.material_component = None
        self.sed_material_component = None
        self.obstruction_component = None
        self.monitor_component = None
        self.wse_dataset = None
        self.grid_error = ''
        self.using_ugrid = False
        self.solution_tree_item = None
        self.mesh_link = None
        self.mesh_tree_item = None
        self.structures_3d = []
        self.structures_3d_weirs = []
        self.structures_3d_monitor = []
        self.nlcd_raster_file = ''
        self.nlcd_background_manning_n = 0.03

    def _initialize_from_xms(self, at_sim, sim_uuid):
        """Initialize member variables with data retrieved from XMS.

        Args:
            at_sim (:obj:`bool`): True if query Context is at the simulation level, False if it
                is at the simulation component level
            sim_uuid (str): The simulation UUID or None if you want to get the simulation UUID from the current item
        """
        if at_sim:  # Get simulation data and then move to component level
            if sim_uuid is None:
                self.sim_uuid = self._query.current_item_uuid()
            else:
                self.sim_uuid = sim_uuid
            sim_comp = self._query.item_with_uuid(self.sim_uuid, model_name='SRH-2D', unique_name='Sim_Manager')
            self.sim_comp_file = sim_comp.main_file
        else:
            self.sim_comp_file = self._query.current_item().main_file
            self.sim_uuid = self._query.parent_item_uuid()
        self.component_folder = os.path.dirname(os.path.dirname(self.sim_comp_file))
        from xms.srh.components.sim_component import SimComponent  # avoid circular dependencies
        self.sim_component = SimComponent(self.sim_comp_file)
        self.sim_tree_item = tree_util.find_tree_node_by_uuid(self._query.project_tree, self.sim_uuid)

    def get_geometry_data(self):
        """Get the mesh linked to the simulation."""
        self._get_mesh()

    def get_sim_data(self):
        """Gets the coverages associated with a simulation."""
        self._get_mesh()
        self._get_coverages()
        self._get_uuids_of_existing_mapped_components()
        self._get_coverage_comp_ids()
        self._get_structures()
        self.get_wse_dataset()
        self._get_nlcd_raster()

    def _get_nlcd_raster(self):
        """Get the filename of the nlcd raster."""
        mc = self.sim_component.data
        if mc.advanced.specify_materials_nlcd:
            raster = self._query.item_with_uuid(mc.advanced.raster_uuid)
            if raster is not None:
                self.nlcd_raster_file = raster
                self.nlcd_background_manning_n = mc.advanced.background_manning_n

    def get_3d_structure_data(self):
        """Gets data associated with 3D Structures."""
        bridge_data = []
        for struct in self.structures_3d:
            struct_cov = struct[0]
            struct_comp = struct[1]
            bot_ug, manning_n = self._regenerate_3d_structure_top_bottom(struct_comp, struct_cov)
            if bot_ug is not None:
                bridge_data.append((bot_ug, manning_n))

        # generate the ceiling file from the structures
        if bridge_data and self.co_grid is not None:
            ug = self.co_grid.ugrid
            grid_units = self.grid_units
            ceiling_file = os.path.join(os.path.dirname(self.sim_component.main_file), 'ceiling.srhceiling')
            gen = CeilingGenerator(ug, bridge_data, ceiling_file, grid_units)
            gen.generate()

    def _regenerate_3d_structure_top_bottom(self, struct_comp, struct_cov):
        """Generates the top and bottom of the bridge from the 3d structure.

        Args:
            struct_comp (:obj:`xms.bridge.structure_component`): 3d structure
            struct_cov (:obj:`Coverage`): data_objects coverage
        """
        # regenerate the bottom of the 3d structure in case we have changed projections
        do_mapping = True
        # json_str = struct_comp.data.data_dict['srh_mapping_info']
        # if json_str:
        #     mapping_info = json.loads(json_str)
        #     if mapping_info['wkt'] == self.grid_wkt:
        #         do_mapping = False
        bot_file = os.path.join(os.path.dirname(struct_comp.main_file), 'bottom2d.xmugrid')
        # if not os.path.isfile(bot_file):
        #     do_mapping = True
        if do_mapping:
            xms_data = StructureXmsData(None, None)
            xms_data.set_query_structure_comp_cov(self._query, struct_comp, struct_cov)
            arc_data = su.setup_structure_generation(xms_data)
            footprint_calc = FootprintCalculator(
                struct_comp=arc_data['component'],
                coverage=arc_data['coverage'],
                data_folder=arc_data['tmp_dir'],
                cl_arc_pts=arc_data['arc_pts'],
                wkt=arc_data['wkt']
            )
            ug_success, msg = su.generate_ugrids_from_structure(struct_comp.data, arc_data, footprint_calc)
            shutil.rmtree(arc_data['tmp_dir'], ignore_errors=True)
            if not ug_success:
                self._logger.warning(msg)
                msg = f'3D Structure coverage "{struct_cov.name}" was skipped.'
                self._logger.warning(msg)
                return

        json_str = struct_comp.data.data_dict['srh_mapping_info']
        mapping_info = json.loads(json_str)
        arc, _ = su.bridge_centerline_from_coverage(struct_cov, struct_comp)

        data_dict = {
            'up_stream': struct_comp.data.curves['upstream_profile'],
            'down_stream': struct_comp.data.curves['downstream_profile'],
            'top': struct_comp.data.curves['top_profile'],
            'up_arc': mapping_info['up_arc'],
            'down_arc': mapping_info['down_arc'],
            'wkt': self.grid_wkt,
            'main_file': struct_comp.main_file,
            'bc_data': self._structure_to_bc_data(struct_comp),
            'cov': struct_cov,
            'arc': arc,
            'arc_pts': [(pt.x, pt.y) for pt in arc.get_points(FilterLocation.PT_LOC_ALL)]
        }
        # make a line midway between the up and down arcs
        ls1 = LineString(data_dict['up_arc'])
        ls2 = LineString(data_dict['down_arc'])
        p_up = [ls1.interpolate(val / 100., normalized=True) for val in range(0, 101, 5)]
        p_up = [Point(pt.x * 0.55, pt.y * 0.55) for pt in p_up]
        p_dn = [ls2.interpolate(val / 100., normalized=True) for val in range(0, 101, 5)]
        p_dn = [Point(pt.x * 0.45, pt.y * 0.45) for pt in p_dn]
        data_dict['mid_arc'] = [(p_up[i].x + p_dn[i].x, p_up[i].y + p_dn[i].y) for i in range(21)]
        # up_arc must be on the right side of the arc
        _swap_structure_upstream_downstream_if_needed(data_dict)

        bot_ug = read_ugrid_from_ascii_file(bot_file)
        manning_n = struct_comp.data.data_dict['bridge_ceiling_roughness']
        if bot_ug is not None:
            self.structures_3d_monitor.append(data_dict)
            if data_dict['bc_data'] is not None:
                self.structures_3d_weirs.append(data_dict)
        else:
            msg = f'3D Structure coverage "{struct_cov.name}" was skipped.'
            self._logger.warning(msg)

        if mapping_info['culvert_poly'] is not None:
            culvert_poly = Polygon(mapping_info['culvert_poly'])
            pts = bot_ug.locations
            cs = bot_ug.cellstream
            new_cs = []
            cnt = 0
            cell_idx = -1
            while cnt < len(cs):
                num_pts = cs[cnt + 1]
                end = cnt + 2 + num_pts
                cell_idx += 1
                flag, cc = bot_ug.get_cell_centroid(cell_idx)
                sh_pt = Point((cc[0], cc[1]))
                if culvert_poly.contains(sh_pt):
                    new_cs.extend(cs[cnt:end])
                cnt = end
            bot_ug = UGrid(pts, new_cs)

        return bot_ug, manning_n

    def _structure_to_bc_data(self, struct_comp):
        """Create bc data from a structure component.

        Args:
            struct_comp (:obj:`xms.bridge.structure_component`): 3d structure

        Returns:
            (:obj:`BcData`): A BcData class
        """
        if int(struct_comp.data.data_dict['srh_overtopping']) == 0:
            return None
        bc = BcDataParam()
        bc.type = 'Link'
        bc.link.inflow_type = 'Weir'
        bc.link.weir.length = struct_comp.data.data_dict['srh_weir_length']
        bc.link.weir.crest_elevation = su.compute_weir_elevation(struct_comp)
        bc.link.weir.units = 'Feet' if 'FOOT' in self.grid_units else 'Meters'
        bc.link.weir.type.type = struct_comp.data.data_dict['srh_weir_type']
        bc.link.weir.type.cw = struct_comp.data.data_dict['srh_weir_cw']
        bc.link.weir.type.a = struct_comp.data.data_dict['srh_weir_a']
        bc.link.weir.type.b = struct_comp.data.data_dict['srh_weir_b']
        bc.link.link_lag_method = 'Specified'
        return bc

    def get_solution_data(self):
        """Get solution datasets for a simulation.

        Returns:
            (:obj:`list`): List of the solution data_object Dataset dumps for this simulation

        """
        dset_dumps = []
        sim_name = self.sim_tree_item.name
        sim_folder = f'{sim_name} (SRH-2D)'

        # Get the mesh tree item.
        self._get_mesh_link()
        if not self.mesh_link:
            self._logger.error('Unable to find SRH-2D solution datasets.')
            return dset_dumps
        mesh_item = tree_util.find_tree_node_by_uuid(self._query.project_tree, self.mesh_link.uuid)

        # Get the simulation solution folder
        solution_folder = tree_util.first_descendant_with_name(mesh_item, sim_folder)
        if not solution_folder:
            self._logger.error('Unable to find SRH-2D solution datasets.')
            return dset_dumps
        self.solution_tree_item = solution_folder
        self.mesh_tree_item = mesh_item

        # Get dumps of all the children datasets
        solution_dsets = tree_util.descendants_of_type(solution_folder, xmd.DatasetItem)
        try:
            for dset in solution_dsets:
                dset_dumps.append(self._query.item_with_uuid(dset.uuid))
        except Exception:  # pragma no cover - hard to test exceptions using QueryPlayback
            self._logger.exception('Error getting solution dataset.')
        return dset_dumps

    def _get_coverages(self):
        """Gets the coverages associated with a simulation."""
        covs = [
            ('Materials', 'Material_Component'), ('Sediment Materials', 'SedMaterial_Component'),
            ('Boundary Conditions', 'Bc_Component'), ('Monitor', 'Monitor_Component'),
            ('Obstructions', 'Obstruction_Component')
        ]
        for pair in covs:
            cov_item = tree_util.descendants_of_type(
                self.sim_tree_item,
                xms_types=['TI_COVER_PTR'],
                allow_pointers=True,
                only_first=True,
                recurse=False,
                coverage_type=pair[0],
                model_name='SRH-2D'
            )
            if cov_item:
                cov_comp = self._query.item_with_uuid(item_uuid=cov_item.uuid, model_name='SRH-2D', unique_name=pair[1])
                if cov_comp:
                    cov_dump = self._query.item_with_uuid(cov_item.uuid)
                    self.coverages[pair[0]] = (cov_dump, cov_comp.main_file)

    def _get_mesh_link(self):
        """Gets the mesh associated with a simulation."""
        # Find the mesh tree item under the simulation.
        mesh_links = tree_util.descendants_of_type(
            self.sim_tree_item, xms_types=['TI_MESH2D_PTR', 'TI_UGRID_PTR'], allow_pointers=True
        )
        if len(mesh_links) != 1:
            return

        if mesh_links[0].item_typename == 'TI_UGRID_PTR':
            self.using_ugrid = True
        self.mesh_link = mesh_links[0]

    def _get_mesh(self):
        """Gets the mesh associated with a simulation."""
        self._logger.info('Getting mesh from simulation.')
        self._get_mesh_link()
        if self.mesh_link is None:
            return
        mesh_item = self._query.item_with_uuid(self.mesh_link.uuid)
        self.grid_name = mesh_item.name
        proj = mesh_item.projection
        unit_str = proj.horizontal_units
        if unit_str == 'METERS':
            self.grid_units = 'GridUnit "METER"'
        elif unit_str in ['FEET (U.S. SURVEY)', 'FEET (INTERNATIONAL)']:
            self.grid_units = 'GridUnit "FOOT"'
        else:
            err_str = 'Unable to get horizontal units from mesh'
            self._logger.error(err_str)
            self._logger.error(f'unit_str: {unit_str}.')
            msg = 'Units must be one of: "METERS", "FEET (U.S. SURVEY)", "FEET (INTERNATIONAL)"'
            self._logger.error(msg)
            self.grid_error = f'{err_str}. {msg}'
            raise RuntimeError(err_str)
        self.grid_uuid = mesh_item.uuid
        self.grid_wkt = proj.well_known_text
        self.co_grid_file = mesh_item.cogrid_file
        try:  # sometimes this causes a fatal exception when running tests
            with open(self.co_grid_file, 'rb') as f:
                self.co_grid_file_crc32 = str(hex(binascii.crc32(f.read()) & 0xFFFFFFFF))
        except Exception:  # pragma
            self._logger.error(f'Error reading file: {self.co_grid_file}')
        self.co_grid = read_grid_from_file(mesh_item.cogrid_file)
        self._logger.info('Mesh successfully loaded.')

    def _get_uuids_of_existing_mapped_components(self):
        """Gets the uuids of any existing mapped components."""
        # Get children of the simulation that are component tree items
        comp_items = tree_util.descendants_of_type(self.sim_tree_item, xms_types=['TI_COMPONENT'])
        self.existing_mapped_component_uuids = [comp_item.uuid for comp_item in comp_items]

    def _get_coverage_comp_ids(self):
        """Load the component ids for the coverage."""
        self._logger.info('Getting feature ids and component ids for coverages.')
        if 'Boundary Conditions' in self.coverages:
            self.bc_component = BcComponent(self.coverages['Boundary Conditions'][1])
            self._query.load_component_ids(self.bc_component, arcs=True)
        if 'Materials' in self.coverages:
            self.material_component = MaterialComponent(self.coverages['Materials'][1])
            self._query.load_component_ids(self.material_component, polygons=True)
        if 'Sediment Materials' in self.coverages:
            self.sed_material_component = SedMaterialComponent(self.coverages['Sediment Materials'][1])
            self._query.load_component_ids(self.sed_material_component, polygons=True)
        if 'Obstructions' in self.coverages:
            self.obstruction_component = ObstructionComponent(self.coverages['Obstructions'][1])
            self._query.load_component_ids(self.obstruction_component, points=True, arcs=True)
        if 'Monitor' in self.coverages:
            self.monitor_component = MonitorComponent(self.coverages['Monitor'][1])
            self._query.load_component_ids(self.monitor_component, points=True, arcs=True)

    def _get_structures(self):
        """Get the 3d structures that are part of the simulation."""
        covs = tree_util.descendants_of_type(self.sim_tree_item, xms_types=['TI_COVER_PTR'], allow_pointers=True)
        struct_tree_items = [c for c in covs if c.coverage_type == '3D Structure']
        for item in struct_tree_items:
            cov_comp = self._query.item_with_uuid(item_uuid=item.uuid, model_name='3D Bridge', unique_name='Structure')
            if cov_comp:
                cov_dump = self._query.item_with_uuid(item.uuid)
                struct_comp = StructureComponent(cov_comp.main_file)
                self._query.load_component_ids(struct_comp, arcs=True)
                self.structures_3d.append((cov_dump, struct_comp))

    def add_mapped_components_to_xms(self):
        """Add mapped components to the XMS project."""
        # Add the mapped component to the Context and send back to XMS.
        # from xms.srh.components import debug_ctx
        self._logger.info('Adding mapped display items.')
        if len(self.mapped_comps) < 1:
            return

        # delete any existing mapped components
        for str_uuid in self.existing_mapped_component_uuids:
            self._query.delete_item(str_uuid)

        # Add the new mapped components and link them to the simulation
        for mapped_comp in self.mapped_comps:
            self._query.add_component(do_component=mapped_comp[0], actions=mapped_comp[1])
            self._query.link_item(taker_uuid=self.sim_uuid, taken_uuid=mapped_comp[0].uuid)

    def get_wse_dataset(self):
        """Get the water surface elevations data set."""
        mc = self.sim_component.data
        if mc.hydro.initial_condition != 'Water Surface Elevation Dataset':
            return
        self._logger.info('Processing SRH-2D water surface elevation data set file.')

        # Get uuid and time step index
        dset_uuid, ts_idx = treeitem_selector_datasets.uuid_and_time_step_index_from_string(
            mc.hydro.water_surface_elevation_dataset
        )

        dset = self._query.item_with_uuid(dset_uuid)
        if dset is not None:
            self.wse_dataset = dset.values[ts_idx]
            self._logger.info('Success processing SRH-2D water surface elevation data set file.')
        else:
            self._logger.info('No valid water surface data set specified for initial condition.')
