"""Class for managing interprocess communication with XMS."""

# 1. Standard Python modules
from io import StringIO
from logging import Logger
from pathlib import Path
from typing import Optional, Sequence, TextIO

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.constraint.ugrid_2d import UGrid2d
from xms.coverage.xy.xy_series import XySeries
from xms.data_objects.parameters import Projection
from xms.gmi.data.generic_model import GenericModel, Group, GroupSet, Parameter, Type, UNASSIGNED_MATERIAL_ID
from xms.guipy.dialogs.log_timer import Timer
from xms.snap import SnapExteriorArc, SnapInteriorArc, SnapPoint

# 4. Local modules
from xms.hydroas.file_io.errors import Messages as Message, write_error
from xms.hydroas.file_io.extra_io import write_extra_file


class GmiWriter:
    """A class for writing out a GMI to disk."""
    def __init__(
        self,
        logger: Optional[Logger] = None,
        ugrid: Optional[UGrid2d] = None,
        name: str = '',
        model: Optional[GenericModel] = None,
        model_instantiation: Optional[str] = None,
        global_instantiation: Optional[str] = None,
        points: Optional[list[int]] = None,
        point_values: Optional[list[str]] = None,
        arcs: Optional[list[list[int]]] = None,
        arc_ids: Optional[list[int]] = None,
        arc_values: Optional[list[str]] = None,
        arc_names: Optional[list[str]] = None,
        material_names: Optional[list[str]] = None,
        material_numbers: Optional[list[int]] = None,
        material_groups: Optional[list[Group]] = None,
        cell_materials: Optional[Sequence[int]] = None,
        projection: Optional[Projection] = None,
        xms_version: Optional[str] = None,
    ):
        """
        Initialize the writer.

        Args:
            logger: Where to log any output.
            ugrid: A `UGrid` describing the mesh to be written. If `None`, no points or cells are written and `points`,
                `arcs`, and `materials` must all be empty.
            name: Name of the mesh. Applied to the `MESHNAME` card.
            model: Definitions for the GMI parameters to store in the file. If `None`, all sections are treated as
                empty. If a section is empty, no definitions are written for that section, and no values should be
                provided for it.
            model_instantiation: Values for variables in the model section. Cards corresponding to unassigned values
                will not be written. If `None` or empty, all values are considered unassigned.
            global_instantiation: Values for variables in the global section. Value cards will only be written for
                variables with values defined in this parameter. If `None` or empty, no global value cards will be
                written, but any definitions for them will still be written.
            points: Points which have values assigned to them. Requires `ugrid` if this is a nonempty list. Each element
                is an index into the point list in `ugrid`. Duplicates will raise `GmiError`.
            point_values: Values to assign to points. Parallel to `points`.
            arcs: Arcs which have values assigned to them. Requires `ugrid` if this is a nonempty list. Each element is
                a list of indexes into the point list in `ugrid`.
            arc_ids: IDs to assign each arc. Parallel to `arcs`. Duplicates will raise `GmiError`.
            arc_values: Values to assign to arcs. Parallel to `arcs`.
            arc_names: Names to assign to arcs. Parallel to `arcs`.
            material_names: Names of materials. Parallel to `material_groups`. Duplicate names are allowed and result in
                two distinct materials with the same name.
            material_numbers: Integer IDs of materials. Parallel to `material_groups`. Duplicate numbers are forbidden
                and raise an exception.
            material_groups: `Group`s with variables assigned. parallel to `materials`. These values are used as the
                values for the corresponding material.
            cell_materials: A mapping from cell_index -> material_index. cell_index is the *zero-based* index of a cell
                in the UGrid, and material_index is the *one-based* index of a material in material_names or
                material_groups. A material_index of 0 indicates no material is applied to the cell. An empty sequence
                or None is treated as being all zeros.
            projection: The projection that `ugrid` is in.
            xms_version: The current XMS version string.
        """
        self._logger = logger if logger else Logger('xms.gmi.writer')
        self._out: Optional[TextIO] = None
        self.__point_snapper: Optional[SnapPoint] = None
        self.__interior_arc_snapper: Optional[SnapInteriorArc] = None
        self.__exterior_arc_snapper: Optional[SnapExteriorArc] = None
        self._curves: list[XySeries] = []

        self._xms_version = xms_version

        self._ugrid = ugrid
        self._mesh_name = name
        self._projection = projection

        self._model = model
        self._model_instantiation = model_instantiation
        self._global_instantiation = global_instantiation

        self._points = points or []
        self._point_values = point_values or []

        self._arcs = arcs or []
        self._arc_ids = arc_ids or []
        self._arc_values = arc_values or []
        self._arc_names = arc_names or []

        self._material_names = material_names or []
        self._material_numbers = material_numbers or []
        self._material_values = material_groups or []
        if cell_materials is not None:
            self._cell_materials = cell_materials
        else:
            length = ugrid.ugrid.cell_count if ugrid else 0
            self._cell_materials = np.zeros(length, dtype=int)

    def write(self, path: Path):
        """
        Write the file.

        Args:
            path: Where to write the output.
        """
        with open(path, 'w') as file:
            # It would be simpler to just use the file here, but it turns out writing to memory and then dumping
            # everything to disk all at once reduces the disk I/O time by about half, which adds up to a few seconds
            # on the kind of huge meshes Hydro-AS uses.
            self._out = StringIO()

            if (self._points or self._arcs) and not self._ugrid:
                write_error(Message.features_without_domain, 0)
            if np.any(self._cell_materials) and not self._ugrid:
                write_error(Message.materials_without_domain, 0)

            self._write_mesh()
            self._write_definitions()
            self._write_values()
            self._write_curve_data()
            self._write_extra_file(path)

            self._out.seek(0, 0)
            file.write(self._out.read())

    def _write_mesh(self):
        """Write the mesh part of the file."""
        self._logger.info('Writing mesh header...')
        self._out.write('MESH2D\n')

        if not self._ugrid:
            return

        self._out.write(f'MESHNAME {_str(self._mesh_name)}\n')
        self._out.write('NUM_MATERIALS_PER_ELEM 1\n')
        self._logger.info('Snapping material polygons to mesh...')

        self._logger.info('Writing mesh elements...')
        grid = self._ugrid.ugrid
        timer = Timer()
        assert len(self._cell_materials) == self._ugrid.ugrid.cell_count
        for cell_index, material_id in enumerate(self._cell_materials):
            if timer.report_due:
                self._logger.info(f'Wrote {cell_index} elements...')
            point_ids = [pt_idx + 1 for pt_idx in grid.get_cell_points(cell_index)]
            if len(point_ids) == 3:
                self._out.write(f'E3T {cell_index+1} {point_ids[0]} {point_ids[1]} {point_ids[2]} {material_id}\n')
            elif len(point_ids) == 4:
                self._out.write(
                    f'E4Q {cell_index+1} {point_ids[0]} {point_ids[1]} {point_ids[2]} {point_ids[3]} {material_id}\n'
                )
            else:  # pragma: no cover
                raise AssertionError('Invalid number of nodes')  # New element card added without support

        self._logger.info('Writing mesh nodes...')
        for node_id, (x, y, z) in enumerate(grid.locations, start=1):
            if timer.report_due:
                self._logger.info(f'Wrote {node_id} nodes...')
            self._out.write(f'ND {node_id} {_float(x)} {_float(y)} {_float(z)}\n')

        self._write_nodestrings()

    def _write_nodestrings(self):
        """Write the NS cards, if necessary."""
        if not self._arcs:
            return

        self._logger.info('Writing nodestrings...')

        used_ids = set()

        for arc, arc_id, name in zip(self._arcs, self._arc_ids, self._arc_names, strict=True):
            if arc_id in used_ids:
                raise write_error(Message.duplicate_arc, arc_id)
            used_ids.add(arc_id)
            self._write_nodestring(arc, arc_id, name)

    def _write_nodestring(self, arc: list[int], arc_id: int, name: str):
        """
        Write a nodestring to the output.

        Args:
            arc: List of indexes of nodes in the mesh defining the arc.
            arc_id: ID to assign the arc.
            name: Optional name of the arc.
        """
        arc: list[int | str] = [index + 1 for index in arc]
        arc[-1] = -arc[-1]
        arc.append(arc_id)

        min_length = 2  # Last node ID + the arc ID
        if name:
            arc.append(_str(name))
            min_length += 1  # Add the arc's name

        chunks = _chunk(arc)
        while len(chunks[-1]) < min_length:
            # The last chunk must contain the last node ID. If it doesn't, we take elements from the second-to-last
            # chunk until it does.
            chunks[-1].insert(0, chunks[-2].pop())

        for chunk in chunks:
            self._out.write('NS')
            for item in chunk:
                self._out.write(f' {item}')
            self._out.write('\n')

    def _should_write_definitions(self):
        if self._model and not self._model.is_default():
            return True

        if self._material_names:
            return True

        return False

    def _write_definitions(self):
        """Write all the definitions."""
        if not self._should_write_definitions():
            return

        self._out.write('BEGPARAMDEF\n')

        if self._model:
            self._write_model_definitions()
            self._write_global_definitions()
            self._write_boundary_definitions(self._model.point_parameters, 0)
            self._write_boundary_definitions(self._model.arc_parameters, 1)
            self._write_boundary_definitions(self._model.polygon_parameters, 2)

        if self._material_names:
            self._write_material_names()

        if self._model:
            self._write_material_definitions()

        self._out.write('ENDPARAMDEF\n')

    def _write_model_definitions(self):
        """Write the model definitions."""
        if not self._model or not self._model_instantiation:
            return

        model_variables = self._model.model_parameters
        model_variables.restore_values(self._model_instantiation)
        model_group = model_variables.group('model')

        model_name = model_group.parameter('model_name')
        if model_name.has_value:
            self._out.write(f'GM {_str(model_name.value)}\n')

        distance_units = model_group.parameter('distance_units')
        if distance_units.has_value:
            value = {'feet': 0, 'meters': 1, 'geographic (lat/lon)': 2, 'feet (international)': 3}[distance_units.value]
            if value < 3:
                self._out.write(f'SI {value}\n')
            else:
                self._out.write('SI 0 "International"\n')

        is_dynamic = model_group.parameter('is_dynamic')
        if is_dynamic.has_value:
            self._out.write(f'DY {_bool(is_dynamic.value)}\n')

        time_units = model_group.parameter('time_units')
        if time_units.has_value:
            self._out.write(f'TU {_str(time_units.value)}\n')

        time_step = model_group.parameter('time_step')
        total_time = model_group.parameter('total_time')
        if time_step.has_value or total_time.has_value:
            self._out.write(f'TD {_float(time_step.value)} {_float(total_time.value)}\n')

        key = model_group.parameter('key')
        if key.has_value:
            self._out.write(f'KEY {_str(key.value)}\n')

    def _write_global_definitions(self):
        """Write global definitions."""
        section = self._model.global_parameters.copy()
        section.restore_values(self._global_instantiation)

        for group_id in section.group_names:
            group = section.group(group_id)
            self._out.write(f'GP {group_id} {_str(group.label)} {_bool(group.is_active)}\n')
            for parameter_id in group.parameter_names:
                self._write_parameter('GP', section, group_id, parameter_id)

    def _write_boundary_definitions(self, section: GroupSet, entity_id):
        """Write boundary definitions."""
        for group_id in section.group_names:
            group = section.group(group_id)
            correlation_id = str(group.correlation) if group.correlation else ""
            correlation = self._model.global_parameters.group(correlation_id).label if correlation_id else ""
            self._out.write(f'BC {entity_id} {_str(group.label)} {group_id} ')
            self._out.write(f'0 {_bool(group.legal_on_interior)} {_str(correlation)}\n')
            for parameter_id in group.parameter_names:
                self._write_parameter('BC', section, group_id, parameter_id)

    def _write_parameter(self, prefix: str, section: GroupSet, internal_group_id: str, internal_parameter_id: str):
        """
        Write a parameter.

        Args:
            prefix: First part of the name of the cards to write ('GP', 'BC', 'MAT').
            section: Part of the definitions to write from.
            internal_group_id: ID of the group containing the parameter to write.
            internal_parameter_id: ID of the parameter to write.
        """
        parameter = section.group(internal_group_id).parameter(internal_parameter_id)

        if prefix == 'MAT':
            # Material parameters are stored in the UNASSIGNED_MATERIAL_GROUP internally for convenience, but we need
            # to restore those groups on export. The reader prepends the group ID to the parameter ID to enable this.
            external_group_id, external_parameter_id = internal_parameter_id.split()
        else:
            external_group_id, external_parameter_id = internal_group_id, internal_parameter_id

        self._out.write(f'{prefix}_DEF {external_group_id} {external_parameter_id} {_str(parameter.label)} ')
        if parameter.parameter_type == Type.BOOLEAN:
            self._out.write(f'0 {_bool(parameter.default)}\n')
        elif parameter.parameter_type == Type.INTEGER:
            self._out.write(f'1 {parameter.default} {parameter.low} {parameter.high}\n')
        elif parameter.parameter_type == Type.FLOAT:
            self._out.write(f'2 {_float(parameter.default)} {_float(parameter.low)} {_float(parameter.high)}\n')
        elif parameter.parameter_type == Type.TEXT:
            self._out.write(f'3 {_str(parameter.default)}\n')
        elif parameter.parameter_type == Type.OPTION:
            self._out.write(f'4 {_str(parameter.default)}\n')
            self._out.write(f'{prefix}_OPTS {external_group_id} {external_parameter_id} ')
            options = ' '.join(_str(option) for option in parameter.options)
            self._out.write(f'{options}\n')
        elif parameter.parameter_type == Type.CURVE:
            self._out.write(f'5 {_str(parameter.axis_titles[0])} {_str(parameter.axis_titles[1])}\n')
        elif parameter.parameter_type == Type.FLOAT_CURVE:
            self._out.write(f'6 {_float(parameter.default[1])} {_float(parameter.low)} {_float(parameter.high)} ')
            self._out.write(f'{_str(parameter.default[0])} {_str(parameter.axis_titles[0])} ')
            self._out.write(f'{_str(parameter.axis_titles[1])}\n')
        else:  # pragma: nocover
            raise AssertionError('Unrecognized type')  # Added a new type without adding support

        if parameter.parent:
            parent_id = parameter.parent[-1]
            parent_name = section.group(internal_group_id).parameter(parent_id).label
            self._out.write(
                f'{prefix}_DEP {external_group_id} {external_parameter_id} "PARENT_LOCAL" {_str(parent_name)} 0'
            )
            for parent_value in parameter.dependency_flags:
                enabled = parameter.dependency_flags[parent_value]
                self._out.write(f' {_str(parent_value)} {_bool(enabled)}')
            self._out.write('\n')

    def _write_material_names(self):
        """Write out the material name cards."""
        for material_id, material_name in zip(self._material_numbers, self._material_names, strict=True):
            self._out.write(f'MAT {material_id} {_str(material_name)}\n')

    def _write_material_definitions(self):
        """Write material definitions."""
        if not self._model or len(self._model.material_parameters.group_names) == 0:
            return

        section = self._model.material_parameters

        # All material groups have identical definitions, so we'll just write the definitions for the unassigned group.
        group = section.group(UNASSIGNED_MATERIAL_ID)
        for parameter_id in group.parameter_names:
            self._write_parameter('MAT', section, UNASSIGNED_MATERIAL_ID, parameter_id)

    def _write_values(self):
        """Write all the GMI values, if necessary."""
        if not self._should_write_values():
            return

        self._out.write('BEG2DMBC\n')
        self._write_global_values()
        self._write_material_values()
        self._write_point_values()
        self._write_arc_values()
        self._out.write('END2DMBC\n')

    def _should_write_values(self) -> bool:
        """Check whether parameter values should be written."""
        return any(
            [
                self._global_instantiation,
                self._point_values,
                self._arc_values,
                # self.polygon_values is not None,
                self._material_values,
            ]
        )

    def _write_global_values(self):
        """Write out the values for global parameters."""
        if not self._global_instantiation:
            return

        self._model.global_parameters.restore_values(self._global_instantiation)

        for group_id in self._model.global_parameters.group_names:
            group = self._model.global_parameters.group(group_id)
            # There's only one instance of the global variables, so the .2dm format can specify whether each group is
            # active, which means we don't have to make the activity/values decision that features have.
            for parameter_id in group.parameter_names:
                if group.parameter(parameter_id).has_value:
                    self._out.write(f'GP_VAL {group_id} {parameter_id} ')
                    self._write_parameter_value(group.parameter(parameter_id))

    def _write_point_values(self):
        """Write the GMI values for points, if necessary."""
        if not self._point_values:
            return

        self._logger.info('Writing node values...')
        point_assigned = np.zeros(self._ugrid.ugrid.point_count, dtype=int)

        section = self._model.to_template().point_parameters
        for point_id, values in zip(self._points, self._point_values):
            feature_id = point_id + 1
            if point_assigned[feature_id - 1]:
                raise write_error(Message.duplicate_point, feature_id)
            else:
                point_assigned[feature_id - 1] = True
            section.restore_values(values)
            for group_id in section.group_names:
                group = section.group(group_id)
                # The .2dm format has no way to specify whether boundary conditions are active for individual features,
                # so activity is implied by whether anything in the group is assigned. This means preserving values and
                # preserving activity are mutually exclusive. Activity is more important, so we write all values for
                # active groups and none for inactive ones, which implicitly writes the activity.
                if group.is_active:
                    for parameter_id in group.parameter_names:
                        parameter = group.parameter(parameter_id)
                        self._out.write(f'BC_VAL N {feature_id} {group_id} {parameter_id} ')
                        self._write_parameter_value(parameter)

    def _write_arc_values(self):
        """Write the GMI values for arcs, if necessary."""
        if not self._arc_values:
            return

        self._logger.info('Writing nodestring values...')
        section = self._model.to_template().arc_parameters

        for arc_id, values in zip(self._arc_ids, self._arc_values):
            section.restore_values(values)
            for group_id in section.group_names:
                group = section.group(group_id)
                if not group.is_active:
                    # The .2dm format has no way to specify whether boundary conditions are active for individual
                    # features, so activity is implied by whether anything in the group is assigned. This means
                    # preserving values and preserving activity are mutually exclusive. Activity is more important, so
                    # we don't write values for inactive groups, which implicitly writes the activity.
                    continue
                for parameter_id in group.parameter_names:
                    parameter = group.parameter(parameter_id)
                    if parameter.has_value:
                        self._out.write(f'BC_VAL S {arc_id} {group_id} {parameter_id} ')
                        self._write_parameter_value(parameter)

    def _write_material_values(self):
        """Write the GMI values for materials, if necessary."""
        if not self._material_values:
            return

        for material_id, group in zip(self._material_numbers, self._material_values, strict=True):
            for parameter_id in group.parameter_names:
                if group.parameter(parameter_id).has_value:
                    # Materials are all stored in a single group internally for convenience, but we want to restore
                    # those groups on export. The reader prefixes the parameter ID with the group ID to enable this.
                    write_group, write_parameter = parameter_id.split()
                    self._out.write(f'MAT_VAL {material_id} {write_group} {write_parameter} ')
                    self._write_parameter_value(group.parameter(parameter_id))

    def _write_parameter_value(self, parameter: Parameter):
        """
        Write the value for a parameter.

        Args:
            parameter: The parameter to write the value of.
        """
        t = parameter.parameter_type
        if t == Type.BOOLEAN:
            self._out.write(_bool(parameter.value))
        elif t == Type.INTEGER:
            self._out.write(str(parameter.value))
        elif t == Type.FLOAT:
            self._out.write(_float(parameter.value))
        elif t == Type.TEXT or t == Type.OPTION:
            self._out.write(_str(parameter.value))
        elif t == Type.CURVE:
            curve_x, curve_y = parameter.value
            self._write_curve_id(parameter.label, curve_x, curve_y)
        elif t == Type.FLOAT_CURVE:
            mode, float_value, curve_x, curve_y = parameter.value
            if mode == 'FLOAT':
                self._out.write(f'VALUE {_float(float_value)}')
            elif mode == 'CURVE':
                self._out.write('CURVE ')
                self._write_curve_id(parameter.label, curve_x, curve_y)
            else:  # pragma: nocover
                raise AssertionError("Unknown mode")  # Added a new mode without adding support here
        else:  # pragma: nocover
            raise AssertionError(f"Unknown type {t}")  # Added a new type without adding support here
        self._out.write('\n')

    def _write_curve_id(self, name: str, curve_x: list[int], curve_y: list[int]):
        """
        Write the ID for a curve and ensure its data will be written later.

        Args:
            name: Name to give the curve.
            curve_x: Curve X values.
            curve_y: Curve Y values.
        """
        if curve_x == curve_y == [0.0]:
            self._out.write('-1')
            return

        series = XySeries()
        self._curves.append(series)
        series.series_id = len(self._curves)
        series.name = name
        series.x = curve_x
        series.y = curve_y
        self._out.write(f'{series.series_id}')

    def _write_curve_data(self):
        """Write the curve section."""
        if not self._curves:
            return

        self._out.write('BEGCURVE Version: 1\n')

        for curve in self._curves:
            curve.write(self._out)

        self._out.write('ENDCURVE\n')

    def _write_extra_file(self, path: Path):
        if not self._ugrid:
            return

        write_extra_file(
            mesh_name=self._mesh_name,
            mesh_path=path,
            mesh_uuid=self._ugrid.uuid,
            projection=self._projection,
            xms_version=self._xms_version
        )


def _bool(b: bool):
    """Format a bool."""
    return '1' if b else '0'


def _float(f: float):
    """Format a float."""
    return str(f)


def _str(s: str):
    """Format a string."""
    return f'"{s}"'


def _chunk(lst):
    """Yield successive 10-element chunks from lst."""
    chunks = []
    for i in range(0, len(lst), 10):
        chunks.append(lst[i:i + 10])
    return chunks
