"""TUFLOW-FV geometry file writer."""
# 1. Standard python modules
from collections import OrderedDict
from io import StringIO
import logging
import os
import shutil

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.guipy.data.target_type import TargetType
from xms.snap.snap_exterior_arc import SnapExteriorArc
from xms.snap.snap_interior_arc import SnapInteriorArc
from xms.snap.snap_polygon import SnapPolygon

# 4. Local modules
from xms.tuflowfv.components.tuflowfv_component import UNINITIALIZED_COMP_ID
from xms.tuflowfv.file_io.io_util import READ_BUFFER_SIZE


class GeomWriter:
    """Writer class for the TUFLOW-FV geometry file."""

    def __init__(self, xms_data, coverages):
        """Constructor.

        Args:
            xms_data (XmsData): Simulation data retrieved from SMS
            coverages (CoverageCollector): The simulation coverage data
        """
        self._ss = StringIO()
        self._logger = logging.getLogger('xms.tuflowfv')
        self._xms_data = xms_data
        self._coverages = coverages
        self._materials_on_cells = None
        self._nodestrings = OrderedDict()  # {bc_id: [n0, ..., nN]}
        self._exterior_snapper = None
        self._interior_snapper = None
        self._polygon_snapper = None

    @property
    def exterior_snapper(self):
        """Lazy initialization getter so we only construct the exterior snapper if we need it."""
        if self._exterior_snapper is None:
            self._logger.info('Building exterior arc snapper for mesh...')
            self._exterior_snapper = SnapExteriorArc()
            self._exterior_snapper.set_grid(grid=self._xms_data.cogrid, target_cells=False)
        return self._exterior_snapper

    @property
    def interior_snapper(self):
        """Lazy initialization getter so we only construct the interior snapper if we need it."""
        if self._interior_snapper is None:
            self._logger.info('Building interior arc snapper for mesh...')
            self._interior_snapper = SnapInteriorArc()
            self._interior_snapper.set_grid(grid=self._xms_data.cogrid, target_cells=False)
        return self._interior_snapper

    @property
    def polygon_snapper(self):
        """Lazy initialization getter so we only construct the polygon snapper if we need it."""
        if self._polygon_snapper is None:
            self._logger.info('Building polygon snapper for mesh...')
            self._polygon_snapper = SnapPolygon()
            self._polygon_snapper.set_grid(grid=self._xms_data.cogrid, target_cells=True)
        return self._polygon_snapper

    def _get_material_assignments(self):
        """Fill the cell material assignment array."""
        # Initialize material assignments to -1. If we use 0 (unassigned) TUFLOW complains that all cells are inactive.
        self._materials_on_cells = np.full(self._xms_data.xmugrid.cell_count, UNINITIALIZED_COMP_ID)
        material_cov_indices = self._coverages.twodm_material_coverages()
        if not material_cov_indices:
            return  # No linked material coverages that are set to export to the .2dm file
        for loop_idx, material_cov_index in enumerate(material_cov_indices):
            self._logger.info(
                f'Snapping material polygons to mesh for .2dm export ({loop_idx + 1} of {len(material_cov_indices)})...'
            )
            cov_polys = self._xms_data.mat_covs[material_cov_index].polygons
            self.polygon_snapper.add_polygons(polygons=cov_polys)
            mat_comp = self._xms_data.mat_comps[material_cov_index]
            for cov_poly in cov_polys:
                original_mat_id = mat_comp.get_comp_id(TargetType.polygon, cov_poly.id)
                if original_mat_id == UNINITIALIZED_COMP_ID or original_mat_id is None:
                    continue
                new_mat_id = self._coverages.material_lookup[material_cov_index][original_mat_id]
                cell_indices = self.polygon_snapper.get_cells_in_polygon(cov_poly.id)
                for cell_idx in cell_indices:
                    self._materials_on_cells[cell_idx] = new_mat_id

    def _get_bc_nodestrings(self):
        """Fill the nodestring array, if there are any."""
        bc_cov_indices = self._coverages.twodm_bc_coverages()
        if not bc_cov_indices:
            return  # No linked BC coverages that are set to export to the .2dm file
        for loop_idx, bc_cov_index in enumerate(bc_cov_indices):
            self._logger.info(
                f'Snapping BC arcs to mesh for .2dm export ({loop_idx + 1} of {len(bc_cov_indices)})...'
            )
            lookup = self._coverages.bc_lookup[bc_cov_index][TargetType.arc]
            for bc_arc in self._xms_data.bc_covs[bc_cov_index].arcs:
                bc_id = lookup.get(bc_arc.id)
                if not self._coverages.bcs.get(bc_id):  # If no atts (monitor line) snap to the interior
                    snapped_arc = self.interior_snapper.get_snapped_points(bc_arc)
                else:  # All other BCs snap to the exterior
                    snapped_arc = self.exterior_snapper.get_snapped_points(bc_arc)
                node_ids = [node_id + 1 for node_id in snapped_arc['id']]
                if node_ids:
                    node_ids[-1] *= -1  # Last node in the string needs to be negative
                    self._nodestrings[bc_id] = node_ids
                else:  # Arc outside the domain snapped to the interior? Probably.
                    self._logger.warning(f'Could not map BC line {bc_arc.id} to the domain.')

    def _write_header(self):
        """Write the header lines at the beginning of the file."""
        self._ss.write('MESH2D\n')
        self._ss.write(f'MESHNAME "{self._xms_data.do_ugrid.name}"\n')
        self._ss.write('NUM_MATERIALS_PER_ELEM 1\n')

    def _write_cells(self):
        """Write the cell lines."""
        self._logger.info('Writing mesh elements to file...')
        ugrid = self._xms_data.xmugrid
        for i in range(ugrid.cell_count):
            point_ids = [pt_idx + 1 for pt_idx in ugrid.get_cell_points(i)]
            cell_type = 'E4Q' if len(point_ids) == 4 else 'E3T'
            # cell_type cell_id node1_id, ..., nodeN_id mat_id
            self._ss.write(
                f'{cell_type} {i + 1} {" ".join([str(point_id) for point_id in point_ids])} '
                f'{self._materials_on_cells[i]}\n'
            )

    def _write_points(self):
        """Write the mesh node locations to the file."""
        self._logger.info('Writing mesh node locations to file...')
        ugrid = self._xms_data.xmugrid
        locations = ugrid.locations
        for i in range(ugrid.point_count):
            self._ss.write(f'ND {i + 1} {locations[i][0]} {locations[i][1]} {locations[i][2]}\n')

    def _write_nodestrings(self):
        """Write all the .2dm defined bcs to the file."""
        self._logger.info('Writing mesh bcs to file...')
        for nodestring_id, node_ids in self._nodestrings.items():
            nodes = [  # Add a new line for every 10 nodes in the string.
                str(node_id) if idx == 0 or idx % 10 != 0 else f'\nNS {node_id}' for idx, node_id in enumerate(node_ids)
            ]

            self._ss.write(f'NS {" ".join(nodes)} {nodestring_id}')
            atts = self._coverages.bcs.get(nodestring_id)
            if atts:
                name = atts[0].name.item()
                if name:
                    # Check for spaces in the BC name, which are unsupported with the .2dm format. Only really a problem
                    # if we try to read them back in.
                    if ' ' in name:
                        self._logger.warning(f'Nodestring {nodestring_id} contains space(s) in its name ({name}). This '
                                             'is not supported in the .2dm format. Importing the exported simulation '
                                             'files into SMS may result in errors.')
                    if name != str(nodestring_id):
                        self._ss.write(f' {name}')
            self._ss.write('\n')

    def _flush(self):
        """Flush in-memory buffer to disk."""
        self._logger.info('Flushing in-memory stream to disk...')
        filename = os.path.join(os.getcwd(), 'model', 'geo', f'{self._xms_data.do_ugrid.name}.2dm')
        with open(filename, 'w') as f:
            self._ss.seek(0)
            shutil.copyfileobj(self._ss, f, READ_BUFFER_SIZE)

    def write(self):
        """Write the TUFLOW-FV geometry file."""
        self._get_material_assignments()
        self._get_bc_nodestrings()
        self._write_header()
        self._write_cells()
        self._write_points()
        self._write_nodestrings()
        self._flush()
