"""Writes the PEST obs statistics file."""
__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from dataclasses import dataclass, field
import os

# 2. Third party modules

# 3. Aquaveo modules

# 4. Local modules
from xms.mf6.file_io.pest import pest_obs_util
from xms.mf6.file_io.pest.pest_obs_results_reader import Obs, ObsResults
from xms.mf6.misc import util

# Constants
INDENT = '  '  # Standard indentation


@dataclass
class Residuals:
    """All three types of residuals."""
    dep_var: list[float] = field(default_factory=list)
    flow: list[float] = field(default_factory=list)
    weighted: list[float] = field(default_factory=list)


def write(model_ftype: str, pest_obs_dir: str, obs_results: ObsResults) -> str:
    """Writes the PEST obs statistics file and returns the filepath.

    Args:
        model_ftype: 'GWF6', 'GWT6', 'GWE6'.
        pest_obs_dir: Path to directory containing pestfiles.
        obs_results: The results from pest_obs_results_reader.
    """
    writer = PestObsStatsWriter(model_ftype, pest_obs_dir, obs_results)
    return writer.write()


class PestObsStatsWriter:
    """Writes a file containing the observation statistics.
    """
    def __init__(self, model_ftype: str, pest_obs_dir: str, obs_results: ObsResults):
        """Initializer.

        Args:
            model_ftype: 'GWF6', 'GWT6', 'GWE6'.
            pest_obs_dir: Path to directory containing pest files.
            obs_results: The results from pest_obs_results_reader.
        """
        self._model_ftype = model_ftype
        self._pest_obs_dir = pest_obs_dir
        self._obs_results: ObsResults = obs_results
        self._residuals = Residuals()
        self._alias_cov_tree_path: dict[str:str] = {}  # alias: coverage tree path
        self._coverage_residuals: dict[str:Residuals] = {}  # coverage tree path: all three types of residuals
        # mean, mae, and rms for dependent variable (head, concentration, or temperature), flow, and weighted
        self._stats = {}

    def write(self):
        """Writes the PEST obs statistics file."""
        self._make_alias_to_cov_tree_path_dict()
        self._extract_residuals()
        self._compute_stats()
        return self._write_file()

    def _compute_stats(self):
        """Calculate and store the mean, mean absolute error, and root mean squared error."""
        # Overall stats
        self._set_stats('dep_var', _compute_mean_mea_rms(self._residuals.dep_var))
        self._set_stats('flow', _compute_mean_mea_rms(self._residuals.flow))
        sswr = _compute_sum_of_squared_weighted_residual(self._residuals.weighted)
        self._set_stats('weighted', _compute_mean_mea_rms(self._residuals.weighted), sswr)

        # Stats by coverage
        for cov_tree_path, residuals in self._coverage_residuals.items():
            self._set_stats(f'{cov_tree_path} dep_var', _compute_mean_mea_rms(residuals.dep_var))
            self._set_stats(f'{cov_tree_path} flow', _compute_mean_mea_rms(residuals.flow))
            sswr = _compute_sum_of_squared_weighted_residual(residuals.weighted)
            self._set_stats(f'{cov_tree_path} weighted', _compute_mean_mea_rms(residuals.weighted), sswr)

    def _write_file(self) -> str:
        """Writes the stats to a text file and returns the filename."""
        output_dir = _model_output_dir(self._pest_obs_dir)
        filename = os.path.join(output_dir, 'pest_obs_stats.txt')
        dep_var_name = self._dep_var_name()
        with open(filename, 'w') as file:
            file.write('Overall statistics\n\n')
            extra_indent = 0
            self._write_stats(dep_var_name, 'dep_var', extra_indent, file)
            file.write('\n')
            self._write_stats('Flow', 'flow', extra_indent, file)
            file.write('\n')
            self._write_weighted_stats(dep_var_name, 'weighted', extra_indent, file)

            # By coverage
            file.write('\nStatistics by coverage\n')
            for cov_tree_path, _residuals in self._coverage_residuals.items():
                file.write('\n')
                cov_name = pest_obs_util.cov_name_from_path(cov_tree_path)
                file.write(f'{_indent(1, 0)}{cov_name}:\n')
                file.write('\n')
                extra_indent = 1
                self._write_stats(dep_var_name, f'{cov_tree_path} dep_var', extra_indent, file)
                file.write('\n')
                self._write_stats('Flow', f'{cov_tree_path} flow', extra_indent, file)
                file.write('\n')
                self._write_weighted_stats(dep_var_name, f'{cov_tree_path} weighted', extra_indent, file)
        return filename

    def _make_alias_to_cov_tree_path_dict(self) -> None:
        """Make a dict of alias -> coverage uuid."""
        self._alias_cov_tree_path = {}
        b2map = self._obs_results.b2map
        for _cov_uuid, cov_data in b2map.items():
            cov_tree_path = cov_data['coverage_tree_path']
            for _feature_type, alias_geom in cov_data['features'].items():
                for alias in alias_geom.keys():
                    self._alias_cov_tree_path[alias] = cov_tree_path

    def _dep_var_name(self) -> str:
        """Return the name of the dependent variable (head, concentration, or temperature) based on the model ftype."""
        if self._model_ftype == 'GWF6':
            return 'Head'
        elif self._model_ftype == 'GWT6':
            return 'Concentration'
        elif self._model_ftype == 'GWE6':
            return 'Temperature'
        else:
            raise ValueError(f'Model ftype "{self._model_ftype}" not supported for PEST observations.')

    def _write_stats(self, heading: str, stats_type: str, extra_indent: int, file) -> None:
        """Write the stats for the dependent variable or flow.

        Args:
            heading: First line heading.
            stats_type: Key to self._stats dict.
            extra_indent: Extra indentation.
            file: The file being written to.
        """
        file.write(f'{_indent(1, extra_indent)}{heading}:\n')
        file.write(f'{_indent(2, extra_indent)}Mean residual: {self._stats[stats_type]["mean"]}\n')
        file.write(f'{_indent(2, extra_indent)}Mean absolute residual: {self._stats[stats_type]["mae"]}\n')
        file.write(f'{_indent(2, extra_indent)}Root mean squared residual: {self._stats[stats_type]["rms"]}\n')

    def _write_weighted_stats(self, dep_var_name: str, stats_type: str, extra_indent: int, file) -> None:
        """Write the weighted stats.

        Args:
            dep_var_name: Name of the dependent variable ('Head', 'Concentration' etc)
            stats_type: Key to self._stats dict.
            extra_indent: Extra indentation.
            file: The file being written to.
        """
        self._write_stats(f'Weighted ({dep_var_name.lower()} + flow)', stats_type, extra_indent, file)
        file.write(f'{_indent(2, extra_indent)}Sum of squared residual: {self._stats[stats_type]["sswr"]}\n')

    def _get_coverage_tree_path(self, alias: str) -> str:
        """Returns the coverage tree path associated with the alias.

        This is a one line function that can be mocked when testing.

        Args:
            alias: Observation alias.

        Returns:
            See description.
        """
        return self._alias_cov_tree_path[alias]

    def _extract_residuals(self) -> None:
        """Get residuals (for non mftimes pseudo samples) and sort them into dep_var, flow, and weighted."""
        obs: Obs = self._obs_results.obs
        for alias, observation in obs.items():
            cov_tree_path: str = self._get_coverage_tree_path(alias)
            for _date_time, obs_vals in observation.items():
                if obs_vals['observed'] is not None and obs_vals['computed'] is not None:
                    residual = obs_vals['observed'] - obs_vals['computed']
                    if obs_vals['flow']:
                        self._residuals.flow.append(residual)
                        self._coverage_residuals.setdefault(cov_tree_path, Residuals()).flow.append(residual)
                    else:
                        self._residuals.dep_var.append(residual)
                        self._coverage_residuals.setdefault(cov_tree_path, Residuals()).dep_var.append(residual)
                    if obs_vals['weight'] is not None:
                        wt_comp = obs_vals['computed'] * obs_vals['weight']
                        wt_obs = obs_vals['observed'] * obs_vals['weight']
                        wt_res = wt_obs - wt_comp
                        self._residuals.weighted.append(wt_res)
                        self._coverage_residuals.setdefault(cov_tree_path, Residuals()).weighted.append(wt_res)

    def _set_stats(self, stats_type: str, mean_mea_rms: tuple[float, float, float], sswr: float | None = None) -> None:
        """Add stats to self._stats dict.

        Args:
            stats_type: dep_var_name, 'flow', 'weighted' etc.
            mean_mea_rms: Tuple of the three statistics.
            sswr: Sum of squared weighted residual, used with weighted.
        """
        stats = {'mean': mean_mea_rms[0], 'mae': mean_mea_rms[1], 'rms': mean_mea_rms[2], 'sswr': sswr}
        self._stats[stats_type] = stats


def _model_output_dir(pest_obs_dir: str) -> str:
    """Return the model output directory.

    Args:
        pest_obs_dir: Path to directory containing pest files.

    Returns:
        See description.
    """
    base, _, _ = pest_obs_dir.rpartition('_pest')
    return base + '_output'


def _compute_mean_mea_rms(residuals: list[float]) -> tuple[float, float, float]:
    """Compute and return the mean, mean absolute error, and root mean squared error.

    Args:
        residuals: List of residuals.

    Returns:
        tuple[float, float, float].
    """
    return util.mean(residuals), util.mae(residuals), util.rms(residuals)


def _compute_sum_of_squared_weighted_residual(wt_residuals: list[float]) -> float:
    """Compute and return the sum of squared weighted residual.

    Args:
        wt_residuals: List of the weighted residuals.

    Returns:
        See description.
    """
    ssr = 0.0
    for res in wt_residuals:
        ssr += res * res
    return ssr


def _indent(indent_level: int, count: int) -> str:
    return INDENT * (indent_level + count)
