"""A reader for spatial datasets in CMS-Wave."""

# 1. Standard Python modules
import datetime
import logging
import math

# 2. Third party modules

# 3. Aquaveo modules
from xms.core.filesystem import filesystem as io_util
from xms.datasets.dataset_writer import DatasetWriter

# 4. Local modules


class DatasetReaderCMSWAVE:
    """CMS-Wave spatial dataset reader for input and solution files."""
    def __init__(self, dset_name, filename, ts_times, geom_uuid, angle=0.0, reftime=None):
        """
        Constructor.

        Args:
            dset_name (:obj:`str`): Name of the dataset.
            filename (:obj:`str`): Filename of the CMS-Wave formatted dataset file. If relative, must be from
                current working directory.
            ts_times (:obj:`list[datetime.datetime]`): Timestep times if not in the CMS-Wave file. IDD
                timestamps used if empty.
            geom_uuid (:obj:`str`): UUID of the dataset's geometry in SMS - not always needed.
            angle (:obj:`float`): used when a_translate=0.
            reftime (:obj:`datetime.datetime`): simulation reference time
        """
        self.name = dset_name
        self.filename = filename.strip().strip("\"").strip("\'")
        self.times = ts_times
        self.geom_uuid = geom_uuid
        self.dataset = None
        self.num_components = 1
        self.recinc = 60 * 60  # default to hour timesteps
        self.recunits = 1  # convert everything to seconds - timestep interval in hours = recinc * recunits
        today = datetime.date.today()  # default reftime to current day
        self.reftime = datetime.datetime(today.year, today.month, today.day, 0, 0, 0) if reftime is None else reftime
        self.currtime = self.reftime
        self.first_time = True
        self.first_message = True  # Only write first warning about indices instead of timestamps.
        self.grid_angle = angle  # used when converting from local to global
        self._logger = logging.getLogger('xms.cmswave')

        self.is_steady_state = dset_name in ['mud', 'fric', 'fref', 'bref']

        # special case for wave datasets
        self.wave_height = None
        self.wave_period = None
        self.wave_direction = None
        self.wave_ht_by_dir = None
        self.ni = 0
        self.nj = 0

        # set up constants used for translations
        rad = math.radians(self.grid_angle)
        self.COSANG = math.cos(rad)
        self.SINANG = math.sin(rad)
        # self.debugger = open("debug.txt", "a")

    @staticmethod
    def _get_output_filename():
        """Returns a filename suitable for writing an XMDF dataset."""
        # TODO: Do something to make this work with tests.
        return io_util.temp_filename()

    def _set_ref_time(self, reftime):
        """Set the reference date for the dataset writers as soon as we know it.

        Need to set this on the dataset writers before appending the first time step.

        Args:
            reftime (:obj:`datetime.datetime`): The dataset's reference date
        """
        self.reftime = reftime
        if self.dataset:
            self.dataset.ref_time = reftime
        if self.wave_height:
            self.wave_height.ref_time = reftime
        if self.wave_period:
            self.wave_period.ref_time = reftime
        if self.wave_direction:
            self.wave_direction.ref_time = reftime
        if self.wave_ht_by_dir:
            self.wave_ht_by_dir.ref_time = reftime
        self.first_time = False

    def _set_num_components(self, num_components):
        """Set the number of components for the dataset writer as soon as we know it.

        Args:
            num_components (:obj:`int`): The number of dataset components (1=scalar, 2=vector)
        """
        self.num_components = num_components
        if self.dataset:  # Only need to update the generic writer, the wave writers already know what they are.
            self.dataset.num_components = num_components

    def shore_normal_to_cart(self, angle):
        """
        This converts a shore-normal degree angle to a cartesian angle.

        Args:
            angle (:obj:`float`): The shore-normal angle to be converted in degrees

        Returns:
            (:obj:`float`): The cartesian angle conversion in degrees
        """
        angle += self.grid_angle
        while angle < 0.0:  # normalize 0-360
            angle += 360.0
        while angle >= 360.0:
            angle -= 360.0
        return angle

    def parse_time(self, a_str, ts_idx):
        """
        Parse a string into a datetime.datetime.

        This function will take priority of reading from the self.times list member instead of the a_str parameter
        if it is the first time using the function on an instance of DatasetReader. Following times will read
        from the string passed in a_str.

        Args:
            a_str (:obj:`string`): A string formatted with the date-time information.
            ts_idx (:obj:`int`): An index to access the self.times member

        Returns:
            (:obj:`datetime.datetime`): The timestep value
        """
        try:
            if self.times:  # check for matching snap_idd first
                if self.first_time:
                    self._set_ref_time(self.times[ts_idx])
                return self.times[ts_idx]

            # try a timestep idd
            trim_str = a_str.replace("IDD", "").strip()
            year = 1950
            month = 1
            day = 1
            hour = 0
            minute = 0
            timepart = trim_str.strip().split(' ')[0]
            if len(timepart) == 12:
                year = int(timepart[:4])
                month = int(timepart[4:6])
                day = int(timepart[6:8])
                hour = int(timepart[8:10])
                minute = int(timepart[10:12])
                ts_time = datetime.datetime(year, month, day, hour, minute, 0)
            elif len(timepart) == 8:
                year = int(timepart[:2])
                year = year + 2000 if year < 50 else year + 1900
                month = int(timepart[2:4])
                day = int(timepart[4:6])
                hour = int(timepart[6:8])
                ts_time = datetime.datetime(year, month, day, hour, minute, 0)
            else:  # If record header is not a timestamp but plain index, ignore the index and increment hourly
                ts_time = datetime.datetime(year, month, day, hour, minute, 0) + datetime.timedelta(hours=ts_idx)
                if self.first_message:
                    self._logger.warning('  Index found instead of timestamp.')
                    self._logger.warning('  SMS will use hourly increments from a default time.')
                    self.first_message = False
            if self.first_time:
                self._set_ref_time(ts_time)
            return ts_time
        except Exception:  # return a regularly spaced interval
            ts_offset = datetime.timedelta(seconds=self.recinc * self.recunits)
            self.currtime = self.currtime + ts_offset
            return self.currtime

    def translate_vals(self, vals):
        """Converts local polar arrows (mag/dir) to world cartesian arrows (x/y).

        Args:
            vals (:obj:`list[float]`): The list containing the coordinate in question at the front.
        """
        # Local polar to world polar to world Cartesian
        vals[1] = self.shore_normal_to_cart(vals[1])
        rad_dir = math.radians(vals[1])
        temp_val = vals[0]
        vals[0] = math.cos(rad_dir) * vals[0]
        vals[1] = math.sin(rad_dir) * temp_val

        if math.isclose(vals[0], 0.0) and math.isclose(vals[1], 0.0):
            vals[0] = 0.0
            vals[1] = 0.0

    def read_wave_file(self):
        """
        Initializes members Wave Height, Wave Period, Wave Direction, and Wave Height by Direction.

        This function will read the list of files passed in and load the data into the members self.wave_height,
        self.wave_period, self.wave_direction, and self.wave_ht_by_dir. It will also return these values.

        Returns:
            (:obj:`tuple(DatasetWriter)`): The wave height, period, direction, and height by direction, as XMDF
            formatted files.
        """
        self.wave_height = DatasetWriter(
            h5_filename=self._get_output_filename(),
            name='Wave Height',
            geom_uuid=self.geom_uuid,
            time_units='Seconds',
            location='cells'
        )
        self.wave_period = DatasetWriter(
            h5_filename=self._get_output_filename(),
            name='Wave Period',
            geom_uuid=self.geom_uuid,
            time_units='Seconds',
            location='cells'
        )
        self.wave_direction = DatasetWriter(
            h5_filename=self._get_output_filename(),
            name='Wave Direction',
            geom_uuid=self.geom_uuid,
            time_units='Seconds',
            location='cells'
        )
        self.wave_ht_by_dir = DatasetWriter(
            h5_filename=self._get_output_filename(),
            name='Wave Height by Direction',
            geom_uuid=self.geom_uuid,
            num_components=2,
            time_units='Seconds',
            location='cells'
        )
        self.translate = 1
        self.ascii_read_wave()
        return self.wave_height, self.wave_period, self.wave_direction, self.wave_ht_by_dir

    def read(self):
        """
        Instantiates the self.dataset member by reading from the file in the self.filename member.

        Args:
        no arguments.

        Returns:
            (:obj:`data_objects.Parameters.Spatial.SpatialVector.Dataset`): The self.dataset member.
        """
        self.dataset = DatasetWriter(
            h5_filename=self._get_output_filename(),
            name=self.name,
            geom_uuid=self.geom_uuid,
            time_units='Seconds',
            location='cells',
            num_components=self.num_components
        )
        # xmdf input no longer supported
        self.ascii_read()
        return self.dataset

    # xmdf input no longer supported
    def xmdf_read(self):
        """
        This function does nothing.

        Args:
        It has no attributes

        Returns:
        It returns nothing.
        """
        if not self.times:  # no times, empty dataset
            return
        pass

    def ascii_read(self):
        """
        This is the helper function for read().

        This function uses the member self.filename to initialize the self.dataset member.
        """
        ts_vals = []
        ts_idx = -1
        idd = ""
        with open(self.filename, "r") as f:
            reading_vals = False
            reading_header = True
            read_vals = 0
            num_vals_to_read = 0
            for line in f:
                line = line.strip().upper()
                if not line or line[0] == "#":  # skip comments and blank lines
                    continue

                if reading_header and not reading_vals and self.parse_ascii_cards(line):  # Still in the header section
                    reading_header = False
                    base_num_to_read = self.ni * self.nj
                    num_vals_to_read = base_num_to_read * self.num_components
                    if self.is_steady_state:
                        reading_vals = True
                    continue
                elif not reading_header and not reading_vals:
                    reading_vals = True
                    if ts_idx > -1:
                        # Put values into vector format if necessary
                        if self.num_components > 1:
                            ts_vals = self._convert_list_to_vector(ts_vals)
                        # Swap the order, because the values start away from the origin, then go down to the origin
                        ts_vals = self._swap_j_order(ts_vals)

                        # Store the timestep before we continue reading the next one, then clear out the lists
                        ts_time = self.parse_time(idd, ts_idx)
                        ts_offset = ts_time - self.reftime
                        ts_secs = ts_offset.total_seconds()
                        if len(ts_vals) == num_vals_to_read / self.num_components:
                            self.dataset.append_timestep(ts_secs, ts_vals)
                        ts_vals = []
                        read_vals = 0  # Reset for next timestep
                    idd = line
                    ts_idx += 1
                elif reading_vals:  # still in the namelists
                    line_data = [float(x) for x in line.split()]
                    if self.num_components == 1:
                        if self.name in ['Depth', 'Transient Depth']:
                            line_data = [dset_val * -1.0 for dset_val in line_data]
                        ts_vals.extend(line_data)
                    else:
                        # Data comes in as X Y, but can be split up by lines (odd number of entries on a line)
                        # so we just read all the data, and then format it into a list of lists (vector) before
                        # storing the dataset.
                        ts_vals.extend(line_data)
                    read_vals += len(line_data)

                if read_vals == num_vals_to_read:
                    reading_vals = False

        # Put values into vector format if necessary
        if self.num_components > 1:
            ts_vals = self._convert_list_to_vector(ts_vals)
        # Swap the order, because the values start away from the origin, then go down to the origin
        ts_vals = self._swap_j_order(ts_vals)

        # add the last timestep
        if self.is_steady_state:
            ts_secs = 0.0
        else:
            ts_time = self.parse_time(idd, ts_idx)
            ts_offset = ts_time - self.reftime
            ts_secs = ts_offset.total_seconds()
        if len(ts_vals) == num_vals_to_read / self.num_components:
            self.dataset.append_timestep(ts_secs, ts_vals)
        self.dataset.appending_finished()

    def _convert_list_to_vector(self, ts_vals):
        """Converts the data passed in to vector format if necessary, based on the number of components.

        Arguments:
            ts_vals (:obj:`list[float]`):  List of X Y X Y... X Y data

        Returns
            (:obj:`list[list[float]]`):  Data formatted into a list of vectors, such as [[X, Y], [X, Y], ... [X, Y]]
        """
        # Put values into vector format if necessary
        if self.num_components > 1:
            node_vals = []
            for i in range(0, len(ts_vals), self.num_components):
                cur_node = []
                for j in range(self.num_components):
                    cur_node.append(ts_vals[i + j])
                cur_node = self.rotate_vector_by_angle(cur_node[0], cur_node[1], self.grid_angle)
                node_vals.append(cur_node)
            ts_vals = node_vals
        return ts_vals

    @staticmethod
    def rotate_vector_by_angle(vx, vy, rotation_angle):
        """Rotates the XY vector passed in by the rotation angle on the Z axis.

        Arguments:
            vx (:obj:`double`): x component of the vector.
            vy (:obj:`double`): y component of the vector.
            rotation_angle (:obj:`double`): angle in degrees to rotate the vector by.

        Returns:
            (:obj:`list[double]`): list of rotated [vx, vy] vector values.
        """
        # Rotate vector about Z Axis by the grid angle:
        # Matrix multiplication for rotation about Z, where t = angle in radians:
        # |cos(t), -sin(t), 0|       |x|     |cos(t) * x + -sin(t) * y + 0 * 0|
        # |sin(t), cos(t),  0|   *   |y|  =  |sin(t) * x + cos(t) * y + 0 * 0 |
        # |0,      0,       1|       |0|     |0 * x + 0 * y + 1 * 0           |
        # x' = vx * cos(t) - vy * sin(t)
        # y' = vx * sin(t) + vy * cos(t)
        theta = math.radians(rotation_angle)
        cs = math.cos(theta)
        sn = math.sin(theta)
        return [vx * cs - vy * sn, vx * sn + vy * cs]

    def parse_ascii_cards(self, line):
        """
        Parse a header card from a line of text.

        Args:
            line (:obj:`str`): A single line of an CMS-Wave ASCII spatial dataset. Should be upper case and trimmed of
                whitespace.

        Returns:
            (:obj:`bool`): True if a card was parsed, False if no known card found.
        """
        if self.name.upper() in ['ETA', 'WAVE', 'WAVE BREAKING', 'BREAK', 'MUD', 'FREF', 'BREF', 'FRIC', 'DEP']:
            # Number of columns, number of rows, cell dimensions
            line_data = line.split()
            self.ni = int(line_data[0])
            self.nj = int(line_data[1])
            return True
        if self.name.upper() in ['RADIATION STRESSES', 'RADS', 'CURR', 'WIND']:
            # Number of columns, number of rows, cell dimensions
            line_data = line.split()
            self.ni = int(line_data[0])
            self.nj = int(line_data[1])
            self._set_num_components(2)
            return True
        return False

    def ascii_read_wave(self):
        """This is a helper function for read_wave_file."""
        height_ts_vals = []
        period_ts_vals = []
        direction_ts_vals = []
        hbd_ts_vals = []
        ts_idx = -1
        idd = ""
        with open(self.filename, "r") as f:
            reading_vals = False
            reading_header = True
            read_vals = 0
            num_vals_to_read = 0
            base_num_to_read = 0
            for line in f:
                line = line.strip().upper()
                if not line or line[0] == "#":  # skip comments and blank lines
                    continue

                if reading_header and not reading_vals and self.parse_ascii_cards(line):  # Still in the header section
                    # After reading first line with num i, num j, calculate number of values to read per timestep
                    reading_header = False
                    base_num_to_read = self.ni * self.nj
                    num_vals_to_read = self.ni * self.nj * 3
                    continue
                elif not reading_header and not reading_vals:
                    reading_vals = True
                    if ts_idx > -1:
                        # Swap the order, because the values start away from the origin, then go down to the origin
                        height_ts_vals = self._swap_j_order(height_ts_vals)
                        period_ts_vals = self._swap_j_order(period_ts_vals)
                        direction_ts_vals = self._swap_j_order(direction_ts_vals)
                        # We read the number of values for a timestep, so calculate the vector dataset
                        for height_val, dir_val in zip(height_ts_vals, direction_ts_vals):
                            node_vals = [height_val, dir_val]
                            self.translate_vals(node_vals)  # Translate, and add grid angle
                            hbd_ts_vals.append(node_vals)

                        # Store the timestep before we continue reading the next one, then clear out the lists
                        ts_time = self.parse_time(idd, ts_idx)
                        ts_offset = ts_time - self.reftime
                        ts_secs = ts_offset.total_seconds()
                        self.wave_height.append_timestep(ts_secs, height_ts_vals)
                        self.wave_period.append_timestep(ts_secs, period_ts_vals)
                        self.wave_direction.append_timestep(ts_secs, direction_ts_vals)
                        self.wave_ht_by_dir.append_timestep(ts_secs, hbd_ts_vals)
                        height_ts_vals = []
                        period_ts_vals = []
                        direction_ts_vals = []
                        hbd_ts_vals = []
                        read_vals = 0  # Reset for next timestep
                    idd = line
                    ts_idx += 1
                    if ts_idx and (ts_idx - 1) % 25 == 0:
                        self._logger.info(f'Reading wave timestep - {ts_idx-1}')
                elif reading_vals:  # still in a timestep
                    line_data = [float(x) for x in line.split()]
                    if read_vals < base_num_to_read:
                        height_ts_vals.extend(line_data)
                    elif base_num_to_read <= read_vals < base_num_to_read * 2:
                        period_ts_vals.extend(line_data)
                    elif base_num_to_read * 2 <= read_vals < base_num_to_read * 3:
                        direction_ts_vals.extend(line_data)
                    read_vals += len(line_data)

                if read_vals == num_vals_to_read:
                    reading_vals = False

        # Swap the order, because the values start away from the origin, then go down to the origin
        height_ts_vals = self._swap_j_order(height_ts_vals)
        period_ts_vals = self._swap_j_order(period_ts_vals)
        direction_ts_vals = self._swap_j_order(direction_ts_vals)
        # We read the number of values for a timestep, so calculate the vector dataset
        for height_val, dir_val in zip(height_ts_vals, direction_ts_vals):
            node_vals = [height_val, dir_val]
            self.translate_vals(node_vals)  # Translate, and add grid angle
            hbd_ts_vals.append(node_vals)

        # add the last timestep
        ts_time = self.parse_time(idd, ts_idx)
        ts_offset = ts_time - self.reftime
        ts_secs = ts_offset.total_seconds()
        self.wave_height.append_timestep(ts_secs, height_ts_vals)
        self.wave_period.append_timestep(ts_secs, period_ts_vals)
        self.wave_direction.append_timestep(ts_secs, direction_ts_vals)
        self.wave_ht_by_dir.append_timestep(ts_secs, hbd_ts_vals)
        self.wave_height.appending_finished()
        self.wave_period.appending_finished()
        self.wave_direction.appending_finished()
        self.wave_ht_by_dir.appending_finished()

    def _swap_j_order(self, values):
        swapped_values = []
        for i in range(self.nj):
            cur_row_values = values[(self.nj - i - 1) * self.ni:(self.nj - i - 1) * self.ni + self.ni]
            swapped_values.extend(cur_row_values)
        return swapped_values
