"""Reads the PEST obs results."""
__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from datetime import datetime
import glob
import os
from pathlib import Path
import shlex
import sys

# 2. Third party modules
import orjson

# 3. Aquaveo modules

# 4. Local modules
from xms.mf6.file_io.pest import pest_obs_data_generator
from xms.mf6.file_io.pest.obs_results import B2Map, Features, ObjGeom, Obs, ObsCovData, Observation, ObsResults, ObsVals
from xms.mf6.file_io.pest.pest_obs_data_generator import PEST_STRFTIME


def read(pest_obs_dir: str, times: list[float]) -> ObsResults | None:
    """Reads the PEST output files.

    Args:
        pest_obs_dir: Path to directory containing all the PEST files.
        times: list of timestep times.

    Returns:
        The results.
    """
    reader = PestObsResultsReader()
    return reader.read(pest_obs_dir, times)


class PestObsResultsReader:
    """Reads the PEST output files into a dict.
    """
    def __init__(self):
        """Initializer.
        """
        self._obs_results = ObsResults()  # All the data including the observation data ('obs') and b2map data ('b2map')
        self._obs: Obs = {}  # The observation data part of self._obs_results
        self._times = []
        self._steady_state_time = None

    def read(self, pest_obs_dir: str, times: list[float]) -> ObsResults | None:
        """Reads the PEST output files.

        Args:
            pest_obs_dir: Path to directory containing all the PEST files.
            times: list of timestep times.

        Returns:
            The results, or None if there was a problem.
        """
        self._times = times
        file_prefix = pest_files_prefix(pest_obs_dir)
        if not file_prefix:
            return None
        self._read_xsamp_files(file_prefix)  # .bsamp, .bsamp.out and .fsamp, .fsamp.out
        self._read_samp_file(file_prefix)  # .samp
        self._read_and_store_b2map_file(file_prefix)
        self._sort_obs_data()
        self._obs_results.obs = self._obs
        return self._obs_results

    def _read_xsamp_files(self, file_prefix: str) -> None:
        """Reads .bsamp and .bsamp.out files, and .fsamp and .fsamp.out files and stores data in self._obs_results."""
        # Do the .mftimes first because, when steady state, we make all times equal to the first mftime time
        extensions = [
            ('.mftimes-bsamp', '.mftimes-bsamp.out', ''),
            ('.bsamp', '.bsamp.out', '.bwt'),
            ('.fsamp', '.fsamp.out', '.fwt'),
        ]
        for extension_tuple in extensions:
            filename = file_prefix + extension_tuple[0]
            filename_out = file_prefix + extension_tuple[1]
            filename_wt = file_prefix + extension_tuple[2]
            if os.path.isfile(filename) and os.path.isfile(filename_out):
                flow = '.fsamp' in filename
                self._read_xsamp_file(filename, filename_out, filename_wt, flow=flow)

    def _read_xsamp_file(self, filename_xsamp: str, filename_xsamp_out: str, filename_wt: str, flow: bool) -> None:
        """Reads either the .bsamp and .bsamp.out files, or the .fsamp and .fsamp.out files.

        Files are read simultaneously.
        """
        mftimes = True if '.mftimes-bsamp' in filename_xsamp else False
        file_wt = None
        with open(filename_xsamp, 'r') as file_xsamp:
            with open(filename_xsamp_out, 'r') as file_xsamp_out:
                if os.path.isfile(filename_wt):
                    # Not using a 'with' context here because the file won't exist for the .mftimes-bsamp case
                    file_wt = open(filename_wt, 'r')
                for line_xsamp, line_xsamp_out in zip(file_xsamp, file_xsamp_out):
                    line_wt = file_wt.readline() if file_wt else ''
                    self._read_xsamp_line(flow, mftimes, line_xsamp, line_xsamp_out, line_wt)
        if file_wt:
            file_wt.close()

    def _read_samp_file(self, file_prefix: str) -> None:
        """Reads the .samp file."""
        samp_filepath = Path(file_prefix + '.samp')
        if not samp_filepath.is_file():
            return
        with samp_filepath.open('r') as samp_file:
            for samp_line in samp_file:
                if samp_line.strip().startswith('FLOW'):  # We already read this when we read the .fsamp files
                    continue
                self._read_xsamp_line(flow=True, mftimes=True, line_observed='', line_computed=samp_line, line_wt='')

    def _is_steady_state(self) -> bool:
        return len(self._times) == 1

    def _read_xsamp_line(self, flow: bool, mftimes: bool, line_observed: str, line_computed: str, line_wt: str) -> None:
        """Reads the lines from the samp files and stores the data in self._obs_results."""
        observed = _read_observed_value(line_observed, mftimes)
        alias, dt, computed = _read_alias_and_date_and_computed(line_computed)
        dt = self._handle_steady_state(dt, mftimes)
        weight = _read_weight(line_wt)
        self._add_obs_data(alias, dt, observed, computed, flow, weight)

    def _handle_steady_state(self, dt: datetime, mftimes: bool) -> datetime:
        """If the model is steady state, fixes the date time and stores it for use later."""
        if mftimes and self._is_steady_state() and self._steady_state_time is None:
            # The first mftimes time should be the steady state time. Save it the first time.
            self._steady_state_time = dt
        if self._is_steady_state():
            dt = self._steady_state_time
        return dt

    def _add_obs_data(
        self, alias: str, dt: datetime, observed: float | None, computed: float | None, flow: bool, weight: float
    ):
        """Adds everything to the dict."""
        if alias not in self._obs:
            self._obs[alias]: Observation = {}
        if dt not in self._obs[alias]:
            self._obs[alias][dt]: ObsVals = {}
        if self._obs[alias][dt].get('observed', None) is None:
            self._obs[alias][dt]['observed'] = observed
        self._obs[alias][dt]['computed'] = computed
        self._obs[alias][dt]['flow'] = flow
        self._obs[alias][dt]['weight'] = weight

    def _sort_obs_data(self) -> None:
        """Sort the date times because we mixed the observed stuff with the modflow times."""
        new_obs_data = {}
        for alias, data in self._obs.items():
            new_obs_data[alias] = dict(sorted(data.items()))
        self._obs = new_obs_data

    def _read_and_store_b2map_file(self, file_prefix: str) -> None:
        """Read the .b2map file and store it in self._obs_results.

        Args:
            file_prefix: Path to the .b2map file.
        """
        self._obs_results.b2map = self._read_b2map_file(file_prefix)

    def _read_b2map_file(self, file_prefix: str) -> B2Map:
        """Read the .b2map file and return data as a dict.

        Args:
            file_prefix: Path to the .b2map file.
        """
        b2map: B2Map = {}
        filename = file_prefix + '.b2map'
        with open(filename, 'rb') as file:
            json_data = orjson.loads(file.read())
            for alias, data in json_data.items():
                if alias in self._obs:
                    alias = alias.upper()
                    cov_uuid, feature_type, _feature_id = pest_obs_data_generator.map_id_tuple_from_map_id(
                        data['map_id']
                    )
                    if cov_uuid not in b2map:
                        cov_path = data.get('coverage_tree_path', '')  # Might not exist because we added it in 2025
                        features: Features = {'points': {}, 'arcs': {}, 'arc_groups': {}, 'polygons': {}}
                        b2map[cov_uuid]: ObsCovData = {'coverage_tree_path': cov_path, 'features': features}

                    if feature_type == 'point':
                        b2map[cov_uuid]['features']['points'][alias]: ObjGeom = {
                            'geometry': data['geometry'],
                            'interval': data['interval']
                        }
                    elif feature_type == 'arc':
                        b2map[cov_uuid]['features']['arcs'][alias]: ObjGeom = {
                            'geometry': data['geometry'],
                            'interval': data['interval']
                        }
                    elif feature_type == 'arc_group':
                        b2map[cov_uuid]['features']['arc_groups'][alias]: ObjGeom = {
                            'geometry': data['geometry'],
                            'interval': data['interval']
                        }
                    elif feature_type == 'polygon':
                        b2map[cov_uuid]['features']['polygons'][alias]: ObjGeom = {
                            'geometry': data['geometry'],
                            'interval': data['interval']
                        }
        return b2map


def _read_weight(line_wt: str) -> float | None:
    """Reads and returns the weight."""
    weight = None
    if line_wt:
        words_wt = shlex.split(line_wt, posix="win" not in sys.platform)  # shlex handles quoted strings
        weight = float(words_wt[3])
    return weight


def _read_alias_and_date_and_computed(line_computed: str) -> tuple[str, datetime, float]:
    """Reads and returns the alias, the date time, and the computed value."""
    words_computed = shlex.split(line_computed, posix="win" not in sys.platform)  # shlex handles quoted strings
    alias = words_computed[0].strip('"\'‘’“”').upper()
    dt = datetime.strptime(f'{words_computed[1]} {words_computed[2]}', PEST_STRFTIME)
    try:
        computed = float(words_computed[3])
    except ValueError:
        computed = None
    return alias, dt, computed


def _read_observed_value(line_observed: str, mftimes: bool) -> float | None:
    """Reads and returns the observed value if we have one."""
    observed = None
    if line_observed:
        words_observed = shlex.split(line_observed, posix="win" not in sys.platform)  # shlex handles quoted strings
        # The MODFLOW times don't have observed values. Use NaN
        observed = None if mftimes else float(words_observed[3])
    return observed


def pest_files_prefix(pest_obs_dir: str) -> str:
    """Returns the filename prefix of the files in the pest obs directory."""
    b2map_files = glob.glob(f'{pest_obs_dir}/*.b2map')
    if not b2map_files:
        return None  # pragma no cover - error condition
    return os.path.splitext(b2map_files[0])[0]
