"""Read a WaveWatch3 gmsh unstructured grid file."""

__copyright__ = "(C) Copyright Aquaveo 2021"
__license__ = "All rights reserved"

# 1. Standard Python modules
import logging
import os
import uuid

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.api.dmi import Query, XmsEnvironment as XmEnv
from xms.constraint.ugrid_builder import UGridBuilder
from xms.data_objects.parameters import Projection, UGrid
from xms.grid.ugrid import UGrid as XmUGrid

# 4. Local modules
from xms.wavewatch3.file_io.io_util import GEOGRAPHIC_WKT, LOCAL_METERS_WKT, READ_BUFFER_SIZE


class GmshReader:
    """Class for reading a WW3 mesh file."""
    MIN_NUM_CELL_CARDS = 9  # Minimum number of columns for element types we care about
    COLUMNS_BEFORE_NUM_NODES = 3  # Number of columns in element line before the number of element nodes

    def __init__(self, filename='', grid_name=None, is_geographic=True):
        """Constructor.

        Args:
            filename (:obj:`str`): Path to the mesh file. If not provided (not testing or control file read),
                will retrieve from Query.
            grid_name (:obj:`str`): Name to assign the mesh
            is_geographic (:obj:`bool`): True if degree units, False if meters
        """
        self._query = None
        self._filename = filename
        self._setup_query()
        self._lines = []
        self._current_line = 0
        self._logger = logging.getLogger('xms.wavewatch3')
        self._grid_name = grid_name
        self._is_geographic = is_geographic
        self.do_ugrid = None
        self.co_ugrid = None

    def _setup_query(self):
        """Setup the xmsapi Query if not testing or running a recording test."""
        if not self._filename:
            self._query = Query()
            self._filename = self._query.read_file

    def _parse_next_line(self):
        """Parse the next line of text from the file.

        Skips empty and comment lines.

        Returns:
            (:obj:`list[str]`): The next line of text, split on whitespace
        """
        line = None
        while not line or line.startswith('$'):  # blank lines and control file identifier
            if self._current_line >= len(self._lines):
                raise RuntimeError('Unexpected end of file.')
            line = self._lines[self._current_line].strip()
            self._current_line += 1
        return line.split()

    def _read_node_locations(self, num_nodes):
        """Parse the node locations from the file.

        Args:
            num_nodes (:obj:`int`): The number of nodes in the file

        Returns:
            (:obj:`tuple(numpy.ndarray,dict)`): The node locations and the mapping of node id in file to point index we
            will be using to build the XmUGrid
        """
        locations = np.zeros((num_nodes, 3))
        node_id_map = {}
        for point_idx in range(num_nodes):
            line = self._parse_next_line()
            node_id_map[int(line[0])] = point_idx
            locations[point_idx][0] = float(line[1])
            locations[point_idx][1] = float(line[2])
            locations[point_idx][2] = -float(line[3])
        return locations, node_id_map

    def _read_cells(self, num_cells, node_id_map):
        """Parse the cellstream from the file.

        Args:
            num_cells (:obj:`int`): The number of cells in the file
            node_id_map (:obj:`dict`): Mapping of node id in file to point index we will be using to build the XmUGrid

        Returns:
            (:obj:`list`): The cellstream
        """
        cellstream = []
        for _ in range(num_cells):
            line = self._parse_next_line()
            if len(line) >= self.MIN_NUM_CELL_CARDS:  # Filter out boundary "cells" with fewer than 3 points.
                num_cell_nodes = int(line[2])
                node_start_col = self.COLUMNS_BEFORE_NUM_NODES + num_cell_nodes
                # Currently only tris for unstructured grids, quads for structured grids.
                cellstream.append(XmUGrid.cell_type_enum.POLYGON)
                cellstream.append(num_cell_nodes)
                # Convert node ids to 0-based indices.
                cellstream.extend([node_id_map[int(line[node_start_col + i])] for i in range(num_cell_nodes)])
        return cellstream

    def _build_mesh(self, locations, cellstream):
        """Build the CoGrid and data_objects UGrid.

        Args:
            locations (:obj:`numpy.ndarray`): The grid's points
            cellstream (:obj:`list`): The grid's cellstream
        """
        # Write the CoGrid file
        xmugrid = XmUGrid(locations, cellstream)
        co_builder = UGridBuilder()
        co_builder.set_is_2d()  # This will create a Mesh2D in SMS instead of a UGrid, can take either.
        co_builder.set_ugrid(xmugrid)
        cogrid = co_builder.build_grid()
        self.co_ugrid = cogrid
        null_uuid = '11111111-1111-1111-1111-111111111111'
        running_tests = XmEnv.xms_environ_running_tests() == 'TRUE'
        grid_uuid = null_uuid if running_tests else str(uuid.uuid4())
        cogrid.uuid = grid_uuid
        cogrid_file = os.path.join(XmEnv.xms_environ_process_temp_directory(), f'{grid_uuid}.xmc')
        cogrid.write_to_file(cogrid_file, not running_tests)

        # Create the data_objects UGrid
        proj = Projection(wkt=GEOGRAPHIC_WKT if self._is_geographic else LOCAL_METERS_WKT)
        # If we don't know the grid name from the ww3_grid.nml file, use basename of the file.
        grid_name = self._grid_name if self._grid_name else os.path.basename(self._filename)
        self.do_ugrid = UGrid(cogrid_file, name=grid_name, uuid=grid_uuid, projection=proj)

    def read(self, meshbnd=False):
        """Read a WaveWatch3 gmsh unstructured grid file.

        Args:
            meshbnd(:obj:`bool`):  True if reading a meshbnd.msh mesh file, and you desire only the node locations.

        Return:
            (:obj:`list[tuple]`): The node locations ID values and locations (if meshbnd), else empty list.
        """
        self._logger.info('Parsing ASCII text from file...')
        with open(self._filename, 'r', buffering=READ_BUFFER_SIZE) as f:
            self._lines = f.readlines()

        self._logger.info('Finding number of nodes...')
        skipped_header = False
        num_nodes = None
        while num_nodes is None:  # Find the number of nodes
            line = self._parse_next_line()
            if not skipped_header:  # Skip line that is always '2 0 8', until we find out what it is
                skipped_header = True
                continue
            num_nodes = int(line[0])  # First non-comment line after header is the number of nodes
        self._logger.info('Parsing node locations from file...')
        locations, node_id_map = self._read_node_locations(num_nodes)
        if meshbnd:
            # return [location for location in locations if location[2] == -2.0]
            return [(node_id, loc) for loc, node_id in zip(locations, node_id_map.items()) if loc[2] == -2.0]

        self._logger.info('Parsing element definitions from file...')
        num_cells = int(self._parse_next_line()[0])  # First non-comment line after nodes is number of cells
        cellstream = self._read_cells(num_cells, node_id_map)

        self._logger.info('Writing mesh to SMS file...')
        self._build_mesh(locations, cellstream)
        return []

    def send(self):
        """Send data to SMS if a control file and not testing or running a recording test."""
        if self._query and self.do_ugrid:
            self._query.add_ugrid(self.do_ugrid)
            self._query.send()
