"""Write a WaveWatch3 unstructured grid file."""

__copyright__ = "(C) Copyright Aquaveo 2021"
__license__ = "All rights reserved"

# 1. Standard Python modules
from io import StringIO
import logging
import os
import shutil

# 2. Third party modules

# 3. Aquaveo modules
from xms.constraint.ugrid_boundaries import UGridBoundaries
from xms.grid.ugrid import UGrid

# 4. Local modules


class GmshWriter:
    """Class for exporting the WaveWatch3 unstructured grid file."""
    NUM_BOUNDARY_TAGS = 2
    NUM_CELL_TAGS = 3
    BOUNDARY_TYPE = 15

    def __init__(self, cogrid, grid_name, lateral_node_ids=None, ocean_node_ids=None):
        """Constructor.

        Args:
            cogrid (:obj:`CoGrid`): The domain grid
            grid_name (:obj:`str`): Name of the domain grid. Filename will be: <cwd>/<grid_name>.msh
            lateral_node_ids (:obj:`list[int]`): List of lateral node id's when making the meshbnd.msh version
            ocean_node_ids (:obj:`list[int]`): List of ocean/input node id's when making the meshbnd.msh version
        """
        self._filename = os.path.join(os.getcwd(), f'{grid_name}.msh')
        self._ugrid = cogrid.ugrid
        self._ss = StringIO()
        self._logger = logging.getLogger('xms.wavewatch3')
        self.last_id = -1

        # Special case for writing meshbnd.msh version
        self._lateral = lateral_node_ids if lateral_node_ids else []
        self._ocean = ocean_node_ids if ocean_node_ids else []
        self._filename = os.path.join(os.getcwd(), 'meshbnd.msh') if self._lateral else self._filename

    def _get_cell_type_and_size(self, cellstream, idx):
        """Get an elements WW3 type and number of points.

        Args:
            cellstream (:obj:`list`): The UGrid cellstream
            idx (:obj:`int`): Index of the cell to retrieve info for

        Returns:
            (:obj:`tuple(int,int)`): The element's WW3 type and number of points
        """
        if cellstream[idx] == UGrid.cell_type_enum.TRIANGLE:
            return 2, 3
        elif cellstream[idx] == UGrid.cell_type_enum.QUAD:
            return 3, 4
        raise RuntimeError(f'Unsupported element type found for cell {idx + 1}')

    def _write_header(self):
        """Write the header lines to the file."""
        self._ss.write("$MeshFormat\n")
        self._ss.write("2 0 8\n")  # Version, ASCII, sizeof(size_t)
        self._ss.write("$EndMeshFormat\n")

    def _write_nodes(self):
        """Write mesh node locations to the file."""
        self._logger.info('Writing mesh node locations to file...')
        self._ss.write("$Nodes\n")
        self._ss.write(f"{self._ugrid.point_count:>12}\n")
        pt_id = 1
        for pt in self._ugrid.locations:
            cur_line = f'{pt_id:>10}  {pt[0]:>18.8f}  {pt[1]:>18.8f}'
            # The z value is the negative z value of the point.
            # If using meshbnd.msh (lateral point version), it's 2.0 for lateral boundary points, else 0.0.
            #   If a point is both a lateral node and an ocean node, set to 0.0 not 2.0 (ocean has priority).
            z = 2.0 if self._lateral and pt_id in self._lateral else 0.0 if self._lateral else -pt[2]
            z = 0.0 if self._ocean and pt_id in self._ocean else z
            cur_line += f'  {z:>18.8f}\n'
            self._ss.write(cur_line)
            pt_id += 1
        self._ss.write("$EndNodes\n")

    def _write_boundary(self):
        """Write the boundary nodes to the file as single node elements.

        Returns:
            (:obj:`int`): The id to use for the next element
        """
        self._logger.info('Computing mesh boundary...')
        boundaries = UGridBoundaries(self._ugrid)
        loops = boundaries.get_loops()
        num_loop_pts = 0
        for loop in range(len(loops)):
            num_loop_pts += len(loops[loop]['id'])
        self._ss.write("$Elements\n")  # Start of the elements section (boundary nodes and UGrid cells)
        self._ss.write(f"{self._ugrid.cell_count + num_loop_pts:>12}\n")

        # write boundary points as elements
        self._logger.info('Writing boundary nodes to file...')
        elem_id = 1
        for loop_key, loop_vals in loops.items():
            for b_pt in loop_vals['id']:
                # For each boundary node: <count starting at 1> 15 2 0 <loop_id starting at 1> <node_id (1-based)>
                cur_line = f'{elem_id:>10} {self.BOUNDARY_TYPE:>9} {self.NUM_BOUNDARY_TAGS:>9} {0:>9} {loop_key:>9} ' \
                           f'{b_pt + 1:>9}\n'
                self._ss.write(cur_line)
                elem_id += 1
        return elem_id

    def _write_elements(self, elem_id):
        """Write the UGrid cells to the file as elements.

        Args:
            elem_id (:obj:`int`): Id to use for the next element

        Returns:
            (:obj:`int`): The last element id
        """
        self._logger.info('Writing mesh elements to file...')
        cell_stream = self._ugrid.cellstream
        index = 0
        cell_id = 1
        while index < len(cell_stream):
            # Write:  ElementID   ElementType   NumTags    Tag1    Tag2   Tag3   NodeIDs....
            # ElementType = 2 for triangles, 3 for quads
            # Tag1 and Tag3 seem to be 0 (not attached to anything else)
            # Tag2 seems to be the triangle/quad number (starting with 1)
            # NodeID's are just the 1 based node IDs comprising the triangle or quad
            elem_type, elem_size = self._get_cell_type_and_size(cell_stream, index)
            cur_id = elem_id if self._lateral else cell_id  # meshbnd version is slightly different
            cur_line = f'{elem_id:>8} {elem_type:>7} {self.NUM_CELL_TAGS:>7} {0:>7} {cur_id:>7} {0:>7}'
            for j in range(elem_size):
                cur_line += f' {cell_stream[index + 2 + j] + 1:>7}'
            self._ss.write(f'{cur_line}\n')
            index += elem_size + 2
            elem_id += 1
            cell_id += 1
        self._ss.write("$EndElements\n")
        return elem_id - 1

    def _flush(self):
        """Write in-memory stream to disk."""
        self._logger.info('Flushing to disk...')
        f = open(self._filename, 'w')
        self._ss.seek(0)
        shutil.copyfileobj(self._ss, f, 1000000)
        f.close()

    def write(self):
        """Write the .msh file."""
        self._logger.info('Exporting WaveWatch3 domain mesh to unstructured grid file...')
        if self._ugrid.point_count == 0 or self._ugrid.cell_count == 0:
            self._logger.warning('No mesh data to export.')
            return
        self._write_header()
        self._write_nodes()
        elem_id = self._write_boundary()
        self.last_id = self._write_elements(elem_id)
        self._flush()
        self._logger.info('Successfully exported WaveWatch3 domain mesh.')
