"""Module for reading a bctides.in file."""

__copyright__ = "(C) Copyright Aquaveo 2024"
__license__ = "All rights reserved"

# 1. Standard Python modules
from pathlib import Path
from typing import Sequence

# 2. Third party modules
import numpy as np
import xarray as xr

# 3. Aquaveo modules
from xms.gmi.data.generic_model import Section

# 4. Local modules
from .bctides_file import BcTidesFile
from ..reader_base import ReaderBase
from ...data.model import get_model


def read_bctides(path: Path | str, open_boundaries: Sequence[Sequence[int]]) -> BcTidesFile:
    """
    Read a bctides.in file.

    Args:
        path: Path to the file to read.
        open_boundaries: Sequence of arcs, where each arc is a sequence of node IDs defining the arc. Used to construct
            the tidal database.

    Returns:
        The file contents.
    """
    with BcTidesReader(path, open_boundaries) as reader:
        reader.read()
        return reader.bctides


class BcTidesReader(ReaderBase):
    """
    A class for reading SCHISM's bctides.in format.
    """
    def __init__(self, path: Path | str, open_boundaries: Sequence[Sequence[int]]):
        """
        Initialize the reader.

        Args:
            path: The file to be read.
            open_boundaries: Sequence of arcs, where each arc is a sequence of node IDs defining the arc. Used to
            construct the tidal database.
        """
        super().__init__(path, comment_markers=['!'])
        # The start of a boundary has flags, which are validated as their respective sections are read.
        # This means that, for example, if the flag for section 2 is bad, it won't be reported until
        # after parsing section 1. At this time, the original line number is lost. This allows reporting
        # the correct line so error messages are less confusing.
        self._boundary_start_line = 0
        self._num_frequencies = 0
        self._num_constituents = 0
        self._elevation_datasets = []
        self._flow_datasets = []
        self._open_boundary_strings = open_boundaries
        self.bctides = BcTidesFile()

    def read(self):
        """Read the file."""
        self._read_line()  # The first line of the file is implicitly a comment.
        self._num_constituents, cutoff_depth = self._parse_line(int, float)
        self.bctides.cutoff_depth = cutoff_depth

        self._read_constituents()

        self._num_frequencies = self._parse_line(int)
        self._read_frequencies()

        num_boundaries = self._parse_line(int)
        self._read_boundaries(num_boundaries)

        if self._elevation_datasets:
            self.bctides.elevation = xr.concat(self._elevation_datasets, 'node')
        else:
            self.bctides.elevation = _make_elevation_dataset([], [])
        if self._flow_datasets:
            self.bctides.velocity = xr.concat(self._flow_datasets, 'node')
        else:
            self.bctides.velocity = _make_flow_dataset([], [])

    def _read_constituents(self):
        """Read the constituents section of the file."""
        if self._num_constituents != 0:
            raise self._error('ntip>0 is unsupported')

    def _read_frequencies(self):
        """Read the forcing frequency section of the file."""
        num_frequencies = self._num_frequencies
        names = np.empty(num_frequencies, dtype=object)
        frequencies = np.empty(num_frequencies, dtype=float)
        factors = np.empty(num_frequencies, dtype=float)
        arguments = np.empty(num_frequencies, dtype=float)

        for i in range(num_frequencies):
            name = self._parse_line(str)
            frequency, factor, argument = self._parse_line(float, float, float)
            names[i] = name
            frequencies[i] = frequency
            factors[i] = factor
            arguments[i] = argument

        coords = {'name': names}
        data_vars = {
            'frequency': ('name', frequencies),
            'factor': ('name', factors),
            'argument': ('name', arguments),
        }

        dataset = xr.Dataset(data_vars=data_vars, coords=coords)
        self.bctides.forcing_frequencies = dataset

    def _read_boundaries(self, num_boundaries):
        """
        Read all the boundaries in the file.

        Args:
            num_boundaries: Number of boundaries to read.
        """
        self.bctides.values = np.empty(num_boundaries, dtype=object)
        section = get_model().arc_parameters

        for i in range(num_boundaries):
            section.clear_values()
            section.group('open').is_active = True
            self._read_boundary(i, section)
            self.bctides.values[i] = section.extract_values()

    def _read_boundary(self, boundary_index: int, section: Section):
        """
        Read a boundary.

        Args:
            boundary_index: Index of the boundary being read.
            section: Where to put the values.
        """
        _length, wse_type, flow_type, temperature_type, salinity_type = self._parse_line(int, int, int, int, int)

        self._boundary_start_line = self._line_number

        self._read_wse(boundary_index, wse_type, section)
        self._read_flow(boundary_index, flow_type, section)
        self._read_temperature(temperature_type, section)
        self._read_salinity(salinity_type, section)

    def _read_wse(self, boundary_index: int, wse_type: int, section: Section):
        """
        Read the elevation section of a boundary.

        Args:
            boundary_index: Index of the boundary being read.
            wse_type: Type of elevation being read.
            section: Where to put the values.
        """
        section.group('open').parameter('wse-type').value = str(wse_type)
        if wse_type == 0:
            pass  # Nothing to read
        elif wse_type == 3:
            self._read_wse_3(boundary_index)
        else:
            raise self._error(f'Unknown wse_type {wse_type}', line_number=self._boundary_start_line)

    def _read_wse_3(self, boundary_index: int):
        """
        Read a type 3 elevation section of a boundary.

        Args:
            boundary_index: Index of the boundary being read.
        """
        nodes = self._open_boundary_strings[boundary_index]
        names = self.bctides.forcing_frequencies['name']
        ds = _make_elevation_dataset(nodes, names)

        for _ in range(len(names)):
            name = self._parse_line(str)
            for node in nodes:
                amplitude, phase = self._parse_line(float, float)
                ds['amplitude'].loc[dict(name=name, node=node)] = amplitude
                ds['phase'].loc[dict(name=name, node=node)] = phase

        self._elevation_datasets.append(ds)

    def _read_flow(self, boundary_index: int, flow_type: int, section: Section):
        """
        Read the flow section of a boundary.

        Args:
            boundary_index: Index of the boundary being read.
            flow_type: Which type of flow to read.
            section: Where to put the values.
        """
        section.group('open').parameter('flow-type').value = str(flow_type)
        if flow_type in [0, 1]:
            pass  # Nothing to read
        elif flow_type == 2:
            self._read_double_parameter('flow-constant', section)
        elif flow_type == 3:
            self._read_flow_3(boundary_index)
        else:
            raise self._error(f'Unknown flow_type {flow_type}', line_number=self._boundary_start_line)

    def _read_flow_3(self, boundary_index: 3):
        """Read flow section for flow type 3."""
        nodes = self._open_boundary_strings[boundary_index]
        names = self.bctides.forcing_frequencies['name']

        ds = _make_flow_dataset(nodes, names)

        for _ in range(len(names)):
            name = self._parse_line(str)
            for node in nodes:
                ax, px, ay, py = self._parse_line(float, float, float, float)
                ds['amp_x'].loc[dict(name=name, node=node)] = ax
                ds['phase_x'].loc[dict(name=name, node=node)] = px
                ds['amp_y'].loc[dict(name=name, node=node)] = ay
                ds['phase_y'].loc[dict(name=name, node=node)] = py

        self._flow_datasets.append(ds)

    def _read_temperature(self, temperature_type: int, section: Section):
        """
        Read the temperature section of a boundary.

        Args:
            temperature_type: Which type of temperature to read.
            section: Where to put the values.
        """
        section.group('open').parameter('temperature-type').value = str(temperature_type)
        if temperature_type == 0:
            pass  # Nothing to read
        elif temperature_type == 1:
            self._read_double_parameter('temperature-nudging', section)
        elif temperature_type == 3:
            self._read_double_parameter('temperature-nudging', section)
        else:
            raise self._error(f'Unknown temperature_type {temperature_type}', line_number=self._boundary_start_line)

    def _read_salinity(self, salinity_type: int, section: Section):
        """
        Read the salinity section of a boundary.

        Args:
            salinity_type: Which type of salinity to read.
            section: Where to put the values.
        """
        section.group('open').parameter('salinity-type').value = str(salinity_type)
        if salinity_type == 0:
            pass  # Nothing to read
        elif salinity_type == 3:
            self._read_double_parameter('salinity-nudging', section)
        else:
            raise self._error(f'Unknown salinity_type {salinity_type}', line_number=self._boundary_start_line)

    def _read_double_parameter(self, name: str, section: Section):
        """
        Read a double and put it into the section's 'open' group.

        Args:
            name: Name of the parameter to apply the value to.
            section: Section to put the value in.
        """
        value = self._parse_line(float)
        section.group('open').parameter(name).value = value


def _make_elevation_dataset(nodes: Sequence[int], names: Sequence[str]) -> xr.Dataset:
    """
    Make an elevation dataset.

    Dataset has coords name and node, with values from names and nodes, respectively. It contains one variable,
    amplitude.

    Args:
        nodes: Nodes in the dataset.
        names: Names of tides in the dataset.

    Returns:
        The elevation dataset.
    """
    amplitudes = np.ndarray((len(names), len(nodes)), dtype=float)
    phases = np.ndarray((len(names), len(nodes)), dtype=float)
    coords = {'name': names, 'node': nodes}
    data = {'amplitude': (('name', 'node'), amplitudes), 'phase': (('name', 'node'), phases)}
    ds = xr.Dataset(coords=coords, data_vars=data)
    return ds


def _make_flow_dataset(nodes: Sequence[int], names: Sequence[str]) -> xr.Dataset:
    """
    Make a flow dataset.

    Dataset contains amp_x, phase_x, amp_y, and phase_y variables. Each one is two-dimensional, with
    coords name and node.

    Args:
        nodes: Nodes in the dataset.
        names: Tide names in the dataset.

    Returns:
        The flow dataset.
    """
    amp_x = np.ndarray((len(names), len(nodes)), dtype=float)
    phase_x = np.ndarray((len(names), len(nodes)), dtype=float)
    amp_y = np.ndarray((len(names), len(nodes)), dtype=float)
    phase_y = np.ndarray((len(names), len(nodes)), dtype=float)

    coords = {'name': names, 'node': nodes}
    data = {
        'amp_x': (('name', 'node'), amp_x),
        'phase_x': (('name', 'node'), phase_x),
        'amp_y': (('name', 'node'), amp_y),
        'phase_y': (('name', 'node'), phase_y),
    }
    ds = xr.Dataset(coords=coords, data_vars=data)
    return ds
