r"""Writes the `STREAM_CELL <ci>`_ (.gst) file, which tells GSSHA how the streams link to the grid.

.. _ci: https://www.gsshawiki.com/Surface_Water_Routing:Channel_Routing

::

    GRIDSTREAMFILE
    STREAMCELLS 108
    CELLIJ     1      31
    NUMNODES   2
    LINKNODE   5      77     1.0
    LINKNODE   5      76     0.3020346471717026
    CELLIJ     2      31
    NUMNODES   5
    LINKNODE   5      76     0.6979653528282974
    LINKNODE   5      75     1.0
    LINKNODE   5      74     1.0
    LINKNODE   5      73     1.0
    LINKNODE   5      72     0.7556114331203649
    .
    .
    .
"""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules
import logging
from pathlib import Path

# 2. Third party modules

# 3. Aquaveo modules

# 4. Local modules
from xms.gssha.data import data_util
from xms.gssha.data.bc_util import BcData
from xms.gssha.file_io.io_util import INT_WIDTH
from xms.gssha.mapping import map_util

# Constants
_ALIGN = 10  # Align numbers starting at this column


def write(gssha_file_path: Path, stream_data: BcData, co_grid, ugrid) -> Path | None:
    """Writes the STREAM_CELL (.gst) file and returns the file path.

    Args:
        gssha_file_path: .gssha file path.
        stream_data: Data about the streams.
        co_grid: The grid.
        ugrid: The UGrid.
    """
    if not stream_data or not stream_data.feature_bc or not co_grid:
        return None
    writer = StreamCellFileWriter(gssha_file_path, stream_data, co_grid, ugrid)
    return writer.write()


class StreamCellFileWriter:
    """Writes the STREAM_CELL (.gst) file."""
    def __init__(self, gssha_file_path: Path, stream_data: BcData, co_grid, ugrid) -> None:
        """Initializes the class.

        Args:
            gssha_file_path (str | Path): .gssha file path.
            stream_data: Data about the streams.
            co_grid: The grid.
            ugrid: The UGrid.
        """
        super().__init__()
        self._gst_file_path: Path = gssha_file_path.with_suffix('.gst')
        self._stream_data = stream_data
        self._co_grid = co_grid
        self._ugrid = ugrid

        self._log = logging.getLogger('xms.gssha')

    def write(self) -> Path | None:
        """Writes the STREAM_CELL (.gst) file and returns the file path."""
        self._log.info('Writing .gst file...')
        on_off_cells = data_util.get_on_off_cells(self._co_grid, self._ugrid)
        arc_ix = map_util.intersect_arcs_with_grid(self._ugrid, on_off_cells, self._stream_data)
        with open(self._gst_file_path, 'w') as file:
            file.write('GRIDSTREAMFILE\n')
            file.write(f'STREAMCELLS {len(arc_ix)}\n')

            sorted_cell_link_node_info = dict(sorted(arc_ix.items()))  # Sort by cell index to match WMS.
            for cell_idx, list_nodes in sorted_cell_link_node_info.items():
                i, j = self._co_grid.get_cell_ij_from_index(cell_idx)
                file.write(f'{"CELLIJ":<{_ALIGN}} {i:<{INT_WIDTH}} {j:<{INT_WIDTH}}\n')
                file.write(f'{"NUMNODES":<{_ALIGN}} {len(list_nodes)}\n')
                # Reverse the order so we go from downstream to upstream
                for _group, segment_idx, t_val1, t_val2, arc in reversed(list_nodes):
                    link = self._stream_data.arc_links[arc[0]]
                    percent_len = t_val2 - t_val1
                    link_str = f'{"LINKNODE":<{_ALIGN}} {link:<{INT_WIDTH}} {segment_idx:<{INT_WIDTH}} {percent_len}\n'
                    file.write(link_str)
        return self._gst_file_path
