r"""Writes the `CHAN_INPUT <ci>`_ (.cif) file, which describes the stream network.

.. _ci: https://www.gsshawiki.com/Surface_Water_Routing:Channel_Routing

::

    GSSHA_CHAN
    ALPHA       1.000000
    BETA        1.000000
    THETA       1.000000
    LINKS       5
    MAXNODES    89
    CONNECT    4    5    0
    CONNECT    1    3    0
    CONNECT    2    3    0
    CONNECT    3    5    2    1    2
    CONNECT    5    0    2    4    3

    LINK           4
    DX             27.545402
    TRAPEZOID
    NODES          68
    NODE 1
    X_Y  459545.266545 4499466.101199
    ELEV 2167.939697
    XSEC
    MANNINGS_N     0.040000
    BOTTOM_WIDTH   1.000000
    BANKFULL_DEPTH 1.000000
    SIDE_SLOPE     1.000000
    NODE 2
    X_Y  459525.449644 4499453.475290
    ELEV 2168.067627
    .
    .
    .
"""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules
import logging
from pathlib import Path
from typing import NamedTuple

# 2. Third party modules

# 3. Aquaveo modules
from xms.gmi.data.coverage_data import CoverageData
from xms.gmi.data.generic_model import Group
from xms.grid.geometry import geometry

# 4. Local modules
from xms.gssha.data.bc_generic_model import ChannelType
from xms.gssha.data.bc_util import BcData, FeatureBc
from xms.gssha.file_io.block import Block

# Constants
_ALIGN = 15  # Values aligned at this column + 1, after the card


def write(gssha_file_path: Path, stream_data: BcData, coverage_data: CoverageData) -> Path | None:
    """Writes the CHAN_INPUT (.cif) file and returns the file path.

    Args:
        gssha_file_path: .gssha file path.
        stream_data: Data about the streams.
        coverage_data: BC coverage component data.
    """
    if not stream_data or not stream_data.feature_bc or not coverage_data:
        return None
    writer = ChannelInputFileWriter(gssha_file_path, stream_data, coverage_data)
    return writer.write()


class ChannelInputFileWriter:
    """Writes the channel input (.cif) file."""
    def __init__(self, gssha_file_path: Path, stream_data: BcData, coverage_data: CoverageData) -> None:
        """Initializes the class.

        Args:
            gssha_file_path (str | Path): .gssha file path.
            stream_data: Data about the streams.
            coverage_data: BC coverage component data.
        """
        super().__init__()
        self._cif_file_path: Path = gssha_file_path.with_suffix('.cif')
        self._stream_data = stream_data
        self._coverage_data = coverage_data

        self._log = logging.getLogger('xms.gssha')
        self._file = None
        self._sorted_bcs: FeatureBc = {}
        self._links_count = 0
        self._max_nodes = 0

    def write(self) -> Path | None:
        """Writes the CHAN_INPUT (.cif) file and returns the file path."""
        self._log.info('Writing .cif file...')
        self._examine_arcs()
        with open(self._cif_file_path, 'w') as self._file:
            self._file.write('GSSHA_CHAN\n')
            self._write_alpha_beta_theta()
            self._write_dimensions()
            self._write_connections()
            self._write_links()

        return self._cif_file_path

    def _write_alpha_beta_theta(self) -> None:
        """Writes ALPHA, BETA, THETA."""
        with Block(self._file, self._log, '', _ALIGN) as block:
            block.write(group=None, name='ALPHA', value=1.0)
            block.write(group=None, name='BETA', value=1.0)
            block.write(group=None, name='THETA', value=1.0)

    def _write_dimensions(self) -> None:
        """Writes LINKS and MAXNODES."""
        with Block(self._file, self._log, '', _ALIGN) as block:
            block.write(group=None, name='LINKS', value=self._links_count)
            block.write(group=None, name='MAXNODES', value=self._max_nodes)

    def _link_number(self, dict_item: tuple[tuple, Group]) -> int:
        """Used for sorting by returning the link number given an item in the dict."""
        return self._stream_data.arc_links[dict_item[0][0]]

    def _examine_arcs(self):
        """Traverses arcs to get connectivity and other info."""
        self._sorted_bcs = dict(sorted(self._stream_data.feature_bc.items(), key=self._link_number))
        for arc, _group in self._sorted_bcs.items():
            # Store some data
            self._links_count += 1
            feature_arc = self._stream_data.feature_from_id(arc[0], arc[1])
            length = len(feature_arc.geometry.coords)
            self._max_nodes = self._max_nodes if self._max_nodes > length else length

    def _write_connections(self) -> None:
        """Writes the CONNECT lines."""
        with Block(self._file, self._log, '', _ALIGN) as block:
            for arc, _group in self._sorted_bcs.items():
                downstream_arc = self._find_downstream_arc(arc)
                upstream_arcs = self._find_upstream_arcs(arc)
                link = self._stream_data.arc_links[arc[0]]
                connect_str = self._build_connect_string(link, downstream_arc, upstream_arcs)
                block.write(group=None, name='CONNECT', value=connect_str)

    def _find_downstream_arc(self, arc: tuple) -> NamedTuple:
        """Returns the arc that is downstream from the arc.

        Args:
            arc: The arc tuple.

        Returns:
            See description.
        """
        downstream_arc = None
        feature_arc = self._stream_data.feature_from_id(arc[0], arc[1])
        for node_arc in self._stream_data.node_arcs[feature_arc.end_node]:
            if node_arc.start_node == feature_arc.end_node:
                downstream_arc = node_arc
                break
        return downstream_arc

    def _find_upstream_arcs(self, arc: tuple) -> list[NamedTuple]:
        """Returns the arcs that are upstream from the arc.

        Args:
            arc: The arc tuple.

        Returns:
            See description.
        """
        upstream_arcs = []
        feature_arc = self._stream_data.feature_from_id(arc[0], arc[1])
        for node_arc in self._stream_data.node_arcs[feature_arc.start_node]:
            if node_arc.end_node == feature_arc.start_node:
                upstream_arcs.append(node_arc)
        return upstream_arcs

    def _write_links(self) -> None:
        """Writes the CONNECT lines."""
        channel_type_map = {ChannelType.TRAPEZOIDAL: 'TRAPEZOID', ChannelType.CROSS_SECTION: 'BREAKPOINT'}
        with Block(self._file, self._log, '\n', _ALIGN) as _:
            for arc, group in self._sorted_bcs.items():
                with Block(self._file, self._log, '\n', _ALIGN) as block:
                    channel_type = group.parameter('channel_type').value
                    feature_arc = self._stream_data.feature_from_id(arc[0], arc[1])
                    points = list(feature_arc.geometry.coords)
                    link = self._stream_data.arc_links[feature_arc.id]
                    _write_link_start(link, channel_type_map[channel_type], points, block)
                    self._write_node_1(group, channel_type, points, block)
                    _write_remaining_nodes(points, block)

    def _write_node_1(self, group: Group, channel_type: str, points: list[tuple], block: Block) -> None:
        """Writes the first node.

        Args:
            group: The generic model group
            channel_type: The channel type
            points: The arc points
            block: The block.
        """
        point = points[0]
        block.write(group=None, name='NODE', value=1)
        block.write(group=None, name='X_Y', value=f'{point[0]} {point[1]}')
        block.write(group=None, name='ELEV', value=f'{point[2]}')
        block.write(group=None, name='XSEC', value='')
        block.write(group=group, name='mannings_n')
        if channel_type == ChannelType.TRAPEZOIDAL:
            block.write(group=group, name='bottom_width')
            block.write(group=group, name='bankfull_depth')
            block.write(group=group, name='side_slope')
        else:
            xy_id = group.parameter('cross_section').value
            x_values, y_values = self._coverage_data.get_curve(xy_id, False)
            block.write(group=None, name='NPAIRS', value=len(x_values))
            for x, y in zip(x_values, y_values):
                block.write(group=None, name='X1', value=f'{x} {y}')

    def _build_connect_string(self, link: int, downstream_arc: NamedTuple, upstream_arcs: list[NamedTuple]) -> str:
        """Builds and returns the CONNECT string.

        Args:
            link: The link number
            downstream_arc: The downstream arc.
            upstream_arcs: The upstream arcs.

        Returns:
            See description.
        """
        # Build CONNECT string
        bc_group = None
        if downstream_arc is not None:
            bc_group = self._sorted_bcs.get((downstream_arc.id, downstream_arc.geometry_types))
        downstream_arc_link = 0 if not bc_group else self._stream_data.arc_links[downstream_arc.id]
        connect_str = f'{link:<5}{downstream_arc_link:<5}{len(upstream_arcs):<5}'
        upstream_links = [self._stream_data.arc_links[upstream_arc.id] for upstream_arc in upstream_arcs]
        for upstream_link in sorted(upstream_links):
            connect_str += f'{upstream_link:<5}'
        return connect_str


def _write_remaining_nodes(points: list[tuple], block: Block) -> None:
    """Writes all the remaining nodes.

    Args:
        points: The arc points
        block: The block.
    """
    for i in range(1, len(points)):
        point = points[i]
        block.write(group=None, name='NODE', value=i + 1)
        block.write(group=None, name='X_Y', value=f'{point[0]} {point[1]}')
        block.write(group=None, name='ELEV', value=f'{point[2]}')


def _write_link_start(link: int, channel_type_str: str, points: list[tuple], block: Block) -> None:
    """Writes the lines that start a new link (arc).

    Args:
        link: The link number
        channel_type_str: Channel type string ('TRAPEZOID', 'BREAKPOINT').
        points: The arc points
        block: The block being written to
    """
    block.write(group=None, name='LINK', value=link)
    block.write(group=None, name='DX', value=_compute_average_arc_segment_length(points))
    block.write(group=None, name='', value=channel_type_str)
    block.write(group=None, name='NODES', value=len(points))


def _compute_average_arc_segment_length(points: list[tuple]) -> float:
    """Computes and returns the average length of all the segments on the arc.

    Args:
        points: The arc points.

    Returns:
        See description.
    """
    total_distance = 0.0
    for i in range(len(points) - 1):
        total_distance += geometry.distance_2d((points[i][0], points[i][1]), (points[i + 1][0], points[i + 1][1]))
    return total_distance / (len(points) - 1)
