"""Class for writing rule curves to model input files."""
__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import calendar
from functools import cached_property
import xml.etree.cElementTree as Et

# 2. Third party modules

# 3. Aquaveo modules

# 4. Local modules


class RuleCurveWriter:
    """Writer class for rule curves in the RSM control file."""
    def __init__(self, writer_data):
        """Constructor.

        Args:
            writer_data (WriterData): Class with information needed to writer model input files.
        """
        self._wd = writer_data
        self._start_id = 1
        self._comp = None
        self._cov_name = ''
        self._msg_part = ''

    def write(self):
        """Write the water mover data to the xml."""
        for cov, comp in self._wd.xms_data.rule_curve_coverages:
            self._comp = comp
            self._cov_name = cov.name
            self._write_coverage()

    def _write_coverage(self):
        """Writer the water mover coverage to the xml."""
        section = self._comp.data.generic_model.material_parameters
        section.restore_values(self._comp.data.values)

        curve_id = self._start_id
        for grp_name in section.group_names:
            grp = section.group(grp_name)
            curve_name = grp.label
            if curve_name in self._wd.rule_curve_label_id:
                msg = (
                    f'Curve: "{curve_name}" in coverage "{self._cov_name}" was skipped because a rule curve with '
                    f'that name was already included in the simulation.'
                )
                self._wd.logger.warning(msg)
                continue
            self._msg_part = f' for curve: "{curve_name}" in coverage "{self._cov_name}". Curve skipped.'
            xunits = grp.parameter('xunits').value
            yunits = grp.parameter('yunits').value
            cycle = grp.parameter('cycle').value
            dtype = grp.parameter('type').value
            tab_vals = grp.parameter('table').value
            if len(tab_vals) < 1 or not yunits:
                if curve_name != 'Default curve':
                    self._yunits_warning(yunits, curve_name)
                    self._table_warning(tab_vals, curve_name)
            elif not self._table_is_valid(tab_vals):
                pass
            else:
                atts = {
                    'id': str(curve_id),
                    'name': curve_name,
                    'xunits': xunits,
                    'yunits': yunits,
                    'cycle': cycle,
                    'type': dtype
                }
                self._wd.rule_curve_label_id[curve_name] = curve_id
                curve_id += 1
                rc = Et.SubElement(self._rule_curve_xml, 'rcentry', atts)
                for item in tab_vals:
                    line_elem = Et.SubElement(rc, 'remove_me')
                    line_elem.text = f'{item[0]} {item[1]}'

    def _yunits_warning(self, yunits, crv):
        """Log a warning if yunits is not specified."""
        if not yunits:
            msg = 'No "yunits" units specified' + self._msg_part
            self._wd.logger.warning(msg)

    def _table_warning(self, tab_vals, crv):
        """Log a warning if table values are not specified."""
        if len(tab_vals) < 1:
            msg = 'No "Rule curve table" values specified' + self._msg_part
            self._wd.logger.warning(msg)

    def _table_is_valid(self, tab_vals):
        """Make sure the x values in the table are valid."""
        date_strs = [item[0] for item in tab_vals]
        bad_dates = []
        months = {month: idx for idx, month in enumerate(calendar.month_abbr) if month}
        set_months = set(months.keys())
        for date_str in date_strs:
            date_str = date_str.strip()
            if len(date_str) != 5:
                bad_dates.append(date_str)
                continue
            try:
                day = int(date_str[:2])
                if day < 1:
                    raise ValueError
            except ValueError:
                bad_dates.append(date_str)
            month = date_str[2:5].title()
            if month not in set_months:
                bad_dates.append(date_str)
                continue
            if day > calendar.monthrange(2000, months[month])[1]:
                bad_dates.append(date_str)
                continue

        if bad_dates:
            msg = f'Invalid date values "{bad_dates}" in rule curve table' + self._msg_part
            self._wd.logger.warning(msg)
            return False
        return True

    @cached_property
    def _rule_curve_xml(self):
        """Get the mesh bc xml element."""
        return Et.SubElement(self._wd.xml_hse, 'rulecurves')
