"""Class for writing impoundments to model input files."""
__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from functools import cached_property
import xml.etree.cElementTree as Et

# 2. Third party modules
from shapely.geometry import Polygon

# 3. Aquaveo modules
from xms.coverage.polygons.polygon_orienteer import get_polygon_point_lists
from xms.guipy.data.target_type import TargetType
from xms.snap.snap_polygon import SnapPolygon

# 4. Local modules
from xms.rsm.data import impoundment_data_def as idd
from xms.rsm.file_io import util
from xms.rsm.file_io.bc_val_writer import BcValWriter
from xms.rsm.file_io.impoundment_monitor_writer import ImpoundmentMonitorWriter
from xms.rsm.file_io.water_body_info import WaterBodyInfo


class ImpoundmentWriter:
    """Writer class for impoundments in the RSM control file."""
    def __init__(self, writer_data):
        """Constructor.

        Args:
            writer_data (WriterData): Class with information needed to writer model input files.
        """
        self.label_id = {}
        self._data = _WriterData(writer_data)
        self._ID = 0
        self._ug = None
        self._full_cells = None
        self._adj_cells = None

    def write(self):
        """Write the water mover data to the xml."""
        for cov, comp in self._data.xms_data.impoundment_coverages:
            if len(cov.polygons) < 1:
                continue
            self._data.cov = cov
            self._data.comp = comp
            self._data.cov_name = cov.name
            self._write_coverage()
        # put BCs after all the impoundments
        for bc in self._data.bcs:
            self._data.bcs_xml.append(bc)

    def write_monitor(self, xml_parent):
        """Write the monitor data to the xml.

        Args:
            xml_parent (xml.etree.cElementTree.Element): xml parent element
        """
        wm = ImpoundmentMonitorWriter(xml_parent, self._data.monitor_data)
        wm.write()

    @cached_property
    def _poly_snap(self):
        co_grid = self._data.xms_data.cogrid
        self._ug = co_grid.ugrid
        poly_snap = SnapPolygon()
        poly_snap.set_grid(co_grid, target_cells=True)
        return poly_snap

    def _write_coverage(self):
        """Writer the water mover coverage to the xml."""
        self._poly_snap.add_polygons(self._data.cov.polygons)
        for poly in self._data.cov.polygons:
            self._data.cur_poly_pts = get_polygon_point_lists(poly)
            self._data.poly_id = poly.id
            self._data.msg = f'Polygon id: "{poly.id}" in Impoundment coverage "{self._data.cov_name}"'
            comp_id = self._data.comp.get_comp_id(TargetType.polygon, poly.id)
            if comp_id is None or comp_id < 0:
                comp_id = util.UNINITIALIZED_COMP_ID
            if comp_id == util.UNINITIALIZED_COMP_ID:
                continue
            p_type, p_val = self._data.comp.data.feature_type_values(TargetType.polygon, comp_id)
            self._data.pp.restore_values(p_val)
            active_groups = self._data.pp.active_group_names
            if not active_groups:
                continue
            if 'impoundment' not in active_groups:  # if the lake group is not active then skip this polygon
                msg = f'{self._data.msg} skipped because the "Impoundment" item is not checked.'
                self._data.logger.warning(msg)
                continue

            try:
                self._get_poly_cells()
                self._write_impoundment()
                self._write_bc()
                self._store_monitor_info()
            except ValueError as msg:
                self._data.logger.warning(f'{self._data.msg} {msg}')
                continue

    def _get_poly_cells(self):
        """Get the grid cells for the polygon."""
        cell_idxs = list(self._poly_snap.get_cells_in_polygon(self._data.poly_id))
        if len(cell_idxs) < 1:
            raise ValueError('skipped because no grid cells are inside the polygon.')
        set_cells = set(cell_idxs)
        # get cells adjacent to the cells that are inside the polygon
        adj_cells = set()
        for idx in cell_idxs:
            num_edge = self._ug.get_cell_edge_count(idx)
            for edge_idx in range(num_edge):
                adj_cell = self._ug.get_cell_2d_edge_adjacent_cell(idx, edge_idx)
                if adj_cell not in set_cells:
                    adj_cells.add(adj_cell)
        self._full_cells = cell_idxs
        self._adj_cells = list(adj_cells)
        self._adj_cells.sort()

    def _write_impoundment(self):
        """Write the single control water mover to the xml."""
        self._data.id += 1
        gp = self._data.pp.group('impoundment')
        atts = {
            'id': f'{self._data.id}',
            'head': f'{gp.parameter("head").value}',
            'bottom': f'{gp.parameter("bottom").value}',
            'owCoeff': f'{gp.parameter("ow_coeff").value}',
            'swCoeff': f'{gp.parameter("sw_coeff").value}',
            'swDepth': f'{gp.parameter("sw_depth").value}',
        }
        lab = gp.parameter('label').value
        if lab:
            atts['label'] = lab
            self.label_id[lab] = self._data.id
        imp_xml = Et.SubElement(self._data.impoundments_xml, 'impoundment', atts)
        for cell_idx in self._full_cells:
            atts = {
                'cellId': f'{cell_idx + 1}',
                'fullCoverFlag': '1',
                'kOverDelV': f'{gp.parameter("k_over_delv").value}',
                'SCconf': f'{gp.parameter("sc_conf").value}',
            }
            Et.SubElement(imp_xml, 'cellConnect', atts)
        for cell_idx in self._adj_cells:
            if cell_idx < 0:
                continue
            atts = {
                'cellId': f'{cell_idx + 1}',
                'fullCoverFlag': '0',
                'kOverDelH': f'{gp.parameter("k_over_delh").value}',
            }
            Et.SubElement(imp_xml, 'cellConnect', atts)

        poly_pts = [(p[0], p[1]) for p in self._data.cur_poly_pts[0]]
        sh_poly = Polygon(poly_pts)
        self._data.impound_info.append(WaterBodyInfo(lab, self._data.id, sh_poly))

    def _write_bc(self):
        self._data.bc_grp = self._data.pp.group('bc')
        self._data.write_bc_val()

    def _store_monitor_info(self):
        """Store monitor information for the water mover."""
        active_grps = set(self._data.pp.active_group_names)
        if not active_grps.intersection(self._data.monitor_set):  # no monitor groups to write
            return
        pp = self._data.pp.copy()  # make a copy of the point parameters
        for gp in pp.group_names:  # remove groups that are not monitor groups
            if gp not in self._data.monitor_set or gp not in active_grps:
                pp.remove_group(gp)

        md = (self._data.id, pp, self._data.poly_id, self._data.cov_name)
        self._data.monitor_data.append(md)


class _WriterData:
    """Data class for writer data."""
    def __init__(self, writer_data):
        self.wd = writer_data
        self.logger = writer_data.logger
        self.xms_data = writer_data.xms_data
        self.xml_hse = writer_data.xml_hse
        self.impound_info = writer_data.water_body_info['impoundment']
        self.pp = idd.generic_model().polygon_parameters
        self.cov = None
        self.comp = None
        self.cov_name = None
        self.poly_id = 0
        self.cur_poly_pts = []
        self.id = self.xms_data.waterbody_start_id + 200_000
        self.msg = ''
        self.cur_bc_xml = None
        self.bc_grp = None
        self.bcs = []
        monitor_grps = [self.pp.group(nm) for nm in self.pp.group_names]
        self.monitor_set = set([gp.group_name for gp in monitor_grps if gp.label.startswith('Monitor ')])
        self.monitor_data = []

    @cached_property
    def impoundments_xml(self):
        """Get the mesh bc xml element."""
        return Et.SubElement(self.xml_hse, 'impoundments')

    @cached_property
    def bcs_xml(self):
        """Get the mesh bc xml element."""
        return Et.SubElement(self.impoundments_xml, 'impoundment_bc')

    def write_bc_val(self):
        """Write the boundary condition values to the xml.

        Args:
            bc_xml (xml.etree.cElementTree.Element): xml element for the boundary condition
        """
        gp = self.bc_grp
        if not gp.is_active:
            return
        atts = {'impoundmentID': f'{self.id}'}
        label = gp.parameter('label').value
        if label:
            atts['label'] = label
        bc_id = gp.parameter('bc_id').value
        if bc_id > 0:
            atts['id'] = str(bc_id)
        tag = gp.parameter('tag').value
        if tag:
            atts['tag'] = tag
        bc_type = gp.parameter('bc_type').value
        cur_bc_xml = Et.Element(bc_type, atts)
        d = self.wd
        bc_val = BcValWriter(cur_bc_xml, self.bc_grp, d.csv_writer, d.rule_curve_label_id, self.msg)
        bc_val.write()
        self.bcs.append(cur_bc_xml)
