"""Module for writing bctides.in file."""

__copyright__ = "(C) Copyright Aquaveo 2023"
__license__ = "All rights reserved"

# 1. Standard Python modules
from pathlib import Path
from typing import TextIO

# 2. Third party modules

# 3. Aquaveo modules
from xms.gmi.data.generic_model import Section

# 4. Local modules
from xms.schism.data.model import elevation_flag, get_model, salinity_flag, temperature_flag, velocity_flag
from .bctides_file import BcTidesFile


def write_bctides(bctides: BcTidesFile, path: Path | str):
    """
    Write a bctides.in file.

    Args:
        bctides: File contents to write.
        path: Where to write the file to.
    """
    with BcTidesWriter(bctides, path) as writer:
        writer.write_constituents()
        writer.write_forcing_frequencies()
        writer.write_boundaries()


class BcTidesWriter:
    """Helper class for writing a bctides.in file."""
    def __init__(self, bctides: BcTidesFile, path: str | Path):
        """Constructor."""
        self._path = Path(path)
        self._file: TextIO
        self._cutoff_depth = bctides.cutoff_depth
        self._boundaries = bctides.open_boundaries
        self._values = bctides.values
        self._properties = bctides.forcing_frequencies
        self._elevation = bctides.elevation
        self._velocity = bctides.velocity

    def __enter__(self):
        """Context manager enter method."""
        self._file = open(self._path, 'w')
        self._file.write(f'{self._path.name}\n')
        return self

    def __exit__(self, _exc_type, _exc_value, _exc_tb):
        """Context manager exit method."""
        self._file.close()

    def write_constituents(self):
        """
        Write the tidal constituents.

        Currently unimplemented.
        """
        self._file.write(f'0 {self._cutoff_depth} ! ntip\n')

    def write_forcing_frequencies(self):
        """
        Write the forcing frequencies.

        Frequencies include the tide's name, angular frequency, nodal factor, and earth equilibrium argument. This is
        the second section of the file.
        """
        properties = self._properties
        names = properties['name']
        frequencies = properties['frequency']
        factors = properties['factor']
        arguments = properties['argument']

        # These don't come through in a stable order, which makes testing inconvenient.
        # We'll sort them here so the output is stable.
        items = [
            (name, frequency, factor, argument) for name, frequency, factor, argument in
            zip(names.values, frequencies.values, factors.values, arguments.values)
        ]
        items.sort()

        self._file.write(f'{len(names)}  ! nbfr\n')

        for name, frequency, factor, argument in items:
            self._file.write(f'{name}\n')
            self._file.write(f'  {frequency} {factor} {argument}\n')

    def write_boundaries(self):
        """
        Write the open boundaries.

        This is basically all the boundary conditions. It's the third section of the file.
        """
        section = get_model().arc_parameters
        self._file.write(f'{len(self._boundaries)} ! nope\n')

        for i in range(len(self._boundaries)):
            value = self._values[i]
            section.restore_values(value)

            length = len(self._boundaries[i])
            flags = [
                length,
                elevation_flag(section),
                velocity_flag(section),
                temperature_flag(section),
                salinity_flag(section)
            ]
            flags = [str(flag) for flag in flags]
            flag_string = ' '.join(flags)

            self._file.write(f'{flag_string}\n')

            self._write_elevation_section(i, section)
            self._write_velocity_section(i, section)
            self._write_temperature_section(i, section)
            self._write_salinity_section(i, section)

    def _write_elevation_section(self, arc_id: int, section: Section):
        """
        Write the elevation section for a boundary condition, if necessary.

        Args:
            arc_id: The ID (index) of the arc to write.
            section: Values for the arc.
        """
        wse_type = section.group('open').parameter('wse-type').value
        if wse_type == '0':
            pass  # Nothing to write for type 0
        elif wse_type == '3':
            self._write_elevation_3(arc_id)
        else:
            raise AssertionError(f'Unknown wse_type: {wse_type}')  # pragma: nocover

    def _write_elevation_3(self, arc_id: int):
        """
        Write a type 3 elevation section.

        Args:
            arc_id: ID (index) of the arc to write.
        """
        nodes = self._boundaries[arc_id]
        for tide_name in sorted(self._elevation['name'].values):
            self._file.write(f'{tide_name} !elevation\n')
            for node in nodes:
                ds = self._elevation.sel({'name': tide_name, 'node': node})
                amplitude = ds['amplitude'].item(0)
                phase = ds['phase'].item(0)
                self._file.write(f'  {amplitude}  {phase}\n')

    def _write_velocity_section(self, arc_id: int, section: Section):
        """
        Write the velocity section for a boundary condition, if necessary.

        Args:
            arc_id: The ID (index) of the arc to write.
            section: Values for the arc.
        """
        flow_type = section.group('open').parameter('flow-type').value
        if flow_type in ['0', '1']:
            pass  # Nothing to write for these types
        elif flow_type == '2':
            self._write_flow_2(section)
        elif flow_type == '3':
            self._write_flow_3(arc_id)
        else:
            raise AssertionError(f'Unknown wse_type: {flow_type}')  # pragma: nocover

    def _write_flow_2(self, section: Section):
        """
        Write a type 2 flow section.

        Args:
            section: Parameter values for this arc.
        """
        flow = section.group('open').parameter('flow-constant').value
        self._file.write(f'{flow}\n')

    def _write_flow_3(self, arc_id: int):
        """
        Write a type 3 flow section.

        Args:
            arc_id: ID (index) of the arc to write.
        """
        nodes = self._boundaries[arc_id]
        for tide_name in self._velocity['name'].values:
            self._file.write(f'{tide_name} !velocity\n')
            for node in nodes:
                ds = self._velocity.sel({'name': tide_name, 'node': node})
                phase_x = ds['phase_x'].item(0)
                phase_y = ds['phase_y'].item(0)
                amp_x = ds['amp_x'].item(0)
                amp_y = ds['amp_y'].item(0)
                self._file.write(f'  {amp_x:.7} {phase_x} {amp_y:.7} {phase_y}\n')

    def _write_temperature_section(self, arc_id: int, section: Section):
        """
        Write the temperature section for a boundary condition, if necessary.

        Args:
            arc_id: The ID (index) of the arc to write.
            section: Values for the arc.
        """
        temperature_type = section.group('open').parameter('temperature-type').value
        if temperature_type == '0':
            pass  # Nothing to write for TYPE_0
        elif temperature_type in ['1', '3']:
            self._write_temperature_nudging(section)
        else:
            raise AssertionError(f'Unknown temperature type: {temperature_type}')  # pragma: nocover

    def _write_temperature_nudging(self, section: Section):
        """
        Write temperature nudging.

        Args:
            section: Parameter values for this arc.
        """
        nudge = section.group('open').parameter('temperature-nudging').value
        self._file.write(f'{nudge} !temperature nudging\n')

    def _write_salinity_section(self, arc_id: int, section: Section):
        """
        Write the salinity section for a boundary condition, if necessary.

        Args:
            arc_id: The ID (index) of the arc to write.
            section: Values for the arc.
        """
        salinity_type = section.group('open').parameter('salinity-type').value
        if salinity_type == '0':
            pass  # Nothing to write for TYPE_0
        elif salinity_type == '3':
            self._write_salinity_nudging(section)
        else:
            raise AssertionError(f'Unknown salinity type: {salinity_type}')  # pragma: nocover

    def _write_salinity_nudging(self, section: Section):
        """
        Write salinity nudging.

        Args:
            section: Parameter values for this arc.
        """
        nudge = section.group('open').parameter('salinity-nudging').value
        self._file.write(f'{nudge} !salinity nudging\n')
