"""SolutionReader class."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules
from dataclasses import dataclass
from datetime import datetime
import logging
from pathlib import Path
import tempfile
from typing import TypedDict

# 2. Third party modules
import numpy as np
from numpy import ndarray
from rtree import index

# 3. Aquaveo modules
from xms.api.dmi import XmsEnvironment
from xms.coverage.xy.xy_series import XySeries
from xms.datasets.dataset_writer import DatasetWriter
from xms.guipy import file_io_util
from xms.guipy.dialogs import dialog_util
from xms.guipy.dialogs.process_feedback_dlg import ProcessFeedbackDlg
from xms.guipy.testing import testing_tools
from xms.guipy.time_format import string_to_datetime

# 4. Local modules
from xms.hgs.data.domains import domain_abbreviation, Domains
from xms.hgs.file_io import grok_writer
from xms.hgs.file_io.binary_file_reader import BinaryFileReader
from xms.hgs.file_io.tecplot_fem_block_reader import TecplotFemBlockReader
from xms.hgs.file_io.tecplot_ordered_block_reader import TecplotOrderedBlockReader
from xms.hgs.file_io.tecplot_point_reader import TecplotPointReader
from xms.hgs.misc import util

# Type aliases
Locations = list[tuple[float, float, float]]  # list of xyz locations


def run(grok_filepath: Path,
        co_grid_3d,
        co_grid_2d,
        win_cont,
        testing: bool = False) -> tuple[list[tuple[DatasetWriter, str]], list[XySeries]]:
    """Runs the solution reader.

    Args:
        grok_filepath (Path): Path to the .grok file.
        co_grid_3d: The 3D grid.
        co_grid_2d: The 2D grid.
        win_cont (QWidget): Parent window.
        testing (bool): Set to true if testing so dialog closes.
    """
    feedback = True
    reader = SolutionReader(grok_filepath, co_grid_3d, co_grid_2d)
    from xms.guipy.dialogs.process_feedback_thread import ProcessFeedbackThread
    worker = ProcessFeedbackThread(do_work=reader.do_work, parent=win_cont)
    output_filepath = hgs_output_files_txt_path(grok_filepath)
    display_text = {
        'title': 'Read Solution',
        'working_prompt': f'Reading HydroGeoSphere solution at \"{str(grok_filepath.parent)}\".',  # noqa: B028
        'error_prompt': 'Error(s) encountered while reading the solution.',
        'warning_prompt': 'Warning(s) encountered while reading the solution.',
        'success_prompt': f'Successfully read solution at \"{str(grok_filepath.parent)}\".',  # noqa: B028
        'note': f'Using: {str(output_filepath)}',
        'auto_load': 'Close this dialog automatically when reading is finished.',
        'log_format': '%(asctime)s - %(message)s',
        'use_colors': True
    }
    dialog_util.ensure_qapplication_exists()
    feedback_dlg = ProcessFeedbackDlg(display_text=display_text, logger_name='xms.hgs', worker=worker, parent=win_cont)
    feedback_dlg.testing = not feedback or testing
    feedback_dlg.exec()
    return reader.get_results()


def hgs_output_files_txt_path(grok_filepath: Path) -> Path:
    """Returns the path to the prefixo.hgs_output_files.txt file (e.g. "C:/.../R5o.hgs_output_files.txt").

    Args:
        grok_filepath (Path): Path to the .grok file.

    Returns:
        (Path): See description.
    """
    return grok_filepath.with_name(f'{grok_filepath.stem}o.hgs_output_files.txt')


@dataclass()
class FileSpec:
    """Specification (description) of the file from the o.hgs_output_files.txt file."""
    file_type: str = ''  # e.g. 'binary', 'tecplot ascii'
    data_location: str = ''  # 'cell-centered', 'nodal'
    data_type: str = ''  # 'real4', 'real8'
    components_count: int = 0  # Usually 1 or 3
    data_count: int = 0  # Number of values in the file
    ordering: str = ''  # not sure what this tecplot thing is yet
    point_or_block: str = ''  # not sure what this tecplot thing is yet
    units: str = ''

    @staticmethod
    def from_strings(words: list[str]) -> 'FileSpec':
        """Returns a FileSpec initialized with the list of words."""
        if words[0] == 'binary':
            spec = FileSpec(
                file_type=words[0],
                data_location=words[1],
                data_type=words[2],
                components_count=int(words[3]),
                data_count=int(words[4]),
                units=words[5].strip('"()') if len(words) > 5 else ''
            )
            return spec
        elif words[0] == 'tecplot ascii':
            return FileSpec(file_type=words[0], ordering=words[1], point_or_block=words[2])
        else:
            raise RuntimeError(f'Unsupported file type: "{words[0]}".')


class FileInfo(TypedDict):
    """Info about an output file."""
    spec: FileSpec
    files: list[Path]  # All the files associated with this file (other timesteps)


class SolutionReader:
    """Reads the solution."""
    def __init__(self, grok_filepath: Path, co_grid_3d, co_grid_2d) -> None:
        """Initializes the file.

        I've tried to be consistent with variable names:
        - spec is a FileSpec class.
        - file_info is a FileInfo TypedDict containing FileSpec and a list of files.

        Args:
            grok_filepath (Path): Path to the .grok file.
            co_grid_3d: The 3D grid.
            co_grid_2d: The 2D grid.
        """
        self._grok_filepath = grok_filepath
        self._co_grid_3d = co_grid_3d
        self._ugrid_3d = co_grid_3d.ugrid if co_grid_3d else None  # Do 'co_grid_3d.ugrid' only once as it's costly
        self._co_grid_2d = co_grid_2d
        self._ugrid_2d = co_grid_2d.ugrid if co_grid_2d else None  # Do 'co_grid_2d.ugrid' only once as it's costly
        self._rtree2d = None
        self._rtree3d = None
        self._log = logging.getLogger('xms.hgs')

        self._output_files: dict[str, FileInfo] = {}
        self._gms_data = {}
        self._dset_time_units = ''
        self._datasets_and_folders: list[tuple[DatasetWriter, str]] = []  # DatasetWriters and their subfolder
        self._xy_series: list[XySeries] = []
        self._ref_time: datetime | None = None

    def do_work(self) -> None:
        """Called by ProcessFeedbackThread to do the stuff."""
        self.read()
        self._log.info('Solution read complete.\n')

    def read(self) -> tuple[list[tuple[DatasetWriter, str]], list[XySeries]]:
        """Reads the solution.

        Returns:
            (list[DatasetWriter]): List of the DatasetWriter classes.
        """
        self._find_time_units()
        self._read_hgs_output_file()
        self._read_gms_data()
        self._set_reference_time()
        for file_name_base, file_info in self._output_files.items():
            if file_name_base.lower() == 'ElemK_pm'.lower():  # Too different (see hgs_binary_formats.txt)
                continue
            self._read_file(file_name_base, file_info)
        self._assign_ids_to_xy_series()
        return self._datasets_and_folders, self._xy_series

    def get_results(self) -> tuple[list[tuple[DatasetWriter, str]], list[XySeries]]:
        """Returns the results."""
        return self._datasets_and_folders, self._xy_series

    def _assign_ids_to_xy_series(self) -> None:
        """Assign unique IDs to xy series."""
        for i, xy_series in enumerate(self._xy_series):
            xy_series.series_id = i + 1

    def _find_time_units(self) -> None:
        """Reads the .grok file and finds and saves the time units."""
        with self._grok_filepath.open('r') as file:
            for line in file:
                pos = line.lower().find('units: kilogram-')  # Hope we don't need a regex for this
                if pos >= 0:
                    pos2 = line.find('k', pos)
                    pos3 = line.find(' ', pos2)
                    units = line[pos2:pos3]
                    words = units.split('-')
                    # Last part is the time unit
                    self._dset_time_units = self._dset_time_units_from_hgs_time_units(words[-1])
                    break

    def _read_hgs_output_file(self) -> None:
        """Reads the []o.hgs_output_files.txt file and sorts it into a dict."""
        output_filepath = hgs_output_files_txt_path(self._grok_filepath)
        if not output_filepath.is_file():
            self._log.error(
                f'Could not find "{str(output_filepath)}". Use "track hgs output files" command in .grok file'
                f' in order to create this file when phgs.exe is run.'
            )
            return

        with output_filepath.open('r') as output_file:
            for line in (x.strip() for x in output_file):  # Strip \n from lines as they're read
                words = line.split(',')
                file_path = self._grok_filepath.with_name(words[0])
                if file_path.is_file():
                    spec = FileSpec.from_strings(words[1:])
                    file_name_base = self._get_file_name_base(spec.file_type, file_path)
                    if file_name_base not in self._output_files:
                        self._output_files[file_name_base] = {'spec': spec, 'files': [file_path]}
                    else:
                        self._output_files[file_name_base]['files'].append(file_path)

    def _read_gms_data(self):
        """Read gms data file."""
        filepath = self._grok_filepath.with_name(grok_writer.GMS_DATA_FILE)
        self._gms_data = file_io_util.read_json_file(filepath)

    def _set_reference_time(self):
        """Sets the reference time from the start date/time, if there is one."""
        start_date_time = self._gms_data.get('start_date_time', '')
        if start_date_time:
            self._ref_time = string_to_datetime(start_date_time, None)

    def _get_file_name_base(self, file_type: str, file_path: Path) -> str:
        """Returns the base name of the file which will be the same for all files of that group of files."""
        file_name_parts = file_path.name.split('.')
        file_name_base = ''
        if file_type == 'binary':
            file_name_base = '.'.join(file_name_parts[1:2])
        elif file_type == 'tecplot ascii' and len(file_name_parts) == 4:
            file_name_base = '.'.join(file_name_parts[1:2])
        elif file_type == 'tecplot ascii' and len(file_name_parts) == 3:
            file_name_base = '.'.join(file_name_parts[1:])
        return file_name_base

    def _read_file(self, file_name_base: str, file_info: FileInfo) -> None:
        """Reads a file."""
        self._log.info(f'Reading file "{file_name_base}"')
        file_type = file_info['spec'].file_type
        match file_type:
            case 'binary':
                self._read_binary_file(file_name_base, file_info)
            case 'tecplot ascii':
                self._read_tecplot_ascii_file(file_info)
            case _:
                self._log.warning(f'Unsupported file type: "{file_type}"')

    @staticmethod
    def _dset_location_from_hgs_location(data_location):
        """Returns the data location string that DatasetWriter expects, given HGS location."""
        match data_location:
            case 'cell-centered':
                location = 'cells'
            case 'nodal':
                location = 'points'
            case _:
                raise RuntimeError(f'Unsupported data location: "{data_location}".')
        return location

    @staticmethod
    def _dset_time_units_from_hgs_time_units(hgs_time_units) -> str:
        """Returns the time units string that DatasetWriter expects, given HGS time units."""
        match hgs_time_units:
            case 'second':
                time_units = 'Seconds'
            case 'minute':
                time_units = 'Minutes'
            case 'hour':
                time_units = 'Hours'
            case 'day':
                time_units = 'Days'
            case 'year':
                time_units = 'Years'
            case _:
                raise RuntimeError(f'Unsupported time units: "{hgs_time_units}".')
        return time_units

    def _init_dataset_writer(self, two_d: bool, name: str, spec: FileSpec) -> 'DatasetWriter | None':
        """Returns a DatasetWriter object.

        Args:
            two_d (bool): True if we're reading a 2D file.
            name (str): Name for the dataset.
            spec (FileSpec): Information about the file.

        Returns:
            (DatasetWriter | None): See description.
        """
        if (two_d and not self._co_grid_2d) or (not two_d and not self._co_grid_3d):
            return None

        location = SolutionReader._dset_location_from_hgs_location(spec.data_location)
        h5_file = tempfile.NamedTemporaryFile(
            prefix=name, suffix='.h5', delete=False, dir=XmsEnvironment.xms_environ_temp_directory()
        )
        dataset_writer = DatasetWriter(
            h5_filename=h5_file.name,
            name=name,
            dset_uuid=testing_tools.new_uuid(),
            geom_uuid=self._co_grid_2d.uuid if two_d else self._co_grid_3d.uuid,
            num_components=spec.components_count,
            ref_time=self._ref_time,
            time_units=self._dset_time_units,
            units=spec.units,
            # use_activity_as_null=True,
            location=location,
        )
        return dataset_writer

    def _read_binary_file(self, file_name_base: str, file_info: FileInfo) -> Path | None:
        """Reads a binary file and returns the filepath to the .h5 that was created.

        Args:
            file_name_base (str): Base part of file name.
            file_info (FileInfo): Information about the file.

        Returns:
            (Path|None): See description.
        """
        spec = file_info['spec']
        real_size = int(spec.data_type[-1])  # E.g. gets the 8 from 'real8'
        domain_subfolder, two_d = SolutionReader._binary_file_domain(file_name_base)
        dataset_writer = self._init_dataset_writer(two_d, file_name_base, spec)
        if not dataset_writer:
            return None

        for file_path in file_info['files']:
            reader = BinaryFileReader(file_path, real_size, spec.data_count, spec.components_count)
            time_stamp, values = reader.read()
            if values is None or len(values) != spec.data_count:
                self._log_num_values_error(file_path, values, spec.data_count)  # pragma no cover - shouldn't happen
                continue  # pragma no cover - shouldn't happen
            dataset_writer.append_timestep(time=time_stamp, data=values)
        dataset_writer.active_timestep = 0
        dataset_writer.appending_finished()
        self._datasets_and_folders.append((dataset_writer, domain_subfolder))
        return Path(dataset_writer.h5_filename)

    def _log_num_values_error(self, file_path, values, expected_values_count):  # pragma no cover - shouldn't happen
        """Logs an error and is easier to put 'pragma no cover' on."""
        values_read = len(values) if values is not None else 0
        self._log.error(
            f'Number of values read ({values_read}) did not match the number expected '
            f'({expected_values_count}) in file "{str(file_path.name)}".'
        )

    def _read_tecplot_ascii_file(self, file_info: FileInfo):
        """Reads a tecplot ASCII file."""
        spec = file_info['spec']
        for file in file_info['files']:
            if spec.point_or_block == 'block':
                if spec.ordering == 'fem':
                    self._read_tecplot_fem_block_file(file)
                else:
                    self._read_tecplot_ordered_block_file(file)
            elif spec.point_or_block == 'point':
                self._read_tecplot_point_file(file)

    def _read_tecplot_fem_block_file(self, filepath: Path) -> None:
        """Reads a tecplot ascii fem file and creates datasets.

        Args:
            filepath (path): Full path to the file.
        """
        if not self._co_grid_2d:
            return
        reader = TecplotFemBlockReader(filepath, self._co_grid_2d.uuid, self._dset_time_units)
        dataset_writers = reader.read()
        self._datasets_and_folders.extend([(dataset_writer, 'olf') for dataset_writer in dataset_writers])

    def _read_tecplot_ordered_block_file(self, filepath: Path) -> None:
        """Reads a tecplot ascii block file and creates datasets.

        Args:
            filepath (path): Full path to the file.
        """
        domain, two_d = self._tecplot_file_domain(filepath.name)
        if (two_d and not self._co_grid_2d) or (not two_d and not self._co_grid_3d):
            return

        if two_d and not self._rtree2d:
            self._rtree2d = _build_rtree2(self._ugrid_2d.locations)
        elif not two_d and not self._rtree3d:
            self._rtree3d = _build_rtree2(self._ugrid_3d.locations)  # pragma no cover - may not need this but maybe
        rtree = self._rtree2d if two_d else self._rtree3d
        ugrid = self._ugrid_2d if two_d else self._ugrid_3d
        uuid = self._co_grid_2d.uuid if two_d else self._co_grid_3d.uuid
        folder = f'{domain}/BCs'

        reader = TecplotOrderedBlockReader(filepath, ugrid, uuid, rtree, self._dset_time_units)

        dataset_writers = reader.read()
        self._datasets_and_folders.extend([(dataset_writer, folder) for dataset_writer in dataset_writers])

    def _read_tecplot_point_file(self, filepath: Path) -> None:
        """Reads a tecplot point file and creates XY series.

        Args:
            filepath (path): Full path to the file.
        """
        reader = TecplotPointReader(filepath)
        xy_series_list = reader.read()
        self._xy_series.extend(xy_series_list)

    @staticmethod
    def _compute_insert_count(grid_values_count: int, components_count: int, len_values: int) -> int:
        """Returns the number of values that need to be inserted to make the size of the values match the grid.

        Args:
            grid_values_count (int): Number of grid points or cells.
            components_count (int): Number of data components (e.g. 1 if scalar, 3 if vector).
            len_values (int): Length of the values.

        Returns:
            (int): See description.
        """
        return (grid_values_count * components_count) - len_values

    @staticmethod
    def _insert_nodata_values(num_to_insert: int, values: ndarray) -> ndarray:
        """Inserts nodata values at the beginning of values array so that it becomes size of count."""
        new_values = np.full(num_to_insert, util.nodata)
        return np.insert(values, 0, new_values)

    def _find_grid_values_count(self, two_d: bool, data_location: str) -> int:
        """Returns the number of grid points or the number of cells.

        Args:
            two_d (bool): True if we're reading a 2D file.
            data_location (str): 'cell-centered', 'nodal' - where the values are located.

        Returns:
            (int): Either the number of grid points or the number of cells.
        """
        ugrid = self._ugrid_2d if two_d else self._ugrid_3d
        return ugrid.point_count if data_location == 'nodal' else ugrid.cell_count

    @staticmethod
    def _binary_file_domain(file_name_base: str) -> tuple[str, bool]:
        """Returns True if the file contains a 2D dataset (olf domain)."""
        pos0 = file_name_base.rfind('_')
        domain_subfolder = file_name_base[pos0 + 1:].lower()
        return domain_subfolder, domain_subfolder == 'olf'

    def _tecplot_file_domain(self, file_name_base: str) -> tuple[str, bool]:
        """Returns True if the file contains a 2D dataset (olf domain)."""
        words = file_name_base.split('.')
        if len(words) > 2:
            bc_name = words[2]  # e.g. From 'R5o.Bc.Well.dat' we'll get 'Well'
            bc_names_and_domains = self._gms_data.get('bc_names_and_domains')
            if bc_names_and_domains:
                domain = bc_names_and_domains.get(bc_name)
                return domain, domain != domain_abbreviation(Domains.PM)
        return '', False

    @staticmethod
    def _make_activity(components_count: int, values: ndarray) -> ndarray:
        """Returns an activity array computed by looking for nodata values.

        Expects the values array to be reshaped by now if it represents vector data.

        Args:
            components_count (int):
            values (ndarray):

        Returns:
            (ndarray): The activity array.
        """
        if components_count == 1:
            activity = np.where(values == util.nodata, 0, 1)
        else:
            activity = np.all(values != util.nodata, axis=1)
        return activity


def _rtree_insert_generator_function(grid_locations: Locations):
    """This generator function is supposed to be a faster way to populate the rtree?

    https://rtree.readthedocs.io/en/latest/performance.html#use-stream-loading
    """
    for i, location in enumerate(grid_locations):
        yield i, (location[0], location[1], location[2], location[0], location[1], location[2]), location


# def _build_rtree(self, grid_locations: Locations):
#     """Builds an rtree.
#
#     https://rtree.readthedocs.io/en/latest/tutorial.html#insert-records-into-the-index
#     """
#     p = index.Property()
#     p.dimension = 3
#     rtree = index.Index(properties=p)
#     for i, location in enumerate(grid_locations):
#         rtree.insert(i, (location[0], location[1], location[2], location[0], location[1], location[2]))
#     return rtree


def _build_rtree2(grid_locations: Locations):
    """Builds an rtree using a generator function which is supposed to be faster.

    https://rtree.readthedocs.io/en/latest/performance.html#use-stream-loading
    """
    p = index.Property()
    p.dimension = 3
    return index.Index(_rtree_insert_generator_function(grid_locations), properties=p)
