"""Class for running the insert ewn feature command in a feedback dialog."""
__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import logging
import os
import shutil
import uuid

# 2. Third party modules
from PySide2.QtCore import QThread, Signal
import xarray as xr

# 3. Aquaveo modules
from xms.adcirc.components.bc_component import BcComponent
from xms.adcirc.components.bc_component_display import BcComponentDisplay
from xms.adcirc.data import bc_data as bcd
import xms.adcirc.mapping.mapping_util as map_util
from xms.components.display.display_options_io import write_display_option_ids
from xms.data_objects.parameters import Arc, Coverage, FilterLocation, Point
from xms.guipy.data.target_type import TargetType
from xms.guipy.dialogs.process_feedback_dlg import LogEchoQSignalStream, ProcessFeedbackDlg

# 4. Local modules
from xms.ewn.dmi.tool_runner_input_queries import ToolRunnerInputQueries
from xms.ewn.tools.insert_ewn_features import InsertEwnFeatures
from xms.ewn.tools.runners import runner_util


class InsertLeveeRunner(QThread):
    """Class for running the insert ewn feature command in a feedback dialog."""
    processing_finished = Signal()

    def __init__(self, query, grid_uuid, adcirc_coverage, coverage_uuid, bias):
        """Constructor for class.

        Args:
            query (:obj:`xms.api.dmi.Query`): query for xms communication
            adcirc_coverage (:obj:`str`): UUID of the adcirc bc coverage
            coverage_uuid (:obj:`str`): UUID of the input coverages
            grid_uuid (:obj:`str`): UUID of the input geometry
            bias (:obj:`float`): bias for xms.mesher
        """
        super().__init__()
        self._adcirc_coverage = adcirc_coverage
        self._coverage_uuid = coverage_uuid
        self._grid_uuid = grid_uuid
        self._is_geographic = False
        self._bias = bias
        self._logger = logging.getLogger('xms.ewn')
        self._queries = ToolRunnerInputQueries(query)
        self._grid = None
        self._poly_data_list = []
        self._new_cov_geom = None
        self._new_bc_comp = None
        self._levee_comp_ids = None
        self.tool = None
        self.out_bc_coverage = None
        self.projection = None
        self._adcirc_bc_cov_comp = None
        self._adcirc_bc_cov_geom = None
        self._new_comp_uuid_testing = None
        self._new_cov_uuid_testing = None

    def run(self):
        """Inserts ewn features into an existing mesh."""
        try:
            self._retrieve_xms_data()
            self._run_tool()
        except:  # noqa
            self._logger.exception('Error inserting EWN features.')
        finally:
            self.processing_finished.emit()

    def _run_tool(self):
        """Runs the insert levee process."""
        self.tool = InsertEwnFeatures(
            self._poly_data_list, self._grid, self._is_geographic, is_levee=True, bias=1 - self._bias
        )
        self.tool.insert_features()
        self._add_to_bc_coverage()

    def _retrieve_xms_data(self):
        """Get all input data needed for the insertion operation."""
        out_list = self._queries.get_coverage_data([self._adcirc_coverage], 'ADCIRC_BC')
        grid, coverage_data, is_cartesian = self._queries.get_insert_features_input(
            self._grid_uuid, [self._coverage_uuid], coverage_type='EWN'
        )
        self._grid = grid.ugrid
        if is_cartesian:
            self._logger.error('Cartesian grid detected. Cannot insert levee in cartesian grid.')
            raise RuntimeError
        self._is_geographic = self._queries.is_geographic
        self.projection = self._queries.projection
        for coverage in coverage_data:
            self._poly_data_list.extend(runner_util.get_levee_polygon_input(coverage[0], coverage[1], False))
        self._ensure_ewn_feature_exists()
        self._adcirc_bc_cov_comp = out_list[0][1]
        self._adcirc_bc_cov_geom = out_list[0][0]

    def _ensure_ewn_feature_exists(self):
        """Log an error if there are no input EWN feature polygons."""
        if not self._poly_data_list:
            self._logger.error('At least one EWN feature polygon must be defined.')
            raise RuntimeError

    def _add_to_bc_coverage(self):
        """Add the levee to the ADCIRC BC coverage."""
        self._copy_bc_component(self._adcirc_bc_cov_comp)
        self._add_levee_arcs_to_coverage(self._adcirc_bc_cov_geom)

        # update the levee display id file
        levee_disp_id_file = BcComponentDisplay.get_display_id_file(
            bcd.LEVEE_INDEX, os.path.dirname(self._new_bc_comp.main_file)
        )
        write_display_option_ids(levee_disp_id_file, self._levee_comp_ids)

        self.out_bc_coverage = {
            'cov_geom': self._new_cov_geom,
            'cov_comp': self._new_bc_comp,
        }

    def _copy_bc_component(self, cov_comp):
        """Copies the adcirc bc coverage component.

        Args:
            cov_comp (:obj:`xmsadcirc.components.BcComponent`): adcirc bc component
        """
        components_dir = os.path.dirname(os.path.dirname(cov_comp.main_file))
        new_comp_uuid = self._new_comp_uuid_testing
        if new_comp_uuid is None:
            new_comp_uuid = str(uuid.uuid4())
        new_dir = os.path.join(components_dir, new_comp_uuid)
        shutil.copytree(os.path.dirname(cov_comp.main_file), new_dir)
        cov_comp.save_to_location(new_dir, 'DUPLICATE')
        new_main_file = os.path.join(new_dir, os.path.basename(cov_comp.main_file))
        self._new_bc_comp = BcComponent(new_main_file)
        self._new_bc_comp.comp_to_xms = cov_comp.comp_to_xms

        bc_data = self._new_bc_comp.data
        for item in self._poly_data_list:
            arc_locs = item['levee_arcs']
            p_len = [map_util.get_parametric_lengths(arc_locs[0]), map_util.get_parametric_lengths(arc_locs[1])]
            ave_p_len = [(p0 + p1) / 2 for p0, p1 in zip(p_len[0], p_len[1])]
            ave_z = [(p0[2] + p1[2]) / 2 for p0, p1 in zip(arc_locs[0], arc_locs[1])]

            comp_id = bc_data.add_bc_atts()
            item['comp_id'] = comp_id
            bc_atts = bc_data.arcs.where(bc_data.arcs.comp_id == comp_id, drop=True)
            bc_atts['type'] = bcd.LEVEE_INDEX
            bc_data.update_bc(comp_id, bc_atts)

            # levee atts
            coords = {'comp_id': [comp_id] * len(ave_p_len)}
            default_flow = [1.0] * len(ave_p_len)
            levee_table = {
                'Parametric __new_line__ Length': ('comp_id', ave_p_len),
                'Zcrest (m)': ('comp_id', ave_z),
                'Subcritical __new_line__ Flow Coef': ('comp_id', default_flow),
                'Supercritical __new_line__ Flow Coef': ('comp_id', default_flow),
            }

            levee_atts = xr.Dataset(data_vars=levee_table, coords=coords)
            bc_data.add_levee_atts(levee_atts)
        bc_data.commit()

    def _add_levee_arcs_to_coverage(self, cov_geom):
        """Add levee arcs to the coverage.

        Args:
            cov_geom (:obj:`data_objects.parameters.Coverage`): the coverage that is copied

        Returns:
            (:obj:`data_objects.parameters.Coverage, list`): the new coverage with the levee arcs
        """
        #  copy the coverage
        new_cov_uuid = self._new_comp_uuid_testing
        if new_cov_uuid is None:
            new_cov_uuid = str(uuid.uuid4())
        new_cov = Coverage(name=f'{cov_geom.name}', projection=cov_geom.projection, uuid=new_cov_uuid)
        self._new_bc_comp.cov_uuid = new_cov_uuid
        self._new_bc_comp.data.info.attrs['cov_uuid'] = new_cov_uuid
        self._new_bc_comp.data.commit()  # save the coverage uuid

        levee_arcs = self._new_bc_comp.data.arcs.where(self._new_bc_comp.data.arcs.type == bcd.LEVEE_INDEX, drop=True)
        levee_comp_ids = levee_arcs.comp_id.data.tolist()
        # need to get the update_ids for the new coverage
        cov_dict = {}
        if cov_geom.uuid in self._new_bc_comp.comp_to_xms:
            cov_dict = self._new_bc_comp.comp_to_xms[cov_geom.uuid]
        self._levee_comp_ids = []
        self._new_bc_comp.update_ids[new_cov_uuid] = {}
        for target_type, comp_id_dict in cov_dict.items():
            self._new_bc_comp.update_ids[new_cov_uuid][target_type] = {}
            for comp_id, att_ids in comp_id_dict.items():
                for att_id in att_ids:
                    self._new_bc_comp.update_ids[new_cov_uuid][target_type][att_id] = comp_id
                    if comp_id in levee_comp_ids:
                        self._levee_comp_ids.append(comp_id)
        if TargetType.arc not in self._new_bc_comp.update_ids[new_cov_uuid]:
            self._new_bc_comp.update_ids[new_cov_uuid][TargetType.arc] = {}
        att_id_to_comp_id = self._new_bc_comp.update_ids[new_cov_uuid][TargetType.arc]

        new_cov.set_points(cov_geom.get_points(FilterLocation.PT_LOC_DISJOINT))
        new_cov.arcs = cov_geom.arcs
        new_cov.polygons = cov_geom.polygons
        # get the max arc id and max point id
        arcs = new_cov.arcs
        max_arc_id = 0
        for arc in arcs:
            max_arc_id = max(arc.id, max_arc_id)
        pts = new_cov.get_points(FilterLocation.PT_LOC_DISJOINT | FilterLocation.PT_LOC_CORNER)
        max_pt_id = 0
        for pt in pts:
            max_pt_id = max(pt.id, max_pt_id)
        # build the levee arcs
        arc_list = []
        for item in self._poly_data_list:
            levee_arcs = item['levee_arcs']
            comp_id = item['comp_id']
            arc_ids = [-1, -1]
            for levee_arc in levee_arcs:
                points = [Point(pt[0], pt[1], pt[2]) for pt in levee_arc]
                max_pt_id += 1
                points[0].id = max_pt_id
                max_pt_id += 1
                points[-1].id = max_pt_id

                max_arc_id += 1
                arc = Arc(feature_id=max_arc_id, start_node=points[0], end_node=points[-1], vertices=points[1:-1])
                arc_list.append(arc)
                att_id_to_comp_id[max_arc_id] = comp_id
                self._levee_comp_ids.append(comp_id)
            item['new_arc_ids'] = arc_ids
        new_cov.arcs = arc_list
        new_cov.complete()
        self._new_cov_geom = new_cov


def insert_levee_with_feedback(target_geometry, adcirc_coverage, levee_coverage, bias, query, parent):
    """Run the Insert EWN Features command with a feedback dialog.

    Args:
        target_geometry (:obj:`str`): UUID of the target geometry
        adcirc_coverage (:obj:`str`): UUID of the adcirc bc coverage
        levee_coverage (:obj:`str`): UUID of the input Levee coverage
        bias (:obj:`float`): bias for xms.mesher
        query (:obj:`xms.api.dmi.Query`): XMS interprocess communication object
        parent (:obj:`PySide2.QtWidgets.QWidget`): The Qt parent window container

    Returns:
        (:obj:`tuple (bool, InsertEwnFeatureRunner)`):
            False if an error level message was logged during the operation,
            the runner object for the operation (contains output data)
    """
    worker = InsertLeveeRunner(query, target_geometry, adcirc_coverage, levee_coverage, bias)
    error_str = 'Error(s) encountered inserting Levee features. Review log output for more details.'
    warning_str = 'Warning(s) encountered inserting Levee features. Review log output for more details.'
    display_text = {
        'title': 'Insert Levee',
        'working_prompt': 'Applying Levee Features to mesh. Please wait...',
        'error_prompt': error_str,
        'warning_prompt': warning_str,
        'success_prompt': 'Successfully inserted Levee Features.',
        'note': '',
        'auto_load': 'Close this dialog automatically when insertion of features is finished.'
    }
    feedback_dlg = ProcessFeedbackDlg(display_text, 'xms.ewn', worker, parent)
    feedback_dlg.exec()
    return LogEchoQSignalStream.logged_error, worker
