r"""Writes the `CHAN_POINT_INPUT <cpi>`_ (.ihg) file, which defines point inputs.

.. _cpi: https://www.gsshawiki.com/Surface_Water_Routing:Introducing_Dischage/Constituent_Hydrographs

::

    NUMPT 2
    POINT 1 2 0.0
    POINT 1 1 0.0
    NRPDS 2
    INPUT 2024 01 30 13 20 11.000000 77.000000
    INPUT 2024 01 30 13 21 22.000000 88.000000
    INPUT 2024 01 30 13 22 33.000000 99.000000
"""

__copyright__ = "(C) Copyright Aquaveo 2023"
__license__ = "All rights reserved"

# 1. Standard Python modules
from datetime import datetime
import logging
from pathlib import Path

# 2. Third party modules

# 3. Aquaveo modules
from xms.coverage.xy import xy_util
from xms.coverage.xy.xy_series import XySeries
from xms.gmi.data.coverage_data import CoverageData

# 4. Local modules
from xms.gssha.data.bc_util import BcData
from xms.gssha.file_io import io_util
from xms.gssha.file_io.io_util import DictIntXy, INT_WIDTH

# Type aliases
LinkNodes = list[tuple[int, int, int, int]]  # list[tuple(arc id, link #, GSSHA node #, xy series ID)]


def write(
    gssha_file_path: Path, stream_data: BcData, coverage_data: CoverageData, start_date_time: datetime
) -> Path | None:
    """Writes the .ihg file and returns the file path.

    See https://www.gsshawiki.com/Surface_Water_Routing:Introducing_Dischage/Constituent_Hydrographs

    Args:
        gssha_file_path: .gssha file path.
        stream_data: Data about the streams.
        coverage_data: BC coverage component data.
        start_date_time: The starting date/time.
    """
    if not stream_data or not stream_data.feature_bc:
        return None
    writer = ChanPointInputFileWriter(gssha_file_path, stream_data, coverage_data, start_date_time)
    return writer.write()


class ChanPointInputFileWriter:
    """Writes the .ihg file and returns the file path."""
    def __init__(
        self, gssha_file_path: Path, stream_data: BcData, coverage_data: CoverageData, start_date_time: datetime
    ) -> None:
        """Initializes the class.

        Args:
            gssha_file_path: .gssha file path.
            stream_data: Data about the streams.
            coverage_data: BC coverage component data.
            start_date_time: The starting date/time.
        """
        super().__init__()
        self._ihg_file_path: Path = gssha_file_path.with_suffix('.ihg')
        self._stream_data = stream_data
        self._coverage_data = coverage_data
        self._start_date_time = start_date_time

        self._log = logging.getLogger('xms.gssha')
        self._nodes: set[int] = set()  # Set of nodes so we don't get duplicate link-node locations
        self._link_nodes: LinkNodes = []

    def write(self) -> Path | None:
        """Writes the .ihg file and returns the file path.

        See https://www.gsshawiki.com/Surface_Water_Routing:Introducing_Dischage/Constituent_Hydrographs
        """
        self._log.info('Writing .ihg file...')

        # Collect the data we'll need
        self._find_input_hydrographs()
        if not self._link_nodes:
            return

        xy_values = self._get_xy_series_values()
        all_x = self._get_all_x(xy_values)
        all_y_for_each = self._get_all_y_for_each(xy_values, all_x)
        time_strings = self._get_time_strings(all_x)

        with open(self._ihg_file_path, 'w') as file:
            file.write(f'NUMPT {len(self._link_nodes)}\n')

            # Write POINT lines
            for _arc_id, link, node, xy_series_id in self._link_nodes:
                x0 = all_y_for_each[xy_series_id][0]
                file.write(f'POINT {node:<{INT_WIDTH}} {link:<{INT_WIDTH}} {x0}\n')

            # Write NRPDS
            file.write(f'NRPDS {len(all_x) - 1}\n')  # minus 1 because we skip time 0 below (x0 was written above)

            # Write INPUT lines
            for i in range(1, len(time_strings)):
                s = f'INPUT {time_strings[i]}'
                for _arc_id, _link, _node, xy_series_id in self._link_nodes:
                    s += f' {all_y_for_each[xy_series_id][i]}'
                file.write(f'{s}\n')
        return self._ihg_file_path

    def _get_time_strings(self, all_x) -> list[str]:
        """Returns the list of time strings for the INPUT lines."""
        time_strings = []
        for i in range(0, len(all_x)):
            time_strings.append(io_util.get_time_string(self._start_date_time, all_x[i]))
        return time_strings

    def _get_xy_series_values(self) -> DictIntXy:
        """Returns a dict with the xy values for all series."""
        xy_values = {}
        for arc_id, _link, node, xy_series_id in self._link_nodes:
            xy_values[xy_series_id] = self._coverage_data.get_curve(xy_series_id, False)
            self._prevent_decreasing_time(arc_id, node, xy_series_id, xy_values)
        return xy_values

    def _prevent_decreasing_time(self, arc_id: int, node: int, xy_series_id: int, xy_values: dict) -> None:
        """Don't allow decreasing time vales (x).

        Args:
            arc_id: Arc ID.
            node: Node ID.
            xy_values: dict of xy series id -> (list[x], list[y]).
            xy_series_id: XY series ID.
        """
        x, y = xy_values[xy_series_id]
        template = (
            'Point source hydrograph for arc {0}, node {1}, was truncated at row {2} due to decreasing time'
            ' values.'
        )
        if len(x) > 1:
            for i in range(1, len(x)):
                if x[i] < x[i - 1]:
                    self._log.error(template.format(arc_id, node, i))
                    xy_values[xy_series_id] = (x[:i], y[:i])

    def _get_all_x(self, xy_values: DictIntXy) -> list[float]:
        """Returns a list of all the x values for all the xy series, in order."""
        set_x = {0.0}  # Make sure time 0.0 is in the set of all x
        for _arc_id, _link, _node, xy_series_id in self._link_nodes:
            set_x.update(xy_values[xy_series_id][0])
        all_x = sorted(list(set_x))
        return all_x

    def _get_all_y_for_each(self, xy_values: DictIntXy, all_x: list[float]) -> dict[int, list[float]]:
        """Gets the y values for all x values for each xy series."""
        # Interpolate/extrapolate to get y values at all x values, for xy series
        all_y_for_each: dict[int, list[float]] = {}  # xy series id -> interpolated y values all x values
        for _arc_id, _link, _node, xy_series_id in self._link_nodes:
            series_x, series_y = xy_values[xy_series_id]
            # Create an XySeries so we can use xy_util.y_from_x
            xy_series = XySeries(series_x, series_y, series_id=xy_series_id)
            all_y_for_series = xy_util.y_from_x(xy_series, all_x)

            # Don't extrapolate. Set any extrapolated y values to 0.0
            for i, x_all in enumerate(all_x):
                if x_all < series_x[0] or x_all > series_x[-1]:
                    all_y_for_series[i] = 0.0
            all_y_for_each[xy_series_id] = all_y_for_series

        return all_y_for_each

    def _find_input_hydrographs(self) -> None:
        """Searches the streams for where input hydrographs are specified and saves locations to a list."""
        for arc, gmi_group in self._stream_data.feature_bc.items():
            if gmi_group.parameter('specify_point_source').value:
                # Make sure the node is only handled once by using the _nodes set
                feature_arc = self._stream_data.feature_from_id(arc[0], arc[1])
                if feature_arc.start_node not in self._nodes:
                    self._nodes.add(feature_arc.start_node)
                    link = self._stream_data.arc_links[feature_arc.id]
                    gssha_node = 1  # The number of the point on the arc (starting node is always 1)
                    xy_series_id = gmi_group.parameter('point_source_xy').value
                    if xy_series_id is None or xy_series_id <= 0:
                        msg = f'Point source hydrograph XY series not defined for arc ID={feature_arc.id}, link={link}.'
                        self._log.error(msg)
                        continue
                    self._link_nodes.append((feature_arc.id, link, gssha_node, xy_series_id))
