"""A reader for spatial datasets in STWAVE."""

# 1. Standard Python modules
import copy
import datetime
import math

# 2. Third party modules

# 3. Aquaveo modules
from xms.core.filesystem import filesystem as io_util
from xms.datasets.dataset_writer import DatasetWriter

# 4. Local modules


class DatasetReaderSTWAVE:
    """STWAVE spatial dataset reader for input and solution files."""

    def __init__(self, dset_name, filename, ts_times, geom_uuid, angle=0.0, reftime=None):
        """
        Constructor.

        Args:
            dset_name (str): Name of the dataset.
            filename (str): Filename of the STWAVE formatted dataset file. If relative, must be from current working
                directory.
            ts_times (:obj:`list` of :obj:`datetime.datetime`): Timestep times if not in the STWAVE file. IDD
                timestamps used if empty.
            geom_uuid (str): UUID of the dataset's geometry in SMS - not always needed.
            angle (float): used when a_translate=0.
            reftime (datetime.datetime): simulation reference time
        """
        self.name = dset_name
        self.filename = filename.strip().strip("\"").strip("\'")
        self.times = ts_times
        self.geom_uuid = geom_uuid
        self.dataset = None
        self.num_components = 1
        self.recinc = 60 * 60  # default to hour timesteps
        self.recunits = 1  # convert everything to seconds - timestep interval in hours = recinc * recunits
        today = datetime.date.today()  # default reftime to current day
        self.reftime = datetime.datetime(today.year, today.month, today.day, 0, 0, 0) if reftime is None else reftime
        self.currtime = self.reftime
        self.first_time = True
        self.grid_angle = angle  # used when converting from local to global

        # special case for stupid wave datasets
        self.wave_height = None
        self.wave_period = None
        self.wave_direction = None
        self.wave_ht_by_dir = None
        self.ni = 0
        self.nj = 0

        # set up constants used for translations
        rad = math.radians(self.grid_angle)
        self.COSANG = math.cos(rad)
        self.SINANG = math.sin(rad)

    def _reorder_timestep(self, vals):
        """Reorder timestep values to i-ascending and j-descending order to match cell index order.

        Args:
            vals (Sequence): The timestep values in i-ascending, j-descending order.

        Returns:
            Sequence: The timestep values in cell index order
        """
        ordered_vals = copy.deepcopy(vals)
        cell_idx = 0
        for j in range(self.nj, 0, -1):  # j-descending
            for i in range(self.ni):  # i-ascending
                ordered_vals[cell_idx] = vals[((j - 1) * self.ni) + i]
                cell_idx += 1
        return ordered_vals

    def _get_output_filename(self):
        """Returns a filename suitable for writing an XMDF dataset."""
        # TODO: Do something to make this work with tests.
        return io_util.temp_filename()

    def _set_ref_time(self, reftime):
        """Set the reference date for the dataset writers as soon as we know it.

        Need to set this on the dataset writers before appending the first time step.

        Args:
            reftime (datetime.datetime): The dataset's reference date
        """
        self.reftime = reftime
        if self.dataset:
            self.dataset.ref_time = reftime
        if self.wave_height:
            self.wave_height.ref_time = reftime
        if self.wave_period:
            self.wave_period.ref_time = reftime
        if self.wave_direction:
            self.wave_direction.ref_time = reftime
        if self.wave_ht_by_dir:
            self.wave_ht_by_dir.ref_time = reftime
        self.first_time = False

    def _set_num_components(self, num_components):
        """Set the number of components for the dataset writer as soon as we know it.

        Args:
            num_components (int): The number of dataset components (1=scalar, 2=vector)
        """
        self.num_components = num_components
        if self.dataset:  # Only need to update the generic writer, the wave writers already know what they are.
            self.dataset.num_components = num_components

    def shore_normal_to_cart(self, angle):
        """
        This converts a shore-normal degree angle to a cartesian angle.

        Args:
            angle (float): The shore-normal angle to be converted in degrees

        Returns:
            (float): The cartesian angle conversion in degrees
        """
        angle += self.grid_angle
        while angle < 0.0:  # normalize 0-360
            angle += 360.0
        while angle >= 360.0:
            angle -= 360.0
        return angle

    def parse_time(self, a_str, ts_idx):
        """
        Parse a string into a datetime.datetime.

        This function will take priority of reading from the self.times list member instead of the a_str parameter
        if it is the first time using the function on an instance of DatasetReader. Following times will read
        from the string passed in a_str.

        Args:
            a_str (string): A string formatted with the date-time information.
            ts_idx (int): An index to access the self.times member

        Returns:
            (:obj:`datetime.datetime`): The timestep value
        """
        try:
            if self.times:  # check for matching snap_idd first
                if self.first_time:
                    self._set_ref_time(self.times[ts_idx])
                return self.times[ts_idx]

            # try a timestep idd
            trim_str = a_str.replace("IDD", "").strip()
            year = int(trim_str[:4])
            month = int(trim_str[4:6])
            day = int(trim_str[6:8])
            hour = int(trim_str[8:10])
            minute = int(trim_str[10:])
            ts_time = datetime.datetime(year, month, day, hour, minute, 0)
            if self.first_time:
                self._set_ref_time(ts_time)
            return ts_time
        except Exception:  # return a regularly spaced interval
            ts_offset = datetime.timedelta(seconds=self.recinc * self.recunits)
            self.currtime = self.currtime + ts_offset
            return self.currtime

    def translate_vals(self, vals):
        """Converts local polar arrows (mag/dir) to world cartesian arrows (x/y).

        Args:
            vals (:obj:`list` of float): The list containing the coordinate in question at the front.
        """
        # Local polar to world polar to world Cartesian
        vals[1] = self.shore_normal_to_cart(vals[1])
        rad_dir = math.radians(vals[1])
        temp_val = vals[0]
        vals[0] = math.cos(rad_dir) * vals[0]
        vals[1] = math.sin(rad_dir) * temp_val

        if math.isclose(vals[0], 0.0) and math.isclose(vals[1], 0.0):
            vals[0] = 0.0
            vals[1] = 0.0

    def read_wave_file(self):
        """
        Initializes members Wave Height, Wave Period, Wave Direction, and Wave Height by Direction.

        This function will read the list of files passed in and load the data into the members self.wave_height,
        self.wave_period, self.wave_direction, and self.wave_ht_by_dir. It will also return these values.

        Returns:
            tuple of DatasetWriter: The wave height, period, direction, and height by direction as XMDF formatted
                files.
        """
        self.wave_height = DatasetWriter(h5_filename=self._get_output_filename(), name='Wave Height',
                                         geom_uuid=self.geom_uuid, time_units='Seconds', location='cells')
        self.wave_period = DatasetWriter(h5_filename=self._get_output_filename(), name='Wave Period',
                                         geom_uuid=self.geom_uuid, time_units='Seconds', location='cells')
        self.wave_direction = DatasetWriter(h5_filename=self._get_output_filename(), name='Wave Direction',
                                            geom_uuid=self.geom_uuid, time_units='Seconds', location='cells')
        self.wave_ht_by_dir = DatasetWriter(h5_filename=self._get_output_filename(), name='Wave Height by Direction',
                                            geom_uuid=self.geom_uuid, num_components=2, time_units='Seconds',
                                            location='cells')
        self.translate = 1
        self.ascii_read_wave()
        return self.wave_height, self.wave_period, self.wave_direction, self.wave_ht_by_dir

    def read(self):
        """
        Instantiates the self.dataset member by reading from the file in the self.filename member.

        Args:
            no arguments.

        Returns:
            (:obj:`data_objects.Parameters.Spatial.SpatialVector.Dataset`): The self.dataset member.
        """
        self.dataset = DatasetWriter(h5_filename=self._get_output_filename(), name=self.name, geom_uuid=self.geom_uuid,
                                     time_units='Seconds', location='cells', num_components=self.num_components)
        # xmdf input no longer supported
        self.ascii_read()
        return self.dataset

    # xmdf input no longer supported
    def xmdf_read(self):
        """
        This function does nothing.

        Args:
            It has no attributes

        Returns:
            It returns nothing.
        """
        if not self.times:  # no times, empty dataset
            return
        pass

    def ascii_read(self):
        """
        This is the helper function for read().

        This function uses the member self.filename to initialize the self.dataset member.
        """
        ts_vals = []
        ts_idx = -1
        idd = ""
        with open(self.filename, "r") as f:
            reading_vals = False
            for line in f:
                line = line.strip().upper()
                if not line or line[0] == "#":  # skip comments and blank lines
                    continue

                if not reading_vals and self.parse_ascii_cards(line):  # Still in the header section
                    continue
                elif line.startswith("IDD"):  # new timestep
                    reading_vals = True
                    if ts_idx > -1:
                        ts_time = self.parse_time(idd, ts_idx)
                        ts_offset = ts_time - self.reftime
                        ts_secs = ts_offset.total_seconds()
                        self.dataset.append_timestep(ts_secs, self._reorder_timestep(ts_vals))
                        ts_vals = []
                    idd = line
                    ts_idx += 1
                elif reading_vals:  # still in the namelists
                    line_data = line.split()
                    if self.num_components == 1:
                        dset_val = float(line_data[0])
                        if self.name in ['Depth', 'Transient Depth']:  # Depths below sea level are negative in SMS
                            dset_val = 0.0 if dset_val == -9999.0 else dset_val * -1.0
                        ts_vals.append(dset_val)
                    else:
                        node_vals = []
                        for i in range(self.num_components):
                            node_vals.append(float(line_data[i]))
                        # Rotate the vector by the grid angle
                        node_vals = self.rotate_vector_by_angle(node_vals[0], node_vals[1], self.grid_angle)
                        ts_vals.append(node_vals)

        # add the last timestep
        ts_time = self.parse_time(idd, ts_idx)
        ts_offset = ts_time - self.reftime
        ts_secs = ts_offset.total_seconds()
        self.dataset.append_timestep(ts_secs, self._reorder_timestep(ts_vals))
        self.dataset.appending_finished()

    @staticmethod
    def rotate_vector_by_angle(vx, vy, rotation_angle):
        """Rotates the XY vector passed in by the rotation angle on the Z axis.

        Arguments:
            vx (double): x component of the vector.
            vy (double): y component of the vector.
            rotation_angle (double): angle in degrees to rotate the vector by.

        Returns:
            (list of double): list of rotated [vx, vy] vector values.
        """
        # Rotate vector about Z Axis by the grid angle:
        # Matrix multiplication for rotation about Z, where t = angle in radians:
        # |cos(t), -sin(t), 0|       |x|     |cos(t) * x + -sin(t) * y + 0 * 0|
        # |sin(t), cos(t),  0|   *   |y|  =  |sin(t) * x + cos(t) * y + 0 * 0 |
        # |0,      0,       1|       |0|     |0 * x + 0 * y + 1 * 0           |
        # x' = vx * cos(t) - vy * sin(t)
        # y' = vx * sin(t) + vy * cos(t)
        theta = math.radians(rotation_angle)
        cs = math.cos(theta)
        sn = math.sin(theta)
        return [vx * cs - vy * sn, vx * sn + vy * cs]

    def parse_ascii_cards(self, line):
        """
        Parse a header card from a line of text.

        Args:
            line (str): A single line of an STWAVE ASCII spatial dataset. Should be upper case and trimmed of
                whitespace.

        Returns:
            (bool): True if a card was parsed, False if no known card found.
        """
        if line.startswith("NUMFLDS"):  # get the number of components
            self._set_num_components(int(line.split("=")[1].replace(",", "").strip()))
        elif line.startswith("RECINC"):  # regularly-spaced time intervals
            self.recinc = int(line.split("=")[1].replace(",", "").strip())
        elif line.startswith("RECUNITS"):  # regularly-spaced time intervals
            units = line.split("=")[1].replace(",", "").replace("'", "").replace("\"", "").strip()
            if "MM" in units:
                self.recunits = 60
            elif "HH" in units:
                self.recunits = 60 * 60
            elif "DD" in units:
                self.recunits = 60 * 60 * 24
        elif line.startswith("NI"):
            self.ni = int(line.split("=")[1].replace(",", "").strip())
        elif line.startswith("NJ"):
            self.nj = int(line.split("=")[1].replace(",", "").strip())
        else:
            return False  # End of header section or error
        return True

    def ascii_read_wave(self):
        """This is a helper function for read_wave_file."""
        height_ts_vals = []
        period_ts_vals = []
        direction_ts_vals = []
        hbd_ts_vals = []
        ts_idx = -1
        idd = ""
        with open(self.filename, "r") as f:
            reading_vals = False
            for line in f:
                line = line.strip().upper()
                if not line or line[0] == "#":  # skip comments and blank lines
                    continue

                if not reading_vals and self.parse_ascii_cards(line):  # Still in the header section
                    continue
                elif line.startswith("IDD"):  # new timestep
                    reading_vals = True
                    if ts_idx > -1:
                        ts_time = self.parse_time(idd, ts_idx)
                        ts_offset = ts_time - self.reftime
                        ts_secs = ts_offset.total_seconds()
                        self.wave_height.append_timestep(ts_secs, self._reorder_timestep(height_ts_vals))
                        self.wave_period.append_timestep(ts_secs, self._reorder_timestep(period_ts_vals))
                        self.wave_direction.append_timestep(ts_secs, self._reorder_timestep(direction_ts_vals))
                        self.wave_ht_by_dir.append_timestep(ts_secs, self._reorder_timestep(hbd_ts_vals))
                        height_ts_vals = []
                        period_ts_vals = []
                        direction_ts_vals = []
                        hbd_ts_vals = []
                    idd = line
                    ts_idx += 1
                elif reading_vals:  # still in a timestep
                    line_data = line.split()
                    height_ts_vals.append(float(line_data[0]))
                    period_ts_vals.append(float(line_data[1]))
                    direction_ts_vals.append(float(line_data[2]))
                    node_vals = [float(line_data[0]), float(line_data[2])]
                    # height by local direction to world (Vx/Vy)
                    self.translate_vals(node_vals)
                    hbd_ts_vals.append(node_vals)

        # add the last timestep
        ts_time = self.parse_time(idd, ts_idx)
        ts_offset = ts_time - self.reftime
        ts_secs = ts_offset.total_seconds()
        self.wave_height.append_timestep(ts_secs, self._reorder_timestep(height_ts_vals))
        self.wave_period.append_timestep(ts_secs, self._reorder_timestep(period_ts_vals))
        self.wave_direction.append_timestep(ts_secs, self._reorder_timestep(direction_ts_vals))
        self.wave_ht_by_dir.append_timestep(ts_secs, self._reorder_timestep(hbd_ts_vals))
        self.wave_height.appending_finished()
        self.wave_period.appending_finished()
        self.wave_direction.appending_finished()
        self.wave_ht_by_dir.appending_finished()
