"""Class for writing basins to model input files."""
__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from functools import cached_property
import xml.etree.cElementTree as Et

# 2. Third party modules
from shapely.geometry import Polygon

# 3. Aquaveo modules
from xms.coverage.polygons.polygon_orienteer import get_polygon_point_lists
from xms.guipy.data.target_type import TargetType

# 4. Local modules
from xms.rsm.data import basin_data_def as bdd
from xms.rsm.data.mesh_property_data import mesh_property_data_from_hpm_group
from xms.rsm.file_io import util
from xms.rsm.file_io.basin_monitor_writer import BasinMonitorWriter
from xms.rsm.file_io.bc_val_writer import BcValWriter
from xms.rsm.file_io.hpm_data_writer import HpmDataWriter
from xms.rsm.file_io.water_body_info import WaterBodyInfo
from xms.rsm.file_io.water_body_rain_et_writer import WaterBodyRainEtData, WaterBodyRainEtWriter


class BasinWriter:
    """Writer class for basins in the RSM control file."""
    def __init__(self, writer_data):
        """Constructor.

        Args:
            writer_data (WriterData): Class with information needed to writer model input files.
        """
        self.label_id = {}
        self._data = _WriterData(writer_data)
        self._write_methods = {
            'rain': self._write_rain,
            'refet': self._write_refet,
            'hpm': self._write_hpm,
            # 'natural_system': self._write_natural_system,
            # 'agricultural': self._write_agricultural,
            'bc': self._write_bc,
        }

    def write(self):
        """Write the water mover data to the xml."""
        for cov, comp in self._data.xms_data.basin_coverages:
            self._data.cov = cov
            self._data.comp = comp
            self._data.cov_name = cov.name
            self._write_coverage()
        # put BCs after all the basins
        for bc in self._data.bcs:
            self._data.bcs_xml.append(bc)

    def write_monitor(self, xml_parent):
        """Write the monitor data to the xml.

        Args:
            xml_parent (xml.etree.cElementTree.Element): xml parent element
        """
        bm = BasinMonitorWriter(xml_parent, self._data.monitor_data)
        bm.write()

    def _write_coverage(self):
        """Writer the water mover coverage to the xml."""
        for poly in self._data.cov.polygons:
            self._data.cur_poly_pts = get_polygon_point_lists(poly)
            self._data.poly_id = poly.id
            self._data.msg = f'Polygon id: "{poly.id}" in Basin coverage "{self._data.cov_name}"'
            comp_id = self._data.comp.get_comp_id(TargetType.polygon, poly.id)
            if comp_id is None or comp_id < 0:
                comp_id = util.UNINITIALIZED_COMP_ID
            if comp_id == util.UNINITIALIZED_COMP_ID:
                continue
            p_type, p_val = self._data.comp.data.feature_type_values(TargetType.polygon, comp_id)
            self._data.pp.restore_values(p_val)
            active_groups = self._data.pp.active_group_names
            if not active_groups:
                continue
            if 'basin' not in active_groups:  # if the lake group is not active then skip this polygon
                msg = f'{self._data.msg} skipped because the "Basin" item is not checked.'
                self._data.logger.warning(msg)
                continue

            self._write_basin()
            for grp_name in active_groups:
                if grp_name in self._write_methods:
                    self._write_methods[grp_name]()
            self._store_monitor_info()

    def _write_basin(self):
        """Write the basin to the xml."""
        # make sure the table is defined
        gp = self._data.pp.group('basin')
        sv_tab = gp.parameter('sv').value
        if len(sv_tab) < 1:
            msg = f'{self._data.msg} has no stage-volume table defined. Aborting.'
            raise RuntimeError(msg)

        self._data.id += 1
        atts = {
            'id': f'{self._data.id}',
            'init_head': f'{gp.parameter("init_head").value}',
            'area': f'{gp.parameter("area").value}',
            'elev': f'{gp.parameter("elev").value}',
            'above': f'{gp.parameter("above").value}',
            'below': f'{gp.parameter("below").value}',
            'sscoef': f'{gp.parameter("sscoef").value}',
        }
        lab = gp.parameter('label').value
        if lab:
            atts['label'] = lab
            self.label_id[lab] = self._data.id
        self._data.basin_xml = Et.SubElement(self._data.basins_xml, 'basin', atts)
        self._write_sv_table(sv_tab)

        poly_pts = [(p[0], p[1]) for p in self._data.cur_poly_pts[0]]
        sh_poly = Polygon(poly_pts)
        self._data.basin_info.append(WaterBodyInfo(lab, self._data.id, sh_poly))

    def _write_sv_table(self, tab):
        """Write a table to the xml."""
        table_xml = Et.SubElement(self._data.basin_xml, 'sv')
        for item in tab:
            line_elem = Et.SubElement(table_xml, 'remove_me')
            line_elem.text = f'{item[0]} {item[1]}'

    def _write_rain(self):
        """Write the rain data to the xml."""
        writer = _BasinRainEtWriter('rain', self._data)
        writer.write()

    def _write_refet(self):
        """Write the rain data to the xml."""
        writer = _BasinRainEtWriter('refet', self._data)
        writer.write()

    def _write_hpm(self):
        grp = self._data.pp.group('hpm')
        mp = mesh_property_data_from_hpm_group(grp)
        mp.index_value = self._data.id
        writer = HpmDataWriter(self._data.wd, self._data.basins_xml, [mp], [self._data.msg])
        writer.write_hpm()

    def _write_bc(self):
        bc = _BasinBcWriter(self._data)
        bc.write()

    def _store_monitor_info(self):
        """Store monitor information for the water mover."""
        active_grps = set(self._data.pp.active_group_names)
        if not active_grps.intersection(self._data.monitor_set):  # no monitor groups to write
            return
        pp = self._data.pp.copy()  # make a copy of the point parameters
        for gp in pp.group_names:  # remove groups that are not monitor groups
            if gp not in self._data.monitor_set or gp not in active_grps:
                pp.remove_group(gp)

        md = (self._data.id, pp, self._data.poly_id, self._data.cov_name)
        self._data.monitor_data.append(md)


class _WriterData:
    """Data class for writer data."""
    def __init__(self, writer_data):
        self.logger = writer_data.logger
        self.wd = writer_data
        self.basin_info = writer_data.water_body_info['basin']
        self.xms_data = writer_data.xms_data
        self.xml_hse = writer_data.xml_hse
        self.csv_writer = writer_data.csv_writer
        self.pp = bdd.generic_model().polygon_parameters
        self.cov = None
        self.comp = None
        self.cov_name = None
        self.poly_id = 0
        self.cur_poly_pts = []
        self.id = self.xms_data.waterbody_start_id + 300_000
        self.bc_grp = None
        self.bcs = []
        self.msg = ''
        self.basin_xml = None
        monitor_grps = [self.pp.group(nm) for nm in self.pp.group_names]
        self.monitor_set = set([gp.group_name for gp in monitor_grps if gp.label.startswith('Monitor ')])
        self.monitor_data = []

    @cached_property
    def basins_xml(self):
        """Get the mesh bc xml element."""
        return Et.SubElement(self.xml_hse, 'basins')

    @cached_property
    def bcs_xml(self):
        """Get the mesh bc xml element."""
        return Et.SubElement(self.basins_xml, 'basin_bc')


class _BasinRainEtWriter:
    """Class for writing basin rain/ET data to the xml."""
    def __init__(self, group_name, data):
        """Constructor."""
        wb_name = 'Basin'
        wbd = WaterBodyRainEtData(
            data.xms_data, group_name, data.basin_xml, data.pp, data.csv_writer, wb_name, data.poly_id, data.cov_name
        )
        self.wb_rain_et_writer = WaterBodyRainEtWriter(wbd)

    def write(self):
        """Write the rain or refet data to the xml."""
        self.wb_rain_et_writer.write()


class _BasinBcWriter:
    """Class for writing basin boundary conditions to the xml."""
    def __init__(self, data):
        """Constructor."""
        self.data = data
        self.data.bc_grp = self.data.pp.group('bc')
        self.lake_ids = []
        self.b_xml = None

    def write(self):
        gp = self.data.bc_grp
        bc_type = gp.parameter('bc_type').value
        atts = {'basinID': f'{self.data.id}'}
        label = gp.parameter('label').value
        if label:
            atts['label'] = label
        bc_id = gp.parameter('bc_id').value
        if bc_id > 0:
            atts['id'] = str(bc_id)
        tag = gp.parameter('tag').value
        if tag:
            atts['tag'] = tag
        package = gp.parameter('package').value
        if bc_type == 'basinsource' and package == 'areaCorrected':
            atts['package'] = package
            self._get_lake_ids()
        cur_bc_xml = Et.Element(bc_type, atts)
        d = self.data.wd
        bc_val = BcValWriter(cur_bc_xml, gp, d.csv_writer, d.rule_curve_label_id, self.data.msg)
        bc_val.write()
        if self.b_xml is not None:
            cur_bc_xml.append(self.b_xml)
        self.data.bcs.append(cur_bc_xml)

    def _get_lake_ids(self):
        gp = self.data.bc_grp
        specify = gp.parameter('specify_lakes').value
        if specify:
            self._specified_lake_ids()
        else:
            self._lake_polys_in_basin()
        self.b_xml = Et.Element('B')
        self.b_xml.text = ' '.join([str(lake_id) for lake_id in self.lake_ids])

    def _lake_polys_in_basin(self):
        basin_poly = self.data.basin_info[-1].geometry
        lake_info = self.data.wd.water_body_info['lake']
        for item in lake_info:
            if basin_poly.contains(item.geometry):
                self.lake_ids.append(item.wb_id)
        if not self.lake_ids:
            msg = f'{self.data.msg} has no lake polygons within the basin polygon. Aborting.'
            raise RuntimeError(msg)

    def _specified_lake_ids(self):
        """Get the lake ids specified in the boundary condition."""
        gp = self.data.bc_grp
        tab = gp.parameter('lakes').value
        lake_info = self.data.wd.water_body_info['lake']
        label_to_id = {item.label: item.wb_id for item in lake_info if item.label}
        for item in tab:
            lake_label = item[0].strip()
            if not lake_label:
                continue
            if lake_label not in label_to_id:
                msg = f'Invalid lake label "{lake_label}" was skipped for {self.data.msg}.'
                self.data.logger.warning(msg)
                continue
            self.lake_ids.append(label_to_id[lake_label])

        if not self.lake_ids:
            msg = f'No valid lakes specified with boundary condition for {self.data.msg}. Aborting.'
            raise RuntimeError(msg)
