"""Class for writing mesh properties to model input files."""
__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from dataclasses import dataclass, field
import xml.etree.cElementTree as Et

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.gmi.data.generic_model import Section
from xms.guipy.data.target_type import TargetType
from xms.snap.snap_polygon import SnapPolygon

# 4. Local modules
from xms.rsm.data import mesh_data_def as mdd
from xms.rsm.data.mesh_property_data import CoverageMeshPropertyData, MeshPropertyData, SimulationMeshPropertyData
from xms.rsm.file_io import util
from xms.rsm.file_io.hpm_data_writer import HpmDataWriter
from xms.rsm.file_io.prop_const_dataset_writer import PropConstDatasetWriter
from xms.rsm.file_io.rain_et_writer import RainETWriter


class MeshDataWriter:
    """Writer class for the RSM control file."""
    def __init__(self, writer_data):
        """Constructor.

        Args:
            writer_data (WriterData): Class with information needed to writer model input files.
        """
        self.filename = ''
        self._logger = writer_data.logger
        self._sim_mesh_prop = SimulationMeshPropertyData(writer_data.xms_data.sim_global_parameters)
        self._xms_data = writer_data.xms_data
        self._xml_mesh = writer_data.xml_mesh
        self._md_xml = _MeshPropertyDataXml(writer_data)
        self._cur_xml_node = None
        self._ug = self._xms_data.xmugrid
        self._poly_data = None
        self._mesh_dataset_cov = None
        self._mesh_dataset_comp = None
        self._cov_data_strings = []
        self._poly_data_to_cell = {}

    def write(self):
        """Write the mesh data portion of the control file."""
        # intersect coverages with the grid/mesh
        if self._xms_data.mesh_dataset_coverages:
            self._process_coverages()
        self._write_mesh_properties()

    def write_hpms(self):
        """Write the hpm portion of the control file."""
        self._md_xml.write_hpms()

    def _write_mesh_properties(self):
        """Write the mesh properties to the xml."""
        self._cur_xml_node = self._xml_mesh
        for prop in mdd.PROP_LIST:
            try:
                self._add_mesh_property(prop)
            except ValueError as e:
                self._logger.error(f'{e}')

    def _process_coverages(self):
        """Process the Mesh Dataset coverages."""
        self._logger.info('Processing Mesh Dataset coverages.')
        self._poly_snap = SnapPolygon()
        self._poly_snap.set_grid(self._xms_data.cogrid, target_cells=True)
        self._poly_data_to_cell = {prop: [] for prop in mdd.PROP_LIST}
        for cov_comp in self._xms_data.mesh_dataset_coverages:
            self._mesh_dataset_cov = cov_comp[0]
            self._mesh_dataset_comp = cov_comp[1]
            self._logger.info(f'Processing coverage: {self._mesh_dataset_cov.name}')

            self._poly_snap.add_polygons(self._mesh_dataset_cov.polygons)
            for poly in self._mesh_dataset_cov.polygons:
                self._get_poly_data(poly)

    def _polygon_parameters_from_mesh_ds_component(self, comp_id: int) -> Section:
        """Get the global parameters from the simulation component.

        Args:
            comp_id (int): The component id.

        Returns:
            (xms.gmi.data.generic_model.Section): The generic model section class.
        """
        pp = mdd.generic_model().polygon_parameters
        comp_data = self._mesh_dataset_comp.data
        ftype, fval = comp_data.feature_type_values(TargetType.polygon, comp_id)
        pp.restore_values(fval)
        return pp

    def _get_poly_data(self, poly):
        """Gets the data associated with the current arc.

        Args:
            poly (xms.data_objects.parameters.Spatial.Polygon.Polygon): coverage polygon
        """
        comp_id = self._mesh_dataset_comp.get_comp_id(TargetType.polygon, poly.id)
        if comp_id is None or comp_id < 0:
            comp_id = util.UNINITIALIZED_COMP_ID

        if comp_id == util.UNINITIALIZED_COMP_ID:
            return

        # get the cells that have the cell center in the polygon
        cell_idxs = list(self._poly_snap.get_cells_in_polygon(poly.id))
        if len(cell_idxs) < 1:
            return

        section = self._polygon_parameters_from_mesh_ds_component(comp_id)
        cov_mp = CoverageMeshPropertyData(section)
        for prop in mdd.PROP_LIST:
            mp = cov_mp.mesh_property_data(prop)
            if mp.option != mdd.OPT_NOT_SPECIFIED:
                cov_name = self._mesh_dataset_cov.name
                cmp = _CoverageMeshProp(mesh_prop=mp, cell_idxs=cell_idxs, poly_ids=[poly.id], cov_names=[cov_name])
                self._poly_data_to_cell[prop].append(cmp)

    def _preprocess_poly_data(self, prop):
        """Preprocess the polygon data so that we combine groups that have the same attributes.

        Args:
            prop (str): variable name of property
        """
        # if the global property is not specified and there are no polygons with the property then skip
        global_prop = self._sim_mesh_prop.mesh_property_data(prop)
        if global_prop.option == mdd.OPT_NOT_SPECIFIED and prop not in self._poly_data_to_cell:
            return []

        # preprocess the poly data so that we combine groups that have the same attributes
        poly_data_new = {}
        if global_prop.option != mdd.OPT_NOT_SPECIFIED:
            cov_mp = _CoverageMeshProp(
                mesh_prop=global_prop,
                cell_idxs=[i for i in range(self._ug.cell_count)],
                poly_ids=[0],
                cov_names=['model_control']
            )
            poly_data_new[global_prop.to_string()] = cov_mp
        if prop in self._poly_data_to_cell:
            for cov_mp in self._poly_data_to_cell[prop]:
                mp_str = cov_mp.mesh_prop.to_string()
                if mp_str not in poly_data_new:
                    poly_data_new[mp_str] = cov_mp
                else:
                    item = poly_data_new[mp_str]
                    item.cell_idxs.extend(cov_mp.cell_idxs)
                    item.poly_ids.append(cov_mp.poly_ids[0])
                    item.cov_names.append(cov_mp.cov_names[0])
        poly_data_new = [v for _, v in poly_data_new.items()]
        return poly_data_new

    def _set_index_values(self, poly_data_new):
        """Set the index values for the mesh properties.

        Args:
            poly_data_new (list): list of mesh property data
        """
        idx_set = set()
        # do one pass to get any specified index values and make sure that the index values are unique
        for cov_mp in poly_data_new:
            if cov_mp.mesh_prop.specify_index:
                index_val = cov_mp.mesh_prop.index_value
                if index_val in idx_set:
                    poly_id_cov_str = _poly_ids_cov_names_str(cov_mp)
                    msg = (
                        f'Specified index value "{index_val}" for polygon(s) is not unique for the following:\n'
                        f'(polygon id, coverage name)\n{poly_id_cov_str}'
                        f'The specified index value will be replaced with a unique value.'
                    )
                    self._logger.warning(msg)
                    cov_mp.mesh_prop.specify_index = False
                else:
                    idx_set.add(index_val)

        # now set the index on each mesh property
        next_index = 1
        for cov_mp in poly_data_new:
            mesh_prop = cov_mp.mesh_prop
            if not mesh_prop.specify_index:
                while next_index in idx_set:
                    next_index += 1
                mesh_prop.index_value = next_index
                idx_set.add(next_index)

    def _compute_index_array(self, poly_data_new, prop):
        """Compute the index array for the mesh properties.

        Args:
            poly_data_new (list): list of mesh property data
            prop (str): variable name of property
        """
        set_idx = set(p.mesh_prop.index_value for p in poly_data_new)
        undefined_idx = max(set_idx) + 1
        index = np.asarray([undefined_idx] * self._ug.cell_count)
        for cov_mp in poly_data_new:
            cell_idxs = np.asarray(cov_mp.cell_idxs, dtype=np.int32)
            index[cell_idxs] = cov_mp.mesh_prop.index_value
        if undefined_idx in index:
            msg = (
                f'Error writing mesh property "{prop}". This property requires an index '
                f'but not all grid cells have the property defined. Skipping this property.'
            )
            raise ValueError(msg)
        return index

    def _add_mesh_property(self, prop):
        """Add properties to the mesh portion of the XML.

        Args:
            prop (str): variable name of property
        """
        # preprocess the poly data so that we combine mesh_props that have the same values
        poly_data_new = self._preprocess_poly_data(prop)
        if not poly_data_new:
            return
        # compute the index array. If all the values in the index are the same then it will be written
        # as a global option. If not, then write the index.
        self._set_index_values(poly_data_new)  # use or make an index value for each property in poly_data_new
        index = self._compute_index_array(poly_data_new, prop)
        self._md_xml.write_mesh_property_xml(prop, poly_data_new, index)


@dataclass
class _CoverageMeshProp:
    """Data class for coverage mesh property."""
    mesh_prop: MeshPropertyData = None
    cell_idxs: list = field(default_factory=list)
    poly_ids: list = field(default_factory=list)
    cov_names: list = field(default_factory=list)


class _MeshPropertyDataXml:
    """Helper class for writing mesh property data to the input file."""
    def __init__(self, writer_data):
        """Constructor.

        Args:
            writer_data (WriterData): Class with information needed to writer model input files.
        """
        self.wd = writer_data
        self.xms_data = writer_data.xms_data
        self.xml_mesh = writer_data.xml_mesh
        self.csv_writer = writer_data.csv_writer
        self.logger = writer_data.logger
        self.ug = self.xms_data.xmugrid
        self._prop = None
        self._poly_data = None
        self._cur_poly_data = None
        self._mesh_prop = None
        self._index = None
        self._additional_xml = {
            mdd.PROP_TRANSMISSIVITY: self._transmissivity_more_xml,
            mdd.PROP_RAIN: self._rain_et_more_xml,
            mdd.PROP_REFET: self._rain_et_more_xml,
            mdd.PROP_CONVEYANCE: self._conveyance_more_xml,
            mdd.PROP_SVCONVERTER: self._svconverter_more_xml,
            mdd.PROP_HPM: self._hpm_more_xml,
        }
        self._hpms = []
        self._hpm_index = []

    def write_mesh_property_xml(self, prop, poly_data_new, index):
        """Write the index properties to the input file.

        Args:
            prop (str): variable name of property
            poly_data_new (list): list of mesh property data
            index (numpy.ndarray): index array
        """
        self._prop = prop
        set_index = set(index)
        poly_data_new = [p for p in poly_data_new if p.mesh_prop.index_value in set_index]
        self._poly_data = poly_data_new
        self._index = index
        set_index = set(index)
        if len(set_index) == 1:  # write a global property
            self._cur_poly_data = self._poly_data[0]
            self._write_property()
        else:  # write an indexed property
            self._write_property_index()

    def _write_property_index(self):
        """Write the index properties to the input file."""
        if self._prop == mdd.PROP_HPM:
            self._hpm_more_xml(None)
            return

        elem = Et.SubElement(self.xml_mesh, self._prop)
        util.export_ds_file(self._prop, self._index, 'index')
        index_filename = f'{self._prop}.index'
        idx_xml = Et.SubElement(elem, 'indexed', {'file': index_filename})
        for poly_data in self._poly_data:
            self._cur_poly_data = poly_data
            mesh_prop = poly_data.mesh_prop
            atts = {'id': f'{mesh_prop.index_value}'}
            if mesh_prop.label:
                atts['label'] = mesh_prop.label
            entry_xml = Et.SubElement(idx_xml, 'entry', atts)
            self._mesh_prop = mesh_prop
            self._write_mesh_prop_values(entry_xml)

    def _write_property(self):
        """Write the mesh property data to the input file."""
        elem = None
        if self._prop != mdd.PROP_HPM:
            elem = Et.SubElement(self.xml_mesh, self._prop)
        self._cur_poly_data = self._poly_data[0]
        self._mesh_prop = self._poly_data[0].mesh_prop
        self._write_mesh_prop_values(elem)

    def _const_xml_tag(self):
        """Get the XML tag for constant properties."""
        tag = 'const'
        if self._prop == mdd.PROP_TRANSMISSIVITY:
            tag = 'confined'
            if self._mesh_prop.transmissivity.unconfined:
                tag = 'unconfined'
        elif self._prop == mdd.PROP_SVCONVERTER:
            tag = 'constsv'
        return tag

    def _const_value_xml_tag(self):
        """Get the XML tag for constant value."""
        tag = 'value'
        if self._prop == mdd.PROP_TRANSMISSIVITY:
            tag = 'trans'
            if self._mesh_prop.transmissivity.unconfined:
                tag = 'k'
        elif self._prop == mdd.PROP_SVCONVERTER:
            tag = 'sc'
        return tag

    def _dataset_xml_tag(self):
        """Get the XML tag for dataset value."""
        tag = 'gms'
        if self._prop == mdd.PROP_TRANSMISSIVITY:
            tag = 'confined_gms'
            if self._mesh_prop.transmissivity.unconfined:
                tag = 'unconfined_gms'
            if self._mesh_prop.dataset.specify_layer:
                tag = tag + '_layer'
        elif self._prop == mdd.PROP_SVCONVERTER:
            tag = 'layersv'
        elif self._mesh_prop.dataset.specify_layer:
            tag = 'gmslayer'
        return tag

    def _write_mesh_prop_values(self, elem):
        """Write the mesh property values to the input file.

        Args:
            elem (xml.etree.cElementTree.Element): mesh property element
        """
        if self._mesh_prop.option in [mdd.OPT_CONSTANT, mdd.OPT_DATASET]:
            self._write_const_or_dataset(elem)
        elif self._prop in self._additional_xml:
            self._additional_xml[self._prop](elem)

    def _write_const_or_dataset(self, elem):
        """Write the constant or dataset mesh property values to the input file.

        Args:
            elem (xml.etree.cElementTree.Element): mesh property element
        """
        poly_id_str = _poly_ids_cov_names_str(self._cur_poly_data)
        # poly_id_str = ''
        # if self._cur_poly_data:
        #     poly_id_str = _poly_ids_cov_names_str(self._cur_poly_data)
        writer = PropConstDatasetWriter(
            self.xms_data, self._prop, self._mesh_prop.option, self._mesh_prop.const, self._const_xml_tag(),
            self._const_value_xml_tag(), self._mesh_prop.dataset, self._dataset_xml_tag(), elem, poly_id_str
        )
        writer.write()

    def _transmissivity_more_xml(self, elem):
        """Write the mesh property values to the input file.

        Args:
            elem (xml.etree.cElementTree.Element): mesh property element
        """
        if self._mesh_prop.option != mdd.OPT_TR_LOOKUP:
            raise RuntimeError('Unknown option for transmissivity property.')
        if not self._mesh_prop.transmissivity.lookup_table:
            msg = (
                'Error writing transmissivity lookup table. The table is not defined '
                'for the following feature(s):\n'
                '(polygon id, coverage name)\n'
                f'{_poly_ids_cov_names_str(self._cur_poly_data)}'
            )
            raise ValueError(msg)
        atts = {
            'below': f'{self._mesh_prop.transmissivity.lookup_below}',
            'above': f'{self._mesh_prop.transmissivity.lookup_above}',
        }
        lookup_xml = Et.SubElement(elem, 'lookup', atts)
        transm_xml = Et.SubElement(lookup_xml, 'transm')
        for item in self._mesh_prop.transmissivity.lookup_table:
            line_elem = Et.SubElement(transm_xml, 'remove_me')
            line_elem.text = f'{item[0]} {item[1]}'

    def _rain_et_more_xml(self, elem):
        """Write the mesh property values to the input file.

        Args:
            elem (xml.etree.cElementTree.Element): mesh property element
        """
        poly_ids_cov = csv_name = ''
        if self._cur_poly_data:
            poly_ids_cov = _poly_ids_cov_names_str(self._cur_poly_data)
            csv_name = f'{self._prop}_{self._cur_poly_data.cov_names[0]}_{self._cur_poly_data.poly_ids[0]}'
        writer = RainETWriter(
            option=self._mesh_prop.option,
            rain_et=self._mesh_prop.rain_et,
            xml_parent=elem,
            csv_writer=self.csv_writer,
            csv_fname=csv_name,
            poly_ids_cov=poly_ids_cov
        )
        writer.write()

    def _conveyance_more_xml(self, elem):
        """Write the mesh property values to the input file.

        Args:
            elem (xml.etree.cElementTree.Element): mesh property element
        """
        atts = {}
        err_str = ''
        if self._mesh_prop.option == mdd.OPT_CONVEY_MANNING:
            if self._mesh_prop.conveyance.specify_man_a:
                atts['a'] = f'{self._mesh_prop.conveyance.manning_a}'
            if self._mesh_prop.conveyance.specify_man_b:
                atts['b'] = f'{self._mesh_prop.conveyance.manning_b}'
            if self._mesh_prop.conveyance.specify_man_detent:
                atts['detent'] = f'{self._mesh_prop.conveyance.manning_detent}'
            if len(atts) < 1:
                err_str = 'Error writing conveyance mannings values. The values are not defined'
            else:
                Et.SubElement(elem, 'mannings', atts)
        elif self._mesh_prop.option == mdd.OPT_CONVEY_KADLEC:
            if self._mesh_prop.conveyance.specify_kad_k:
                atts['K'] = f'{self._mesh_prop.conveyance.kadlec_k}'
            if self._mesh_prop.conveyance.specify_kad_alpha:
                atts['alpha'] = f'{self._mesh_prop.conveyance.kadlec_alpha}'
            if self._mesh_prop.conveyance.specify_kad_beta:
                atts['beta'] = f'{self._mesh_prop.conveyance.kadlec_beta}'
            if self._mesh_prop.conveyance.specify_kad_detent:
                atts['detent'] = f'{self._mesh_prop.conveyance.kadlec_detent}'
            if len(atts) < 1:
                err_str = 'Error writing conveyance kadlec values. The values are not defined'
            else:
                Et.SubElement(elem, 'kadlec', atts)
        elif self._mesh_prop.option == mdd.OPT_CONVEY_LOOKUP:
            if not self._mesh_prop.conveyance.lookup_table:
                err_str = 'Error writing conveyance lookup table. The table is not defined'
            else:
                atts = {
                    'below': f'{self._mesh_prop.conveyance.lookup_below}',
                    'above': f'{self._mesh_prop.conveyance.lookup_above}',
                    'base': f'{self._mesh_prop.conveyance.lookup_base}',
                    'exponent': f'{self._mesh_prop.conveyance.lookup_exponent}',
                }
                lookup_xml = Et.SubElement(elem, 'lookup', atts)
                convey_xml = Et.SubElement(lookup_xml, 'convey')
                for item in self._mesh_prop.conveyance.lookup_table:
                    line_elem = Et.SubElement(convey_xml, 'remove_me')
                    line_elem.text = f'{item[0]} {item[1]}'
        else:
            raise RuntimeError('Unknown option for conveyance property.')
        if err_str:
            msg = (
                f'{err_str} for the following feature(s):\n'
                '(polygon id, coverage name)\n'
                f'{_poly_ids_cov_names_str(self._cur_poly_data)}'
            )
            raise ValueError(msg)

    def _svconverter_more_xml(self, elem):
        """Write the mesh property values to the input file.

        Args:
            elem (xml.etree.cElementTree.Element): mesh property element
        """
        if self._mesh_prop.option != mdd.OPT_SV_LOOKUP:
            raise RuntimeError('Unknown option for svconverter property.')
        if not self._mesh_prop.svconverter.lookup_table:
            msg = (
                'Error writing svconverter lookup table. The table is not defined '
                'for the following feature(s):\n'
                '(polygon id, coverage name)\n'
                f'{_poly_ids_cov_names_str(self._cur_poly_data)}'
            )
            raise ValueError(msg)
        atts = {
            'sc': f'{self._mesh_prop.svconverter.lookup_sc}',
            'below': f'{self._mesh_prop.svconverter.lookup_below}',
            'above': f'{self._mesh_prop.svconverter.lookup_above}',
        }
        lookup_xml = Et.SubElement(elem, 'lookup', atts)
        sv_xml = Et.SubElement(lookup_xml, 'sv')
        for item in self._mesh_prop.svconverter.lookup_table:
            line_elem = Et.SubElement(sv_xml, 'remove_me')
            line_elem.text = f'{item[0]} {item[1]}'

    def _hpm_more_xml(self, elem):
        """Write the mesh property values to the input file.

        Args:
            elem (xml.etree.cElementTree.Element): mesh property element
        """
        self._hpms = self._poly_data.copy()
        self._hpm_index = self._index

    def write_hpms(self):
        """Write the hpm portion of the control file."""
        if not self._hpms:
            return
        hpm_root_xml = Et.SubElement(self.xml_mesh, 'hpModules')
        util.export_ds_file('hpm', self._hpm_index, 'index')
        _ = Et.SubElement(hpm_root_xml, 'indexed', {'file': 'hpm.index'})
        id_str = [_poly_ids_cov_names_str(hpm) for hpm in self._hpms]
        hpms = [p.mesh_prop for p in self._hpms]
        writer = HpmDataWriter(self.wd, hpm_root_xml, hpms, id_str)
        writer.write_hpm()


def _poly_ids_cov_names_str(cov_mp):
    """Get the string for the polygon ids and coverage names.

    Args:
        cov_mp (CoverageMeshProp): coverage mesh property

    Returns:
        str: string of polygon ids and coverage names
    """
    pids, names = cov_mp.poly_ids, cov_mp.cov_names
    ss = ''.join(f'({pid}, {name})\n' for pid, name in zip(pids, names))
    rstr = 'Property specified in the model control (not assigned to a polygon)'
    ss = ss.replace('(0, model_control)', rstr)
    return ss
