"""Class for writing water movers to model input files."""
__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from functools import cached_property
import xml.etree.cElementTree as Et

# 2. Third party modules
from shapely.geometry import Point

# 3. Aquaveo modules
from xms.data_objects.parameters import FilterLocation
from xms.guipy.data.target_type import TargetType

# 4. Local modules
from xms.rsm.data import water_mover_data_def as wmdd
from xms.rsm.file_io import util
from xms.rsm.file_io.bc_val_writer import BcValWriter
from xms.rsm.file_io.water_body_info import WaterBodyInfo
from xms.rsm.file_io.water_mover_monitor_writer import WaterMoverMonitorWriter


class WaterMoverWriter:
    """Writer class for water movers in the RSM control file."""
    def __init__(self, writer_data):
        """Constructor.

        Args:
            writer_data (WriterData): Class with information needed to writer model input files.
        """
        self._data = _WriterData(writer_data)
        self._wb_id_1 = -1
        self._wb_id_2 = -1
        self._write_methods = {
            'canal_bc': self._canal_bc,
            'single_control': self._single_control,
            'genxweir': self._genxweir,
            'setflow': self._setflow,
        }
        self.canal_bcs = []

    def write(self):
        """Write the water mover data to the xml."""
        for cov, comp in self._data.xms_data.water_mover_coverages:
            self._data.cov = cov
            self._data.comp = comp
            self._data.cov_name = cov.name
            self._write_coverage()

    def write_monitor(self, xml_parent):
        """Write the monitor data to the xml.

        Args:
            xml_parent (xml.etree.cElementTree.Element): xml parent element
        """
        wm = WaterMoverMonitorWriter(xml_parent, self._data.monitor_data)
        wm.write()

    def _write_coverage(self):
        """Writer the water mover coverage to the xml."""
        self._data.set_feature_type(TargetType.point)
        self._write_coverage_features()
        self._data.set_feature_type(TargetType.arc)
        self._write_coverage_features()

    def _write_coverage_features(self):
        """Writer the water mover coverage to the xml."""
        for item in self._data.features:
            self._data.feature = item
            self._data.feature_id = item.id
            self._data.msg = f'{self._data.ftype_str} id: "{item.id}" in water mover coverage "{self._data.cov_name}"'
            comp_id = self._data.comp.get_comp_id(self._data.feature_type, item.id)
            if comp_id is None or comp_id < 0:
                comp_id = util.UNINITIALIZED_COMP_ID
            if comp_id == util.UNINITIALIZED_COMP_ID:
                continue
            p_type, p_val = self._data.comp.data.feature_type_values(self._data.feature_type, comp_id)
            self._data.gm_par.restore_values(p_val)
            active_groups = self._data.gm_par.active_group_names
            if not active_groups:
                continue

            for grp_name in active_groups:
                if grp_name in self._write_methods:
                    self._data.cur_grp = self._data.gm_par.group(grp_name)
                    try:
                        self._get_water_body_ids()
                        self._write_methods[grp_name]()
                        self._store_monitor_info()
                        self._data.wmID += 1
                    except ValueError as msg:
                        msg = f'{self._data.msg} was skipped because {msg}'
                        self._data.logger.warning(msg)

    def _get_water_body_ids(self):
        """Get the water body IDs for the current point."""
        if 'id1' not in self._data.cur_grp.parameter_names:
            return
        if self._data.feature_type == TargetType.arc:
            self._wb_id_1, self._wb_id_2 = self._data.arc_water_body_ids()
            return
        label_1 = self._data.cur_grp.parameter('id1').value
        label_2 = self._data.cur_grp.parameter('id2').value
        if not label_1:
            raise ValueError('label of water body 1 may not be empty.')
        if not label_2:
            raise ValueError('label of water body 2 may not be empty.')
        if label_1 == label_2:
            raise ValueError('water body labels must be different.')

        self._wb_id_1, self._wb_id_2 = self._data.water_body_ids(label_1, label_2)

        # handle cell ids
        if self._wb_id_1 < 0:
            self._wb_id_1 = self._cell_id_from_label(label_1)
        if self._wb_id_2 < 0:
            self._wb_id_2 = self._cell_id_from_label(label_2)
        if self._wb_id_1 < 0:
            raise ValueError(f'water body with label "{label_1}" was not found.')
        if self._wb_id_2 < 0:
            raise ValueError(f'water body with label "{label_2}" was not found.')

    def _cell_id_from_label(self, label):
        """Get the cell ID from the label."""
        rval = -1
        if label.startswith('CELLID'):
            label = label.replace('CELLID', '').strip()
            if label:
                try:
                    rval = int(label)
                except ValueError:
                    pass
            else:
                cell_id = self._data.cell_id_from_pt((self._data.feature.x, self._data.feature.y))
                if cell_id > 0:
                    rval = cell_id
        return rval

    def _canal_bc(self):
        """Write the canal boundary condition water mover to the xml."""
        # get the 2 canals for the WM
        id1, id2 = self._data.canal_ids_for_bc()
        atts = {
            'id1': str(id1),
            'id2': str(id2),
        }
        label = self._data.cur_grp.parameter('label').value
        if label:
            atts['label'] = label
        bc_id = self._data.cur_grp.parameter('bc_id').value
        if bc_id > 0:
            atts['bcid'] = str(bc_id)
        bc_type = self._data.cur_grp.parameter('bc_type').value
        if bc_type == 'Uniform flow (<uniformflow>)':
            atts['slope'] = str(self._data.cur_grp.parameter('slope').value)
            bc = Et.Element('uniformflow', attrib=atts)
        else:  # bc_type == 'Structure at junction (<junctionblock>)'
            bc = Et.Element('junctionblock', attrib=atts)
        self.canal_bcs.append(bc)

    def _single_control(self):
        """Write the single control water mover to the xml."""
        table = self._data.cur_grp.parameter('table').value
        if len(table) < 1:
            raise ValueError('Head-discharge table is not defined.')
        control_id = self._wb_id_1
        if self._data.cur_grp.parameter('controlled_by').value == 'waterbody 2':
            control_id = self._wb_id_2
        gravflow = 'yes' if self._data.cur_grp.parameter('gravflow').value else 'no'
        revflow = 'yes' if self._data.cur_grp.parameter('revflow').value else 'no'
        atts = {
            'id1': str(self._wb_id_1),
            'id2': str(self._wb_id_2),
            'wmID': str(self._data.wmID),
            'control': str(control_id),
            'cutoff': str(self._data.cur_grp.parameter('cutoff').value),
            'gravflow': gravflow,
            'revflow': revflow,
        }
        lab = self._data.cur_grp.parameter('label').value
        if lab:
            atts['label'] = lab
        xml_tag = Et.SubElement(self._data.water_mover_xml, 'single_control', attrib=atts)
        for item in table:
            line_elem = Et.SubElement(xml_tag, 'remove_me')
            line_elem.text = f'{item[0]} {item[1]}'

    def _genxweir(self):
        """Write the general weir water mover to the xml."""
        atts = {
            'id1': str(self._wb_id_1),
            'id2': str(self._wb_id_2),
            'wmID': str(self._data.wmID),
            'fcoeff': str(self._data.cur_grp.parameter('fcoeff').value),
            'bcoeff': str(self._data.cur_grp.parameter('bcoeff').value),
            'crestelev': str(self._data.cur_grp.parameter('crestelev').value),
            'crestlen': str(self._data.cur_grp.parameter('crestlen').value),
            'dpower': str(self._data.cur_grp.parameter('dpower').value),
            'spower': str(self._data.cur_grp.parameter('spower').value),
        }
        lab = self._data.cur_grp.parameter('label').value
        if lab:
            atts['label'] = lab
        Et.SubElement(self._data.water_mover_xml, 'genxweir', attrib=atts)

    def _setflow(self):
        """Write the set flow water mover to the xml."""
        atts = {
            'id1': str(self._wb_id_1),
            'id2': str(self._wb_id_2),
            'wmID': str(self._data.wmID),
        }
        lab = self._data.cur_grp.parameter('label').value
        if lab:
            atts['label'] = lab
        item = Et.SubElement(self._data.water_mover_xml, 'setflow', attrib=atts)
        wd = self._data.wd
        bc = BcValWriter(item, self._data.cur_grp, wd.csv_writer, wd.rule_curve_label_id, self._data.msg)
        bc.write()

    def _store_monitor_info(self):
        """Store monitor information for the water mover."""
        active_grps = set(self._data.gm_par.active_group_names)
        if not active_grps.intersection(self._data.monitor_set):  # no monitor groups to write
            return
        pp = self._data.gm_par.copy()  # make a copy of the point parameters
        for gp in pp.group_names:  # remove groups that are not monitor groups
            if gp not in self._data.monitor_set or gp not in active_grps:
                pp.remove_group(gp)

        md = (self._data.wmID, pp, self._data.feature_id, self._data.cov_name)
        self._data.monitor_data.append(md)


class _WriterData:
    """Data class for writer data."""
    def __init__(self, writer_data):
        self.wd = writer_data
        self.wmID = writer_data.xms_data.waterbody_start_id + 400_001
        self.logger = writer_data.logger
        self.xms_data = writer_data.xms_data
        self.xml_hse = writer_data.xml_hse
        wb_info = writer_data.water_body_info
        self.lookup = {}
        for key, value in wb_info.items():
            self.lookup[key] = {}
            for item in value:
                self.lookup[key][item.label] = item.wb_id
        self.cov = None
        self.comp = None
        self.cov_name = None
        self.feature_id = 0
        self.feature = None
        self.monitor_data = []
        self.msg = ''
        self.cur_grp = None
        # items associated with Points or Arcs
        self.ftype_str = ''
        self.feature_type = None
        self.features = None
        self.gm_par = None
        self.monitor_set = None

    def set_feature_type(self, target_type):
        self.feature_type = target_type
        if target_type == TargetType.point:
            self.ftype_str = 'Point'
            self.features = self.cov.get_points(FilterLocation.PT_LOC_DISJOINT)
            self.gm_par = wmdd.generic_model().point_parameters
        else:
            self.ftype_str = 'Arc'
            self.features = self.cov.arcs
            self.gm_par = wmdd.generic_model().arc_parameters
        monitor_grps = [self.gm_par.group(nm) for nm in self.gm_par.group_names]
        self.monitor_set = set([gp.group_name for gp in monitor_grps if gp.label.startswith('Monitor ')])

    @cached_property
    def water_mover_xml(self):
        """Get the mesh bc xml element."""
        return Et.SubElement(self.xml_hse, 'watermovers')

    def water_body_ids(self, label_1, label_2):
        """Get the water body IDs for the given labels.

        Args:
            label_1 (str): Label of the first water body
            label_2 (str): Label of the second water body

        Returns:
            tuple: Tuple containing the IDs of the two water bodies
        """
        id_1 = id_2 = -1
        for _, lookup_dict in self.lookup.items():
            if label_1 in lookup_dict:
                id_1 = lookup_dict[label_1]
            if label_2 in lookup_dict:
                id_2 = lookup_dict[label_2]
            if id_1 >= 0 and id_2 >= 0:
                break
        return id_1, id_2

    def canal_ids_for_bc(self):
        """Get the water body IDs for the canals given labels.

        Returns:
            tuple: Tuple containing the IDs of the two canals
        """
        canal_info = self.wd.water_body_info['canal']
        if len(canal_info) < 2:
            raise ValueError('at least 2 canals must be defined for a "Canal boundary condition".')

        if self.feature_type == TargetType.point:
            loc = (self.feature.x, self.feature.y)
            canals = self._sorted_water_bodies_from_location_and_type(loc, 'canal_junction')
            if len(canals[0][1].canal_ids) > 2:  # warn about a junction with more than 2 canals
                msg = f'The nearest canal junction {loc} to {self.msg} has more than 2 canals. '\
                      'The first two canals will be used.'
                self.wd.logger.warning(msg)
            id_1, id_2 = canals[0][1].canal_ids[0], canals[0][1].canal_ids[1]
        else:  # TargetType.arc
            loc = (self.feature.start_node.x, self.feature.start_node.y)
            canals = self._sorted_water_bodies_from_location_and_type(loc, wmdd.WB_CANAL)
            id_1 = canals[0][1].wb_id
            loc = (self.feature.end_node.x, self.feature.end_node.y)
            canals = self._sorted_water_bodies_from_location_and_type(loc, wmdd.WB_CANAL)
            id_2 = canals[0][1].wb_id
            self._canals_share_node(id_1, id_2)
        return id_1, id_2

    def _canals_share_node(self, id_1, id_2):
        """Check if the two canals share a node.

        Args:
            id_1 (int): ID of the first canal
            id_2 (int): ID of the second canal

        Raises:
            ValueError: If the canals do not share a node.
        """
        if id_2 == id_1:
            raise ValueError('the same canal ID was found for each node of the water mover arc.')
        info = self.wd.water_body_info['canal_junction']
        for item in info:
            id_set = set(item.canal_ids)
            if id_1 in id_set and id_2 in id_set:
                return
        raise ValueError('nearest canals to the arc nodes do not share a junction.')

    def cell_id_from_pt(self, pt):
        """Get the cell ID from the point.

        Args:
            pt (Tuple(x,y)): The point to get the cell ID for.

        Returns:
            int: The cell ID for the point.
        """
        extractor = self.xms_data.ugrid_extractor
        extractor.extract_locations = [pt]
        extractor.extract_data()
        cell_id = extractor.cell_indexes[0] + 1
        if cell_id > 0:
            return cell_id
        return -1

    def _sorted_water_bodies_from_location_and_type(self, location, wb_type):
        """Get the water body ID from the location and water body type.

        Args:
            location (Tuple(x,y)): The location to get the water body ID for.
            wb_type (str): The water body type to get the ID for.

        Returns:
            int: The water body ID for the location and water body type.
        """
        if wb_type == wmdd.WB_MESH_CELL:
            cell_id = self.cell_id_from_pt(location)
            return [(0.0, WaterBodyInfo('CELLID', cell_id, Point(location)))]
        shp_node = Point(location)
        info = self.wd.water_body_info[wb_type.lower()]
        features = []
        for item in info:
            features.append((shp_node.distance(item.geometry), item))
        features.sort()
        return features

    def arc_water_body_ids(self):
        """Get the water body IDs for the arc.

        This method assumes that the arc has two nodes and gets the water body IDs for each node.
        """
        wb_ids = [-1, -1]
        items = [(self.feature.start_node, 'id1'), (self.feature.end_node, 'id2')]
        for ix, item in enumerate(items):
            xy_loc = (item[0].x, item[0].y)
            ftype = self.cur_grp.parameter(item[1]).value
            water_bodies = self._sorted_water_bodies_from_location_and_type(xy_loc, ftype)
            wb_ids[ix] = water_bodies[0][1].wb_id if water_bodies else -1
        if wb_ids[0] < 0:
            nd_id = self.feature.start_node.id
            raise ValueError(f'no water body found for the first arc node, ID: "{nd_id}".')
        if wb_ids[1] < 0:
            nd_id = self.feature.end_node.id
            raise ValueError(f'no water body found for the last arc node, ID: "{nd_id}".')
        if wb_ids[0] == wb_ids[1]:
            raise ValueError('the same water body ID was found for each arc node.')
        return wb_ids[0], wb_ids[1]
