"""SolutionReader class."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules
from datetime import datetime
import logging
from pathlib import Path

# 2. Third party modules
import numpy as np
import numpy.typing as npt
from PySide2.QtWidgets import QWidget

# 3. Aquaveo modules
from xms.api.dmi import Query, XmsEnvironment
from xms.constraint import Grid, UGridBuilder
from xms.data_objects.parameters import FilterLocation, Point, UGrid as DoUGrid
from xms.datasets import dat_reader
from xms.datasets.dataset_writer import DatasetWriter
from xms.grid.ugrid import UGrid
from xms.guipy.dialogs import process_feedback_dlg
from xms.guipy.dialogs.feedback_thread import FeedbackThread
from xms.guipy.testing import testing_tools

# 4. Local modules
from xms.gssha.components import dmi_util
from xms.gssha.file_io import cif_file_reader, io_util
from xms.gssha.misc.type_aliases import ActionRv

# Constants
"""Cards for gridded dataset output"""
GRIDDED_DATASETS = {
    'DIS_RAIN',
    'DEPTH',
    'INF_DEPTH',
    'RATE_OF_INFIL',
    'SURF_MOIST',
    'GW_OUTPUT',
    'GW_RECHARGE_INC',
    'GW_RECHARGE_CUM',
    'VOL_SED_SUSP',
    'MAX_SED_FLUX',
    'NET_SED_VOLUME',
    'FLOOD_GRID',
}
"""Cards for link/node dataset output"""
LINK_NODE_DATASETS = {
    'CHAN_DEPTH',
    'CHAN_DISCHARGE',
    'CHAN_VELOCITY',
    'CHAN_SED_FLUX',
    'FLOOD_STREAM',
    'CHAN_STAGE',
    'PIPE_FLOW',
    'PIPE_HEAD',
    'PIPE_TILE',
    'SUPERLINK_JUNC_FLOW',
    'SUPERLINK_NODE_FLOW',
}


def read(gssha_file_path: str | Path, query: Query, win_cont: QWidget) -> ActionRv:
    """Runs the solution reader.

    Args:
        gssha_file_path: Path to the .gssha file.
        query: Object for communicating with XMS
        win_cont: The window container.
    """
    # For testing
    # from xms.guipy import debug_pause
    # debug_pause()
    # _set_test_environment()

    thread = ReadSolutionFeedbackThread(query, gssha_file_path)
    process_feedback_dlg.run_feedback_dialog(thread, win_cont)
    return [], []


class ReadSolutionFeedbackThread(FeedbackThread):
    """Thread for reading the solution."""
    def __init__(self, query: Query, gssha_file_path: Path):
        """Initializes the class."""
        super().__init__(query)
        self._gssha_file_path = gssha_file_path
        self.display_text |= {
            'title': 'Read Solution',
            'working_prompt': f'Reading GSSHA solution at \"{str(gssha_file_path.parent)}\".',
            'success_prompt': f'Successfully read solution at \"{str(gssha_file_path.parent)}\".',
            # 'auto_load': 'Automatically close this dialog when done'
            'auto_load': ''  # This makes it stay open
        }

    def _run(self):
        """Does the work."""
        reader = SolutionReader(self._gssha_file_path, self._query)
        reader.read()


class SolutionReader:
    """Reads the solution."""
    def __init__(self, gssha_file_path: str | Path, query: Query) -> None:
        """Initializes the class.

        Args:
            gssha_file_path: Path to the .gssha file.
            query: Object for communicating with XMS
        """
        self._gssha_file_path = gssha_file_path
        self._project_path = ''  # PROJECT_PATH card value
        self._query = query
        self._log = logging.getLogger('xms.gssha')
        self._geom_uuid = ''
        self._gssha_dict: dict[str, str] = {}  # .gssha file in a dict. card -> value
        self._start_date_time: datetime | None = None
        self._link_node_co_grid: 'Grid | None' = None

        if self._query:  # This can be None when testing
            self._sim_uuid = self._query.parent_item_uuid()  # Do this only once because it doesn't work the 2nd time

    def read(self) -> None:
        """Reads the solution."""
        self._find_geom_uuid()
        self._read_gssha_file_to_dict()
        self._read_start_date_time()
        self._read_summary()
        self._read_gridded_datasets()
        self._read_link_node_datasets()
        self._read_outlet_hydrograph()
        self._log.info('Solution read complete.')

    def _find_geom_uuid(self):
        """Finds the linked grid uuid."""
        ugrid_node = dmi_util.get_ugrid_node_or_warn(self._query.project_tree, None, query=self._query)
        self._geom_uuid = ugrid_node.uuid

    def _read_gssha_file_to_dict(self) -> None:
        """Reads the .gssha file and fills a dict with the outputs."""
        self._log.info(f'Reading {str( self._gssha_file_path.name)} to find outputs.')
        self._gssha_dict = io_util.read_gssha_file_to_dict(self._gssha_file_path)
        if 'PROJECT_PATH' in self._gssha_dict:
            self._project_path = self._gssha_dict['PROJECT_PATH'].strip('"')

    def _read_start_date_time(self) -> None:
        """Reads and stores the starting date/time, either from the .gssha file, or the .gag file."""
        if 'START_DATE' in self._gssha_dict:
            start_date, start_time = self._gssha_dict.get('START_DATE', ''), self._gssha_dict.get('START_TIME', '')
            self._start_date_time = io_util.datetime_from_cards(start_date, start_time)
        elif 'PRECIP_FILE' in self._gssha_dict:
            full_path = io_util.get_full_path(
                self._gssha_file_path, self._project_path, self._gssha_dict['PRECIP_FILE']
            )
            self._start_date_time = _start_date_time_from_gag_file(full_path)

    def _read_summary(self) -> None:
        """Reads the summary file."""
        if 'SUMMARY' not in self._gssha_dict:
            return

        summary_path = self._gssha_dict['SUMMARY']
        full_path = io_util.get_full_path(self._gssha_file_path, self._project_path, summary_path)
        if full_path.is_file():
            self._log.info(f'Reading Summary file "{str(full_path.name)}".')
            txt_comp = dmi_util.build_solution_component(full_path, self._sim_uuid, 'TXT_SOL', False)
            self._query.add_component(txt_comp)

    def _read_outlet_hydrograph(self) -> None:
        """Reads the outlet hydrograph."""
        if 'OUTLET_HYDRO' not in self._gssha_dict:
            return

        file_path = self._gssha_dict['OUTLET_HYDRO']
        full_path = io_util.get_full_path(self._gssha_file_path, self._project_path, file_path)
        if full_path.is_file():
            self._log.info(f'Reading the outlet hydrograph "{str(full_path.name)}".')
            txt_comp = dmi_util.build_solution_component(full_path, self._sim_uuid, 'OTL_SOL', False)
            self._query.add_component(txt_comp)

    def _read_gridded_datasets(self) -> None:
        """Reads the gridded datasets."""
        self._log.info('Reading gridded datasets.')

        for card in GRIDDED_DATASETS:
            if card not in self._gssha_dict:
                continue

            file_path = self._gssha_dict[card].strip('"')
            full_path = io_util.get_full_path(self._gssha_file_path, self._project_path, file_path)
            if full_path.is_file():
                self._log.info(f'Reading "{full_path.name}".')
                try:
                    dataset_reader = dat_reader.parse_dat_file(
                        full_path,
                        geom_uuid=self._geom_uuid,
                        ref_time=self._start_date_time,
                        location='cells',
                        time_units='Minutes',
                        units='meters'
                    )
                except Exception as error:
                    self._log.error(f'Error reading "{str(full_path)}": "{error}"')
                    continue
                self._query.add_dataset(dataset_reader)

    def _read_link_node_datasets(self) -> None:
        """Reads the link/node datasets."""
        self._log.info('Reading link/node datasets.')

        for card in LINK_NODE_DATASETS:
            if card not in self._gssha_dict:
                continue

            file_path = self._gssha_dict[card].strip('"')
            full_path = io_util.get_full_path(self._gssha_file_path, self._project_path, file_path)
            if full_path.is_file():
                self._log.info(f'Reading "{full_path.name}".')
                if not self._link_node_co_grid:
                    self._make_link_node_ugrid(full_path)
                num_points = len(self._link_node_co_grid.point_elevations)
                dataset_writer = _read_link_node_dataset_file(
                    full_path, card, num_points, self._link_node_co_grid.uuid, self._start_date_time
                )
                self._query.add_dataset(dataset_writer)

    def _make_link_node_ugrid(self, full_path: Path) -> None:
        """Creates a UGrid of points at the midpoints of each stream arc segment."""
        self._log.info('Creating link/node grid.')

        # Check that there's a CHANNEL_INPUT
        if 'CHANNEL_INPUT' not in self._gssha_dict:
            msg = f'Cannot read "{str(full_path)}" because no CHANNEL_INPUT file found in "{self._gssha_file_path}".'
            self._log.warning(msg)
            return

        # Read the CHANNEL_INPUT file to get the stream arcs
        file_path = self._gssha_dict['CHANNEL_INPUT']
        full_path = io_util.get_full_path(self._gssha_file_path, self._project_path, file_path)
        coverage, _links, _xy_dict = cif_file_reader.read(full_path)
        if not coverage:
            self._log.warning(f'Could not read "{str(full_path)}". Skipping "{full_path}".')
            return

        # Get the midpoints of all arc segments. Assumes arcs were created in same order as .cif and .cdp files
        midpoints = []
        for arc in coverage.arcs:
            pts = arc.get_points(FilterLocation.PT_LOC_ALL)
            midpoints.extend([_avg_xyz(pts[i], pts[i - 1]) for i in range(1, len(pts))])

        # Create the grid
        builder = UGridBuilder()
        builder.set_is_2d()
        xm_ugrid = UGrid(midpoints)
        builder.set_ugrid(xm_ugrid)
        self._link_node_co_grid = builder.build_grid()
        self._link_node_co_grid.uuid = testing_tools.new_uuid()

        # Add grid
        self._log.info('Adding link/node grid.')
        xmc_file = Path(XmsEnvironment.xms_environ_process_temp_directory()) / f'{self._link_node_co_grid.uuid}.xmc'
        self._link_node_co_grid.write_to_file(str(xmc_file), binary_arrays=True)
        name = f'{self._gssha_file_path.stem} (GSSHA) linknode'
        do_ugrid = DoUGrid(str(xmc_file), name=name, uuid=self._link_node_co_grid.uuid)
        do_ugrid.force_ugrid = True
        self._query.add_ugrid(do_ugrid)


def _set_test_environment():  # pragma no cover - This is only called when manually debugging
    import os
    from xms.api.dmi import XmsEnvironment
    os.environ[XmsEnvironment.ENVIRON_RUNNING_TESTS] = 'TRUE'


def _avg_xyz(pt1: Point, pt2: Point) -> tuple[float, float, float]:
    """Averages the two points to get the midpoint between them and returns the xyz as a tuple.

    Args:
        pt1: First point.
        pt2: Second point.

    Returns:
        See description.
    """
    return (pt1.x + pt2.x) / 2., (pt1.y + pt2.y) / 2., (pt1.z + pt2.z) / 2.


def _read_link_node_dataset_file(
    file_path: Path, card: str, num_points: int, geom_uuid: str, start_date_time: datetime
) -> DatasetWriter:
    """Reads a link/node dataset file and returns a DatasetWriter.

    Example files: GSSHA_LINKNODE_STREAM_FLOW (.cdq), GSSHA_LINKNODE_STREAM_DEPTH (.cdp) etc.

    Args:
        file_path: The file path.
        card: CHAN_DEPTH etc.
        num_points: Number of grid points.
        geom_uuid: Uuid of the grid.
        start_date_time: START_DATE/START_TIME from the .gssha file.

    Returns:
        See description.
    """
    dataset_writer = DatasetWriter(
        name=card.lower(), geom_uuid=geom_uuid, time_units='Minutes', location='points', use_activity_as_null=True
    )
    num_links: int = 0
    time_step: float = 0.0  # Length of each time step?
    with open(file_path, 'r') as file:
        for line in file:
            line = line.rstrip('\n')
            card, card_value = io_util.get_card_and_value(line)
            if card == 'NUM_LINKS':
                num_links = int(card_value)
            elif card == 'TIME_STEP':
                time_step = float(card_value)
            elif card == 'START_TIME':
                # time string is month day year hour minute second - not year month day - like everywhere else!!!
                words = card_value.split()
                time_str = ' '.join([words[2], words[0], words[1], words[3], words[4]])
                dataset_writer.ref_time = io_util.datetime_from_string(time_str)
                if dataset_writer.ref_time is None:
                    dataset_writer.ref_time = start_date_time
            elif card == 'TS':
                ts_idx = int(card_value)
                time = ts_idx * time_step
                values, activity = _read_time_step(num_links, num_points, file)
                dataset_writer.append_timestep(time, values, activity)

    dataset_writer.appending_finished()
    return dataset_writer


def _read_time_step(num_links: int, num_points: int, file) -> 'tuple[npt.NDArray[float], npt.NDArray[int]]':
    """Reads the link/node time step and returns the values.

    Args:
        num_links: Number of links.
        file: The file being read from.

    Returns:
        See description.
    """
    values = np.zeros(num_points)
    activity = np.ones(num_points, dtype=int)
    offset = 0
    for _ in range(num_links):
        line = next(file)
        line = line.rstrip('\n')
        card, card_value = io_util.get_card_and_value(line)
        npairs = int(card)
        words = card_value.split()
        for j, k in enumerate(range(0, len(words), 2)):
            activity[offset + j] = int(words[k])
            values[offset + j] = float(words[k + 1])
        offset += npairs
    return values, activity


def _start_date_time_from_gag_file(file_path: Path) -> datetime | None:
    """Reads and returns the first time from the .gag file."""
    try:
        with open(file_path, 'r') as file:
            start_date_time = None
            for line in file:
                if line.startswith('ACCUM'):
                    line = line.rstrip('\n')
                    card, card_value = io_util.get_card_and_value(line)
                    date_str, value_str = card_value.rsplit(maxsplit=1)  # Exclude last word when reading the time
                    start_date_time = io_util.datetime_from_string(date_str)
                    break
            return start_date_time
    except FileNotFoundError:
        return None
