"""A reader for CMS-Wave simulation files."""

# 1. Standard Python modules
import logging
import os

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import XmsEnvironment as XmEnv
from xms.constraint.rectilinear_geometry import Numbering, Orientation
from xms.constraint.rectilinear_grid_builder import RectilinearGridBuilder
from xms.data_objects.parameters import Projection

# 4. Local modules


def logging_filename(filename):
    """Get the basename of a file if testing or whatever was passed in if not.

    Args:
        filename (:obj:`str`): Path to the file to strip if testing

    Returns:
        (:obj:`str`): See description
    """
    if XmEnv.xms_environ_running_tests():
        return os.path.basename(filename)
    return filename


class SimFileReader:
    """A class for reading CMS-Wave simulation files."""
    def __init__(self, filename):
        """Constructor.

        Args:
            filename (:obj:`str`): Simulation file filename
        """
        self._logger = logging.getLogger('xms.cmswave')
        self._filename = filename
        self._grid_data = {}
        self._grid_z = []
        self.grid = None
        self.input_files = {}  # {dset_name: [filename, format]}
        self.grid_proj = Projection()

    def _read_depth_file(self, filename):
        """Reads a CMS-Wave depth file (.dep), storing the depths and grid discretization.

        Arguments:
            filename (:obj:`str`): The filename of the file to read.
        """
        read_dimensions = False
        idx = 0
        grid_z = []
        locations_x = []
        locations_y = []
        numcells = 0
        with open(filename, 'r') as in_file:
            firstline = True
            for line in in_file:
                line = line.strip().replace("\'", '')
                if not line or line[0] == '#' or line[0] == '&' or line[0] == '/':
                    continue
                line_data = [x.strip() for x in line.split(' ') if x != '']
                if firstline:
                    self._grid_data['n_cell_i'] = int(line_data[0])
                    self._grid_data['n_cell_j'] = int(line_data[1])
                    numcells = self._grid_data['n_cell_i'] * self._grid_data['n_cell_j']
                    self._grid_data['dx'] = float(line_data[2])
                    if len(line_data) > 3:
                        self._grid_data['dy'] = float(line_data[3])
                        if float(line_data[3]) == 999.0:
                            read_dimensions = True
                            locations_x = [0.0] * self._grid_data['n_cell_i']
                            locations_y = [0.0] * self._grid_data['n_cell_j']
                    else:
                        self._grid_data['dy'] = self._grid_data['dx']

                    grid_z = [0.0] * numcells
                    firstline = False
                else:
                    for data in line_data:
                        if idx < numcells:
                            # Store depth values and increment counter
                            grid_z[idx] = float(data)
                            idx += 1
                        elif read_dimensions:
                            # We are reading the cell sizes at the bottom of the file
                            if idx < numcells + self._grid_data['n_cell_i']:
                                locations_x[idx - numcells] = float(data)
                            elif idx >= numcells + self._grid_data['n_cell_i']:
                                locations_y[idx - numcells - self._grid_data['n_cell_i']] = float(data)
                            idx += 1

        # Swap the rows in the array
        grid_z = [-z for z in grid_z]
        numrow = self._grid_data['n_cell_j']
        numcol = self._grid_data['n_cell_i']
        for i in range(numrow):
            cur_row_values = grid_z[(numrow - i - 1) * numcol:(numrow - i - 1) * numcol + numcol]
            self._grid_z.extend(cur_row_values)

        if read_dimensions:
            self._grid_data['locations_x'] = locations_x
            self._grid_data['locations_y'] = locations_y

    def _build_grid_from_data(self):
        """Builds a Rectilinear Grid from grid data members.

        Returns:
            (:obj:`bool`): False on error
        """
        builder = RectilinearGridBuilder()
        builder.angle = self._grid_data['azimuth']
        builder.origin = (self._grid_data['x0'], self._grid_data['y0'], 0.0)
        builder.numbering = Numbering.kji
        builder.orientation = (Orientation.x_increase, Orientation.y_increase)
        builder.is_2d_grid = True
        builder.is_3d_grid = False
        if 'dx' in self._grid_data and self._grid_data['dx'] == self._grid_data['dy']:
            builder.set_square_xy_locations(
                self._grid_data['n_cell_i'] + 1, self._grid_data['n_cell_j'] + 1, self._grid_data['dx']
            )
        else:
            if 'locations_x' in self._grid_data:
                loc_x = [0]
                loc_y = [0]
                curval = 0
                for i in range(len(self._grid_data['locations_x'])):
                    curval += self._grid_data['locations_x'][i]
                    loc_x.append(curval)
                curval = 0
                for i in range(len(self._grid_data['locations_y'])):
                    curval += self._grid_data['locations_y'][i]
                    loc_y.append(curval)
                builder.locations_x = loc_x
                builder.locations_y = loc_y
            else:
                if not self._grid_data.get('n_cell_i') or not self._grid_data.get('n_cell_j'):
                    self._logger.error('No grid definition found. Aborting import of CMS-Wave model.')
                    return False
                builder.locations_x = [self._grid_data['dx'] * i for i in range(self._grid_data['n_cell_i'] + 1)]
                builder.locations_y = [self._grid_data['dy'] * j for j in range(self._grid_data['n_cell_j'] + 1)]
        self.grid = builder.build_grid()
        if self._grid_z:
            self.grid.cell_elevations = self._grid_z  # already in elevations
        return True

    def _set_projection(self):
        """Set the grid projection from imported data."""
        self.grid_proj.vertical_units = 'METERS'
        self.grid_proj.horizontal_units = 'METERS'
        if 'coord_sys' in self._grid_data:
            self.grid_proj.coordinate_system = self._grid_data['coord_sys']
        if 'spzone' in self._grid_data:
            self.grid_proj.coordinate_zone = self._grid_data['spzone']

    def read(self):
        """Reads all the cards in file and loads data members with the info.

        Widgets, grid data, input files, idd cards, Points and others are set.

        Returns:
            (:obj:`bool`): False on error
        """
        self._logger.info(f'Reading simulation file {logging_filename(self._filename)}')
        with open(self._filename, 'r') as in_file:
            for line in in_file:
                line = line.strip().replace("\'", '')
                if not line or line[0] == '#' or line[0] == '&' or line[0] == '/':
                    continue

                line_data = [x.strip() for x in line.split(' ') if x != '']
                # remove comma at end of the line
                line_data[len(line_data) - 1] = line_data[len(line_data) - 1].replace(',', '')
                token = line_data[0].lower()

                if token == 'cms-wave':
                    # First line of file contains the grid origin and rotation
                    self._grid_data['x0'] = float(line_data[1])
                    self._grid_data['y0'] = float(line_data[2])
                    self._grid_data['azimuth'] = float(line_data[3])
                else:
                    line_data = [x.strip() for x in line.split(' ', 1) if x != '']
                    self.input_files[token] = os.path.join(os.path.dirname(self._filename), ''.join(line_data[1:]))

        dep_filename = self.input_files.get('dep', '')
        if os.path.isfile(dep_filename):
            self._logger.info(f'Reading depth file {logging_filename(dep_filename)}')
            self._read_depth_file(dep_filename)
        if not self._build_grid_from_data():
            return False
        self._set_projection()
        return True
