"""A reader for STWAVE energy files."""

# 1. Standard Python modules
import datetime
import logging
import math
import uuid

# 2. Third party modules

# 3. Aquaveo modules
from xms.coverage.spectral import SpectralCoverage, SpectralGrid
from xms.data_objects.parameters import Coverage, datetime_to_julian, Point, Projection, RectilinearGrid

# 4. Local modules


FULL_PLANE_ANGLES = 72


def get_coords_from_ij(i, j, angle, origin_x, origin_y, dx, dy):
    """
    Given an i-j location, returns the x,y coordinates of the center of the cell.

    Maths...see:
    https://math.stackexchange.com/questions/1384994/rotate-a-point-on-a-circle-with-known-radius-and-position

    Args:
        i (float): I-row grid coordinate
        j (float): J-row grid coordinate
        angle (float): Angle of the grid (in degrees)
        origin_x (float): X coordinate of the grid origin
        origin_y (float): Y coordinate of the grid origin
        dx (float): Grid spacing in the i-direction (must be constant)
        dy (float): Grid spacing in the j-direction (must be constant)

    Returns:
        (:obj:`tuple` of :obj:`float`) Tuple containing the x,y coordinates of the center of the requested cell

    """
    theta = math.radians(angle)
    ptx = ((i - 1) * dx) + origin_x
    ptx += dx * 0.5  # adjust to the middle of the cell
    pty = ((j - 1) * dy) + origin_y
    pty += dy * 0.5  # adjust to the middle of the cell

    cx = origin_x + (ptx - origin_x) * math.cos(theta) - (pty - origin_y) * math.sin(theta)
    cy = origin_y + (ptx - origin_x) * math.sin(theta) + (pty - origin_y) * math.cos(theta)
    return cx, cy


def parse_time(strtime):
    """
    Parse a string into a datetime.datetime.

    Args:
        strtime(string): The string of time data

    Returns:
        (:obj:`datetime.datetime`): The object instantiated with the time data.

    """
    year = int(strtime[:4])
    month = int(strtime[4:6])
    day = int(strtime[6:8])
    hour = int(strtime[8:10])
    minute = int(strtime[10:12])
    return datetime.datetime(year, month, day, hour, minute, 0)


class EngReader:
    """STWAVE Spectral Coverage and frequency reader."""
    def __init__(self, filename, idd_to_time, times=None, rect_grid=None, reftime=None):
        """
        Constructor.

        Args:
            filename (str): Filename of the STWAVE formatted spectral file. If relative, must be from current working
                directory.
            idd_to_time (:obj:`dict` of str:`datetime.datetime` or `int`): for idd_spec_types with snap_idds
            times (:obj:`list` of str, optional): Defaults to None.
            rect_grid (:obj:`data_objects.Parameters.Spatial.SpatialVector.RectilinearGrid`, optional): Default to None.
            reftime (datetime.datetime): simulation reference time
        """
        self.filename = filename.strip().strip("\"").strip("\'")
        self.idd_to_time = idd_to_time  # for idd_spec_types with snap_idds
        self.times = times  # for idd_spec_types without snap_idds
        self.spec_cov = SpectralCoverage()
        self.numfreqs = 0
        self.minfreq = 0.0
        self.deltafreq = 0.0
        self.numangs = 0
        self.numpts = 0
        self.angle = 0.0
        self.spec_grids = {}  # key=stringified pt location, value=(SpectralGrid for this point, point geometry)
        self.coord_sys = "LOCAL"
        self.spzone = -9999
        self.curr_snap = ""  # used so we know when we have moved to a new snap_idd
        today = datetime.date.today()
        self.curr_time = datetime.datetime(today.year, today.month, today.day, 0, 0, 0) if reftime is None else reftime
        self.snap_count = 0  # for idd_spec_types without snap_idds
        self.reftime = None
        self.grid_angle = None
        self.dx = None
        self.dy = None
        self.origin_x = None
        self.origin_y = None
        # Get the grid definition variables for convenience if they are passed in.
        if rect_grid:
            self.grid_angle = rect_grid.angle
            self.dx = 1.0
            self.dy = 1.0
            i_sizes = rect_grid.i_sizes
            if i_sizes:
                self.dx = i_sizes[0]
            j_sizes = rect_grid.j_sizes
            if j_sizes:
                self.dy = j_sizes[0]
            origin = rect_grid.origin
            self.origin_x = origin.x
            self.origin_y = origin.y
        # self.debugger = open("debug.txt", "w")

    def build_spec_cov(self):
        """Builds the coverage geometry, sets its projection, and adds the spectral attributes."""
        # create the coverage geometry
        spec_pts = [v[1] for k, v in self.spec_grids.items()]
        cov_geom = Coverage()
        cov_geom.set_points(spec_pts)
        cov_geom.name = "Spectral"
        # set the geometry's projection
        proj = Projection()
        proj.vertical_units = "METERS"
        proj.horizontal_units = "METERS"
        proj.coordinate_system = self.coord_sys
        proj.coordinate_zone = self.spzone
        cov_geom.projection = proj
        cov_geom.uuid = str(uuid.uuid4())
        cov_geom.complete()
        self.spec_cov.m_cov = cov_geom
        # add the spectral attributes
        for _, v in self.spec_grids.items():
            self.spec_cov.AddSpectralGrid(v[1].id, v[0])

    def read_data_dims(self, f):
        """
        Reads dimensional data from a file to data members.

        Args:
            f (TextIOWrapper): The open-read file containing dimensional data.

        Members Written to:
            numfreqs, numangs, numpts, angle, coord_sys, spzone, reftime
        """
        done = False
        while not done:
            line = f.readline().strip()
            if not line or line.startswith("#") or line.startswith("&"):        # skip comments and blank lines
                continue
            elif line == "/":       # stop on terminating line
                break

            line = line.replace(",", "")
            line_data = [x.strip() for x in line.split("=") if x != ""]
            if line_data[0].upper() == "NUMFREQ":
                self.numfreqs = int(line_data[1])
            elif line_data[0].upper() == "NUMANGLE":
                self.numangs = int(line_data[1])
            elif line_data[0].upper() == "NUMPOINTS":
                self.numpts = int(line_data[1])
            elif line_data[0].upper() == "AZIMUTH":
                self.angle = float(line_data[1])
            elif line_data[0].upper() == "COORD_SYS":
                self.coord_sys = line_data[1].strip('"').strip("'")
            elif line_data[0].upper() == "SPZONE":
                self.spzone = int(line_data[1])
            elif line_data[0].upper() == "REFTIME":
                self.reftime = parse_time(line_data[1].replace("\"", "").replace("\'", ""))

    def read_point_header(self, f):
        """
        Reads in the data from the given file to create a rectilinear grid.

        Args:
            f (TextIOWrapper): The open-read file containing the data for the grid.

        Returns:
            (:obj:`spectral.SpectralGrid`): The resulting grid object.
        """
        line = f.readline()
        if not line:  # EOF
            return None
        line = line.strip()
        while not line or line.startswith("#"):
            line = f.readline().strip()

        # [idd, windspeed, winddir, peakfreq, surge, ptx, pty]
        line_data = line.split()
        idd = line_data[0]
        if idd != self.curr_snap:
            self.curr_snap = idd
            if self.times and self.snap_count < len(self.times):  # no snap_idds
                self.curr_time = self.times[self.snap_count]
                self.snap_count += 1
                if not self.reftime:  # store the first time, it is the reftime
                    self.reftime = self.curr_time
            elif idd in self.idd_to_time:
                self.curr_time = self.idd_to_time[idd]
                if not self.reftime:  # store the first time, it is the reftime
                    self.reftime = self.curr_time
            else:
                try:
                    # try a timestep idd
                    trim_str = line.replace("IDD", "").strip()
                    self.curr_time = parse_time(trim_str)
                except Exception:  # return a regularly spaced interval
                    ts_offset = datetime.timedelta(hours=1)
                    self.curr_time = self.curr_time + ts_offset
                    if not self.reftime:  # store the first time, it is the reftime
                        self.reftime = self.curr_time

                if not self.reftime:  # store the first time, it is the reftime
                    self.reftime = self.curr_time

        # If this is the first idd for a point, create a spectral grid for it.
        pt_loc = f"{line_data[5]} {line_data[6]}"
        if pt_loc not in self.spec_grids:
            # create the point geometry
            pt_x = float(line_data[5])
            pt_y = float(line_data[6])
            # If this is a solution spectral file, point locations are actually grid i,j coords.
            if self.filename.lower().endswith('obse.out'):
                if self.grid_angle is not None:
                    # Translate the i,j location to x,y coordinates.
                    pt_x, pt_y = get_coords_from_ij(pt_x, pt_y, self.grid_angle, self.origin_x, self.origin_y, self.dx,
                                                    self.dy)
            spec_pt = Point(pt_x, pt_y)
            spec_pt.id = len(self.spec_grids) + 1
            spec_grid = SpectralGrid(datetime_to_julian(self.reftime))
            self.spec_grids[pt_loc] = (spec_grid, spec_pt)
            # construct the computational spectral grid
            rect_grid = RectilinearGrid()
            rect_grid.origin = Point(self.minfreq, 0.0, 0.0)
            rect_grid.angle = self.angle
            numangs = self.numangs
            if self.numangs < FULL_PLANE_ANGLES:
                numangs = FULL_PLANE_ANGLES  # convert half-plane to full-plane
            if self.deltafreq < 0.0:
                # build the list of frequencies from the power factor
                # build the list of directions
                freq = []
                next_freq = self.minfreq
                powerfac = -self.deltafreq
                for _ in range(self.numfreqs - 1):
                    freq.append(next_freq * powerfac - next_freq)
                    next_freq *= powerfac
                dir = []
                for _ in range(numangs):
                    dir.append(5.0)
                rect_grid.set_sizes(freq, dir)
            else:
                rect_grid.set_sizes(self.numfreqs - 1, self.deltafreq, numangs, 5.0)
            rect_grid.complete()
            self.spec_grids[pt_loc][0].m_rectGrid = rect_grid
            self.spec_grids[pt_loc][0].m_timeUnits = "Days"
            # if self.numangs == 35:
            #     self.spec_grids[pt_loc][0].m_planeType = PLANE_TYPE_enum.HALF_PLANE
            # else:
            #     self.spec_grids[pt_loc][0].m_planeType = PLANE_TYPE_enum.FULL_LOCAL_PLANE
        return self.spec_grids[pt_loc][0]

    def read_frequencies(self, f):
        """
        Reads the frequencies from a file.

        The minimum frequency, delta frequency, and number of frequencies are initialized with this function.

        Args:
            f (TextIOWrapper): The relavent file opened in read-mode.
        """
        num_specific_frequencies = 0
        first_time = True
        found_delta = False
        num_freqs_read = 0
        while num_freqs_read < self.numfreqs:
            line = f.readline().strip()
            if not line or line.startswith("#"):        # skip comments and blank lines
                continue

            line_data = line.split()
            num_freqs_read += len(line_data)
            if first_time:  # get the minimum frequency
                first_time = False
                self.minfreq = float(line_data[0])
                num_specific_frequencies = 1
                # we used to assume the spectra had a delta distribution of frequencies. I believe STWAVE supports
                # any distribution. This code checks for delta or power distribution.
                # In the future we should support any distribution
                # get the 2nd, 3rd and 4th frequencies  if they are on this line
                if len(line_data) > 1 and self.numfreqs > 1:  # get the second frequency
                    freq2 = float(line_data[1])
                    num_specific_frequencies = 2
                if len(line_data) > 2 and self.numfreqs > 2:  # get the second frequency
                    freq3 = float(line_data[2])
                    num_specific_frequencies = 3
                if len(line_data) > 3 and self.numfreqs > 3:  # get the second frequency
                    freq4 = float(line_data[3])
                    num_specific_frequencies = 4
            # get specific frequencies from the next line(s) if needed
            # to get full test coverage on this code requires files with 1, 2, 3, >3 frequencies per line
            elif num_specific_frequencies < 4:
                if num_specific_frequencies == 1:
                    freq2 = float(line_data[0])
                    num_specific_frequencies = 2
                    if len(line_data) > 1 and self.numfreqs > 1:  # get the second frequency
                        freq3 = float(line_data[1])
                        num_specific_frequencies = 3
                    if len(line_data) > 2 and self.numfreqs > 2:  # get the second frequency
                        freq4 = float(line_data[2])
                        num_specific_frequencies = 4
                elif num_specific_frequencies == 2:
                    freq3 = float(line_data[0])
                    num_specific_frequencies = 3
                    if len(line_data) > 1 and self.numfreqs > 1:  # get the second frequency
                        freq4 = float(line_data[1])
                        num_specific_frequencies = 4
                elif num_specific_frequencies == 3:
                    freq4 = float(line_data[0])
                    num_specific_frequencies = 4

        # now compute deltafreq (for delta or power distribution)
        # getting full test coverage for this would require spectra with 1, 2, 3, >3 frequencies
        # it would also require delta, power and arbitrary frequency distribution to fail correctly
        # handle case of 2 frequencies
        if num_specific_frequencies == 2:
            self.deltafreq = freq2 - self.minfreq
            found_delta = True

        # handle case of 3 frequencies
        if num_specific_frequencies == 3:
            delta1 = freq2 - self.minfreq
            delta2 = freq3 - freq2
            if abs(delta2 - delta1) < 0.0001:
                self.deltafreq = delta1
                found_delta = True
            else:
                power1 = freq2 / self.minfreq
                power2 = freq3 / freq2
                if abs(power2 - power1) < 0.0055:
                    self.deltafreq = -power1
                    found_delta = True

        # handle case of 4 frequencies
        elif num_specific_frequencies >= 4:
            delta1 = freq2 - self.minfreq
            delta2 = freq3 - freq2
            delta3 = freq4 - freq3
            if abs(delta2 - delta1) < 0.0001 and abs(delta3 - delta2) < 0.0001:
                self.deltafreq = delta1
                found_delta = True
            else:
                power1 = freq2 / self.minfreq
                power2 = freq3 / freq2
                power3 = freq4 / freq3
                if abs(power2 - power1) < 0.005 and abs(power3 - power2) < 0.005:
                    self.deltafreq = -power1
                    found_delta = True

        # we still haven't checked for all cases of arbitrary frequency distributions
        # log an error if this doesn't look like a spectra we recognize
        if not found_delta:
            logger = logging.getLogger('xms.stwave')
            logger.error("Unsupported frequency distribution in STWAVE ENG file")

    def read_spectral_vals(self, f):
        """
        Builds spectral coverage from a file.

        Args:
            f (TextIOWrapper): The open-read file containing the spectral data
        """
        num_vals_to_read = self.numangs * self.numfreqs  # this is the number of values per pt idd in the file
        if self.numangs == 35:  # half-plane
            num_total_vals = num_vals_to_read
        else:
            num_total_vals = (self.numangs + 1) * self.numfreqs  # repeat an angle for full plane
        read_vals = 0
        ts_vals = []
        spec_grid = self.read_point_header(f)  # read the first point header
        line = f.readline()
        num_full_ang_vals = (FULL_PLANE_ANGLES + 1) * self.numfreqs
        one_day = datetime.timedelta(days=1)
        while line:
            line = line.strip()
            while not line or line.startswith("#"):  # skip comments and blank lines
                line = f.readline()
                if not line:  # avoid infinite loop if file ends with comment
                    self.build_spec_cov()
                    return
                else:
                    line = line.strip()

            if read_vals == num_vals_to_read:
                # Now we need to transpose the values. Code adapted from STWAVEFileIO::impl::GetSpectralDatasets
                # in the custom interface.
                trans_vals = [0.0 for _ in range(num_total_vals)]
                vals_moved = 0
                for row in range(self.numfreqs):
                    for col in range(self.numangs):
                        idx = (self.numfreqs * col) + row
                        trans_vals[idx] = ts_vals[vals_moved]
                        vals_moved += 1
                    if self.numangs != 35:  # We need to repeat the last angle if this is full plane.
                        last_idx = (self.numfreqs * self.numangs) + row
                        trans_vals[last_idx] = trans_vals[row]

                # Only full-plane spectra in modern SMS. Half-plane STWAVE spectral files always have 5 degree bins
                # from -85 to 85 degrees around the origin of the grid. Fill in remaining angles with 0.0
                if len(trans_vals) < num_full_ang_vals:
                    first_half = trans_vals[:17 * self.numfreqs]
                    second_half = trans_vals[17 * self.numfreqs:]
                    fill_vals = second_half
                    fill_vals.extend([0.0 for _ in range(num_full_ang_vals - len(trans_vals))])
                    fill_vals.extend(first_half)
                    trans_vals = fill_vals
                    # fill_vals = [0.0 for _ in range(num_full_ang_vals - len(trans_vals))]
                    # trans_vals.extend(fill_vals)

                # xmscoverage lies. It uses the data_objects default for time units, so it wants the time step offset
                # in days.
                ts_offset = (self.curr_time - self.reftime) / one_day
                spec_grid.add_timestep(ts_offset, trans_vals)
                ts_vals = []
                read_vals = 0
                spec_grid = self.read_point_header(f)  # read the next point header
            else:  # keep reading data values for this idd/point
                line_data = [float(x) for x in line.split()]
                ts_vals.extend(line_data)
                read_vals += len(line_data)

            if read_vals < num_vals_to_read:
                line = f.readline()

        self.build_spec_cov()

    def read(self):
        """
        Calls the self.ascii_read function.

        Returns:
            (:obj:spectral.SpectralCoverage, int, float, float): Spectral Coverage and frequency data.
        """
        # xmdf format no longer supported
        return self.ascii_read()

    # deprecated
    def xmdf_read(self):
        """
        This function returns the same thing no matter what.

        Returns:
            (None, int, float, float): None, 30, 0.04, 0.01
        """
        if not self.times and not self.idd_to_time:  # no times, empty dataset
            return None, 30, 0.04, 0.01
        return None, 30, 0.04, 0.01

    def ascii_read(self):
        """
        Reads frequency data, dimensional data and spectral coverage from a file.

        Returns:
            (:obj:spectral.SpectralCoverage, int, float, float): spectral coverage and frequency data
        """
        # if not self.times and not self.idd_to_time:  # no times, empty dataset
        #     return None, 30, 0.04, 0.01

        # .eng files are grouped by snap_idd with an entry for each point per snap_idd
        with open(self.filename, "r") as f:
            # read the namelists
            self.read_data_dims(f)
            self.read_frequencies(f)
            # read the spectral values
            self.read_spectral_vals(f)

        return self.spec_cov, self.numfreqs, self.minfreq, self.deltafreq
