"""Class for writing water movers to model input files."""
__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from functools import cached_property
import xml.etree.cElementTree as Et

# 2. Third party modules
from shapely.geometry import LineString, Polygon

# 3. Aquaveo modules
from xms.coverage.polygons.polygon_orienteer import get_polygon_point_lists
from xms.data_objects.parameters import FilterLocation
from xms.guipy.data.target_type import TargetType

# 4. Local modules
from xms.rsm.data import monitor_data_def as mdd
from xms.rsm.data import wcd_data_def as wcd_dd
from xms.rsm.file_io import util
from xms.rsm.file_io.bc_val_writer import BcValWriter
from xms.rsm.file_io.water_body_info import WaterBodyInfo
from xms.rsm.file_io.wcd_monitor_writer import WcdMonitorWriter


class WcdWriter:
    """Writer class for water movers in the RSM control file."""
    def __init__(self, writer_data):
        """Constructor.

        Args:
            writer_data (WriterData): Class with information needed to writer model input files.
        """
        self._data = _WriterData(writer_data)
        self._wcd_mover = _WcdMover(self._data)
        self._write_methods = {
            'wcd': self._wcd_poly,
            'bc': self._bc_poly,
        }

    def write(self):
        """Write the water mover data to the xml."""
        for cov, comp in self._data.xms_data.wcd_coverages:
            self._data.cov = cov
            self._data.comp = comp
            self._data.cov_name = cov.name
            self._write_coverage()
        # see if we need to write the initial condition
        self._write_init_cond()
        for wcd in self._data.wcds_xml:
            self._data.wcdwaterbodies_xml.append(wcd)
        # put BCs after all the basins
        for bc in self._data.bcs:
            self._data.wcd_bc_xml.append(bc)

    def write_monitor(self, xml_parent):
        """Write the monitor data to the xml.

        Args:
            xml_parent (xml.etree.cElementTree.Element): xml parent element
        """
        wm = WcdMonitorWriter(xml_parent, self._data.monitor_data)
        wm.write()
        wm = WcdMonitorWriter(xml_parent, self._wcd_mover.monitor_data, self._wcd_mover.monitor_wm_ids)
        wm.write()

    def _write_init_cond(self):
        """Write the initial conditions to the xml."""
        if not self._data.initial_cond:
            return
        # make sure that init_head specified for each wcd or give a warning and skip
        if len(self._data.initial_cond) != len(self._data.wcds_xml):
            msg = 'Initial head specified not specified for all WCD polygons so all initial heads will be ignored.'
            self._data.logger.warning(msg)
            self._data.initial_cond = []
            return
        with open('wcd_init_cond.dat', 'w') as f:
            f.write('netinit\n')
            for head in self._data.initial_cond:
                f.write(f'{head}\n')

    def _write_coverage(self):
        """Writer the water mover coverage to the xml."""
        self._data.set_feature_type(TargetType.arc)
        self._wcd_mover.monitor_set = self._data.monitor_set.copy()
        self._get_all_arcs_with_data()
        self._data.set_feature_type(TargetType.polygon)
        self._write_coverage_features()

    def _get_all_arcs_with_data(self):
        """Find all arcs with data and save them."""
        for item in self._data.features:
            comp_id = self._data.comp.get_comp_id(self._data.feature_type, item.id)
            if comp_id is None or comp_id < 0:
                comp_id = util.UNINITIALIZED_COMP_ID
            if comp_id == util.UNINITIALIZED_COMP_ID:
                continue
            p_type, p_val = self._data.comp.data.feature_type_values(self._data.feature_type, comp_id)
            self._data.gm_par.restore_values(p_val)
            active_groups = self._data.gm_par.active_group_names
            if 'wcd_canal' not in active_groups:
                continue

            arc_pts = [(pt.x, pt.y, 0.0) for pt in item.get_points(FilterLocation.PT_LOC_ALL)]
            self._wcd_mover.arcs_with_data.append((item, self._data.gm_par.copy(), LineString(arc_pts)))

    def _write_coverage_features(self):
        """Writer the water mover coverage to the xml."""
        for item in self._data.features:
            self._data.feature = item
            self._data.feature_id = item.id
            self._data.msg = f'{self._data.ftype_str} id: "{item.id}" in WCD coverage "{self._data.cov_name}"'
            comp_id = self._data.comp.get_comp_id(self._data.feature_type, item.id)
            if comp_id is None or comp_id < 0:
                comp_id = util.UNINITIALIZED_COMP_ID
            if comp_id == util.UNINITIALIZED_COMP_ID:
                continue
            p_type, p_val = self._data.comp.data.feature_type_values(self._data.feature_type, comp_id)
            self._data.gm_par.restore_values(p_val)
            active_groups = self._data.gm_par.active_group_names
            if 'wcd' not in active_groups:
                msg = f'{self._data.msg} was skipped because the "Water Control District" is not checked.'
                self._data.logger.warning(msg)
                continue

            try:
                self._wcd_poly()
            except ValueError as msg:
                msg = f'{self._data.msg} was skipped because {msg}'
                self._data.logger.warning(msg)

    def _wcd_poly(self):
        """Write the wcd polygon to the xml."""
        self._data.cur_grp = self._data.gm_par.group('wcd')
        sv_tab = self._data.cur_grp.parameter('sv').value
        if len(sv_tab) < 1:
            raise ValueError('no stage-volume table defined.')
        self._data.wcdID += 1
        atts = {'id': str(self._data.wcdID)}
        lab = self._data.cur_grp.parameter('label').value
        if lab:
            atts['label'] = lab
        self._wcd_mover.cur_wcd_xml = Et.Element('wcd', atts)
        table_xml = Et.SubElement(self._wcd_mover.cur_wcd_xml, 'sv')
        for item in sv_tab:
            line_elem = Et.SubElement(table_xml, 'remove_me')
            line_elem.text = f'{item[0]} {item[1]}'
        # store monitor info if we have created the wcd element
        self._store_monitor_info()
        self._bc_poly()
        self._data.wcds_xml.append(self._wcd_mover.cur_wcd_xml)
        b_init_head = self._data.cur_grp.parameter('b_init_head').value
        if b_init_head:
            init_head = self._data.cur_grp.parameter('init_head').value
            self._data.initial_cond.append(init_head)

        pts = get_polygon_point_lists(self._data.feature)
        pts = [(p[0], p[1]) for p in pts[0]]
        sh_poly = Polygon(pts)
        self._wcd_mover.process_canals_in_polygon(sh_poly)

        self._data.wcd_info.append(WaterBodyInfo(lab, self._data.wcdID, sh_poly))

    def _bc_poly(self):
        """Write the bc polygon to the xml."""
        grp = self._data.gm_par.group('bc')
        if not grp.is_active:
            return
        bc_type = grp.parameter('bc_type').value
        atts = {'id': str(self._data.wcdID)}
        lab = grp.parameter('label').value
        if lab:
            atts['label'] = lab
        bc_id = grp.parameter('bc_id').value
        if bc_id > 0:
            atts['bcID'] = str(bc_id)
        cur_bc_xml = Et.Element(bc_type, atts)

        d = self._data.wd
        bc_val = BcValWriter(cur_bc_xml, grp, d.csv_writer, d.rule_curve_label_id, self._data.msg)
        bc_val.write()
        self._data.bcs.append(cur_bc_xml)

    def _store_monitor_info(self):
        """Store monitor information for the water mover."""
        active_grps = set(self._data.gm_par.active_group_names)
        if not active_grps.intersection(self._data.monitor_set):  # no monitor groups to write
            return
        pp = self._data.gm_par.copy()  # make a copy of the feature parameters
        for gp in pp.group_names:  # remove groups that are not monitor groups
            if gp not in self._data.monitor_set or gp not in active_grps:
                pp.remove_group(gp)

        md = (self._data.wcdID, pp, self._data.feature_id, self._data.cov_name)
        self._data.monitor_data.append(md)


class _WriterData:
    """Data class for writer data."""
    def __init__(self, writer_data):
        self.wd = writer_data
        self.wcd_info = writer_data.water_body_info['wcd waterbody']
        self.wcdID = writer_data.xms_data.waterbody_start_id + 500_000
        self.logger = writer_data.logger
        self.xms_data = writer_data.xms_data
        self.xml_hse = writer_data.xml_hse
        self.cov = None
        self.comp = None
        self.cov_name = None
        self.feature_id = 0
        self.feature = None
        self.monitor_data = []
        self.msg = ''
        self.cur_grp = None
        self.bcs = []
        self.wcds_xml = []
        self.initial_cond = []  # list of starting heads, one for each wcd if specified
        # items associated with Polygons or Arcs
        self.ftype_str = ''
        self.feature_type = None
        self.features = None
        self.gm_par = None
        self.monitor_set = None

    @cached_property
    def wcdwaterbodies_xml(self):
        """Get the mesh wcd xml element."""
        atts = {}
        if self.initial_cond:
            atts['initialcondfile'] = 'wcd_init_cond.dat'
        return Et.SubElement(self.xml_hse, 'wcdwaterbodies', atts)

    @cached_property
    def wcd_bc_xml(self):
        """Get the mesh wcd xml element."""
        return Et.SubElement(self.wcdwaterbodies_xml, 'wcd_bc')

    def set_feature_type(self, target_type):
        self.feature_type = target_type
        if target_type == TargetType.polygon:
            self.ftype_str = 'Polygon'
            self.features = self.cov.polygons
            self.gm_par = wcd_dd.generic_model().polygon_parameters
        else:
            self.ftype_str = 'Arc'
            self.features = self.cov.arcs
            self.gm_par = wcd_dd.generic_model().arc_parameters
        monitor_grps = [self.gm_par.group(nm) for nm in self.gm_par.group_names]
        self.monitor_set = set([gp.group_name for gp in monitor_grps if gp.label.startswith('Monitor ')])


class _WcdMover:
    """Class for storing information about a water mover."""
    def __init__(self, writer_data):
        """Constructor.

        Args:
            writer_data (_WriterData): WCD writer data
        """
        self.wd = writer_data
        self.xms_data = writer_data.xms_data
        self.logger = writer_data.logger
        self.arcs_with_data = []
        self.poly_canals = []
        self.canal_cells = []
        self.monitor_data = []
        self.monitor_wm_ids = []
        self.cur_wcd_xml = None
        self._par_cpy = None

    def process_canals_in_polygon(self, polygon):
        """Find arcs in the polygon and store them."""
        self._find_arcs_in_polygon(polygon)
        for arc, gm_par, line in self.poly_canals:
            cell_idx = self._ugrid_cells_from_arc(line)
            if not cell_idx:
                msg = f'No cells found for canal arc id: "{arc.id}" associated with {self.wd.msg}. Skipping canal.'
                self.logger.warning(msg)
                continue
            self._intersect_arc_with_cells(line, cell_idx)
            self._write_canal_cells(gm_par)

    def _find_arcs_in_polygon(self, polygon):
        """Find arcs in the polygon and store them."""
        self.poly_canals = []
        for arc, gm_par, line in self.arcs_with_data:
            if not polygon.intersects(line):
                continue
            self.poly_canals.append((arc, gm_par, line))

    def _ugrid_cells_from_arc(self, arc_line):
        """Get the ugrid cell indexes from the arc.

        Args:
            arc_line (LineString): linestring for the arc

        Returns:
            list: cell indexes
        """
        ln_ext = self.poly_line_extractor
        ln_ext.set_polyline(arc_line.coords)
        ln_ext.extract_data()
        locs = [(p[0], p[1]) for p in ln_ext.extract_locations]
        canal_pts = []
        for idx, loc in enumerate(locs):
            if idx == 0:
                continue
            p0 = locs[idx - 1]
            pt = (0.5 * (p0[0] + loc[0]), 0.5 * (p0[1] + loc[1]), 0.0)
            canal_pts.append(pt)

        pt_ext = self.xms_data.ugrid_extractor
        pt_ext.extract_locations = canal_pts
        pt_ext.extract_data()

        cell_idx = []
        last = -1
        for idx in pt_ext.cell_indexes:
            if idx == -1:
                continue
            if last != idx:
                cell_idx.append(idx)
                last = idx
        return cell_idx

    @cached_property
    def poly_line_extractor(self):
        """Get the polyline extractor."""
        from xms.extractor.ugrid_2d_polyline_data_extractor import UGrid2dPolylineDataExtractor
        ug = self.xms_data.xmugrid
        scalar = [0.0] * ug.point_count
        act = [1] * len(scalar)
        extractor = UGrid2dPolylineDataExtractor(ugrid=ug, scalar_location='points')
        extractor.set_grid_scalars(scalar, act, 'points')
        return extractor

    def _intersect_arc_with_cells(self, arc_line, cell_idx):
        """Intersect the arc with the cells.

        Args:
            arc_line (LineString): linestring for the arc
            cell_idx (list): list of cell indexes
        """
        ug = self.xms_data.xmugrid
        self.canal_cells = []
        for idx in cell_idx:
            _, cell_poly = ug.get_cell_plan_view_polygon(idx)
            poly = Polygon(cell_poly)
            o = arc_line.intersection(poly)
            # no test cases for this yet
            # if type(o) is not LineString:
            #     poly = poly.buffer(0.95 * self._tol)
            #     o = arc_line.intersection(poly)
            # if o.length > self._tol:
            #     self.canal_cells.append((idx + 1, o.length))
            self.canal_cells.append((idx + 1, o.length))

    # @cached_property
    # def _tol(self):
    #     """Get a tolerance value based on the model domain size."""
    #     tol = 1.0e-9
    #     ug = self.xms_data.xmugrid
    #     ls = LineString(ug.extents)
    #     return tol * ls.length

    def _setup_monitor_data_item(self, gm_par):
        """Set up a monitor data item."""
        self._par_cpy = gm_par.copy()
        for gp in self._par_cpy.group_names:  # remove all groups to make a monitor group for use later
            self._par_cpy.remove_group(gp)
        gp = self._par_cpy.add_group(group_name='flow', label='wcdmover flow (<flow>)')
        gp.is_active = True
        mdd.add_output_to_group(gp)

    def _write_canal_cells(self, gm_par):
        """Write the canal cells to the xml."""
        self._setup_monitor_data_item(gm_par)
        cur_wcd_id = self.wd.wcdID

        grp = gm_par.group('wcd_canal')
        atts = {
            'cellid': '-1',  # filled in below
            'seepid': '-1',  # filled in below
            'bankid': '-1',  # filled in below
            'segwidth': str(grp.parameter('segwidth').value),
            'seglength': '-1',  # filled in below
            'botelv': str(grp.parameter('botelv').value),
            'leakagecoeff': str(grp.parameter('leakagecoeff').value),
            'bankheight': str(grp.parameter('bankheight').value),
            'bankcoeff': str(grp.parameter('bankcoeff').value),
        }
        for cell_id, length in self.canal_cells:
            self.wd.wcdID += 1
            seepid = self.wd.wcdID
            self.wd.wcdID += 1
            bankid = self.wd.wcdID
            atts['cellid'] = str(cell_id)
            atts['seepid'] = str(seepid)
            atts['bankid'] = str(bankid)
            atts['seglength'] = f'{round(length, 3)}'
            Et.SubElement(self.cur_wcd_xml, 'wcdmover', atts)
            self._store_monitor_data(gm_par, cell_id, seepid, bankid, cur_wcd_id)

    def _store_monitor_data(self, gm_par, cell_id, seepid, bankid, cur_wcd_id):
        """Store monitor data for the canal cell."""
        for item in [('seep_flow', seepid), ('bank_flow', bankid)]:
            mon_grp = gm_par.group(item[0])
            if mon_grp.is_active:
                par_cpy1 = self._par_cpy.copy()
                gp = par_cpy1.group('flow')
                for pname in gp.parameter_names:
                    gp.parameter(pname).value = mon_grp.parameter(pname).value
                # append cell id to label
                gp.parameter('label').value += f'_c{cell_id}'
                # if not dss append cell id to the file name
                if gp.parameter('output_type').value != mdd.MONITOR_DSS:
                    gp.parameter('out_file').value += f'_c{cell_id}'
                # replace %CELLID% in dss path
                pn = gp.parameter('dss_path').value
                pn = pn.replace('%CELLID%', f'_c{cell_id}')
                gp.parameter('dss_path').value = pn

                md = (cur_wcd_id, par_cpy1, self.wd.feature_id, self.wd.cov_name)
                self.monitor_data.append(md)
                self.monitor_wm_ids.append(item[1])
