"""Testing utility functions."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import filecmp
import glob
import json
import os
from pathlib import Path
import shutil
from typing import List, Optional, Union

# 2. Third party modules
import numpy
import pytest

# 3. Aquaveo modules
from xms.constraint import read_grid_from_file
from xms.datasets.dataset_reader import DatasetReader
from xms.gdal.utilities import compare_tif_files
from xms.tool_core import Argument, Tool, ToolError, ValidateError

# 4. Local modules
from xms.tool.utilities.file_utils import get_test_files_path
from xms.tool.utilities.h5_file_to_text import write_h5_as_text_file


def arguments_log_string(arguments):
    """Returns a string representation of the arguments to a tool.

    Args:
        arguments (list): list of tool arguments

    Returns:
        (string): see description

    """
    lines = []
    for argument in arguments:
        lines.append(f'\'{argument.name}\': {argument.text_value}')
    return 'Input parameters: {' + ', '.join(lines) + '}'


def run_dataset_test(tool, arguments, outputs, test_folder, expected_output,
                     expected_validate_error=None):
    """Test run tool.

    Args:
        tool (Tool): Instance of the Tool being tested.
        arguments (list): The tool arguments.
        outputs (list(int)): List of indices in to arguments for each output dataset.
        test_folder (str): path of the test folder where baselines are stored
        expected_output (str): Expected tool output if tool succeeds.
        expected_validate_error (str): Expected validate error.
    """
    tool.echo_output = False

    if expected_validate_error is not None:
        assert_tool_validate_error(tool, arguments, expected_validate_error)
        return

    tool.run_tool(arguments)
    output = tool.get_testing_output()
    assert expected_output == output

    # Compare .h5 dataset files
    for idx in outputs:
        out_name = arguments[idx].value
        out_prefix = out_name.replace('_out', '')
        assert out_prefix != out_name
        base_file = f'{out_prefix}_base.h5'
        out_file = f'{out_name}.h5'
        base_path = os.path.join(get_test_files_path(), test_folder, base_file)
        out_path = os.path.join(get_test_files_path(), test_folder, out_file)
        assert_dataset_files_equal(base_path, out_path)


def run_raster_test(tool, arguments, outputs, expected_output=None,
                    expected_validate_error=None, expected_tool_error=None, expected_output_end=None,
                    test_exists_only=False, raster_extension: str = '.tif'):
    """Test raster tool.

    Args:
        tool (Tool): Instance of the raster Tool being tested.
        arguments (list): The tool arguments.
        outputs (list(int)): List of indices in to arguments for each output raster.
        expected_output (str): Expected output to compare with test display output.
        expected_validate_error (str): Expected error string to compare with test error output.
        expected_tool_error (str): Expected error string to compare with test error output.
        expected_output_end (str): Expected ending output to compare with test display output.
        test_exists_only (bool): Only check whether the raster exists, don't actually check the output.
        raster_extension (str): The raster extensions
    """
    tool.echo_output = False
    if expected_validate_error is not None:
        assert_tool_validate_error(tool, arguments, expected_validate_error)
        return
    if expected_tool_error is not None:
        assert_tool_error(tool, arguments, expected_tool_error)
        return

    tool.run_tool(arguments)

    if expected_output is not None:
        output = tool.get_testing_output()
        expected_output = _update_ignored_text(expected_output, output)
        if expected_output_end is None:
            assert expected_output == output
        else:
            expected_output_end = _update_ignored_text(expected_output_end, output)
            assert expected_output in output
            assert expected_output_end in output

    if test_exists_only:
        check_rasters_exist(tool, arguments, outputs)
    else:
        compare_raster_outputs(tool, arguments, outputs, raster_extension)


def _get_output_rasters(tool, arguments, outputs, raster_extension: str = '.tif'):
    """Gets a list of output rasters and their baseline rasters.

    Args:
        tool (Tool): Instance of the raster Tool being tested.
        arguments (list): The tool arguments.
        outputs (list(int)): List of indices in to arguments for each output raster.
        raster_extension (str): The raster file extension

    Returns:
        (list): List of tuples containing the output rasters and their baselines.
    """
    rasters = []
    for idx in outputs:
        out_path = arguments[idx].value
        if not os.path.isfile(out_path):
            out_path = tool.get_output_raster(arguments[idx].value, raster_extension)
        base_path = out_path.replace(f'_out{raster_extension}', f'_base{raster_extension}')
        out_basename = os.path.basename(out_path)
        assert out_basename != os.path.basename(base_path)
        rasters.append((base_path, out_path))
    return rasters


def compare_raster_outputs(tool, arguments, outputs, raster_extension: str = '.tif'):
    """Compare the raster outputs.

    Args:
        tool (Tool): Instance of the raster Tool being tested.
        arguments (list): The tool arguments.
        outputs (list(int)): List of indices in to arguments for each output raster.
        raster_extension (str): The raster file extensions
    """
    rasters = _get_output_rasters(tool, arguments, outputs, raster_extension)
    for raster in rasters:
        return_code, output = compare_tif_files(raster[0], raster[1])
        assert 'Differences Found: 0' == output
        assert 0 == return_code


def check_rasters_exist(tool, arguments, outputs):
    """Compare the raster outputs.

    Args:
        tool (Tool): Instance of the raster Tool being tested.
        arguments (list): The tool arguments.
        outputs (list(int)): List of indices in to arguments for each output raster.
    """
    rasters = _get_output_rasters(tool, arguments, outputs)
    for raster in rasters:
        assert os.path.isfile(raster[1])


def get_grid_uuids(folder_path):
    """Get the uuids for all grids under a given path.

    Args:
        folder_path (str): The path to the folder containing the grids.

    Returns:
        (dict): Dictionary of uuid to grid path.

    """
    uuids_to_grids = {}
    grid_files = glob.glob(folder_path + '/**/*.xmc', recursive=True)
    for grid_file in grid_files:
        grid = read_grid_from_file(grid_file)
        if grid is not None:
            uuids_to_grids[grid.uuid] = grid_file
    return uuids_to_grids


def build_dataset_meta_dict(reader: DatasetReader, ignored_keys: Union[str, list[str]]) -> dict[str, object]:
    """Convert a dataset to a dictionary with metadata.

    Args:
        reader: The dataset.
        ignored_keys: The dictionary keys to remove.

    Returns:
        (dict): JSON compatible dictionary with dataset metadata.
    """
    dset_data = {
        'uuid': reader.uuid,
        'geom_uuid': reader.geom_uuid,
        'ref_time': str(reader.ref_time),
        'null_value': reader.null_value,
        'time_units': reader.time_units,
        'num_activity_values': reader.num_activity_values,
        'num_components': reader.num_components,
        'num_values': reader.num_values,
    }
    if isinstance(ignored_keys, str):
        ignored_keys = [ignored_keys]
    for key in ignored_keys:
        dset_data.pop(key, None)
    return dset_data


def build_dataset_value_dict(reader: DatasetReader, ignored_keys: Union[str, list[str]]):
    """Convert a dataset to a dictionary with value data.

    Args:
        reader: The dataset.
        ignored_keys: The dictionary keys to remove.

    Returns:
        (dict): JSON compatible dictionary with dataset value data.
    """
    active_dset = None
    try:
        active_dset = reader.activity[()]
    except TypeError:
        pass
    dset_data = {
        'values': reader.values[()],
        'times': reader.times[()],
        'mins': reader.mins[()],
        'maxs': reader.maxs[()],
        'active': active_dset
    }
    if isinstance(ignored_keys, str):
        ignored_keys = [ignored_keys]
    for key in ignored_keys:
        dset_data.pop(key, None)
    return dset_data


def dataset_to_dict(file: str, ignored_keys: Union[str, list[str]],
                    dataset_name: Optional[str] = None):
    """Convert a dataset file to a dictionary.

    Args:
        file: The file to convert. File basename must be the same as the dataset name.
        ignored_keys: The dictionary keys to remove.
        dataset_name: The dataset name.

    Returns:
        (dict): JSON compatible dictionary describing the dataset.
    """
    reader = get_dataset_reader(file, dataset_name)
    dset_data = build_dataset_value_dict(reader, ignored_keys)
    metadata = build_dataset_meta_dict(reader, ignored_keys)
    dset_data.update(metadata)
    return dset_data


def get_dataset_reader(file: str, dataset_name: Optional[str] = None,
                       group_path: Optional[str] = None) -> DatasetReader:
    """Get a dataset's reader.

    Args:
        file: The dataset file.
        dataset_name: The dataset name.
        group_path: The group path to the dataset.

    Returns:
        A dataset reader.
    """
    basename = os.path.basename(file)
    if group_path is None and dataset_name is None:
        dataset_name, _ = os.path.splitext(basename)
    reader = DatasetReader(file, dset_name=dataset_name, group_path=group_path)
    return reader


def dump_dataset_json(dataset_file: str, ignored_keys: Union[str, List[str]], dataset_name: Optional[str] = None,
                      output_file: Optional[str] = None) -> None:
    """Export a dataset in an H5 file to JSON.

    Args:
        dataset_file: The path to the H5 dataset file.
        ignored_keys: The dictionary keys to remove from the JSON dictionary.
        dataset_name: Optional dataset name. If not specified uses the base name of the file.
        output_file: Optional path to the output file.
    """
    if output_file is None:
        base_path, _ = os.path.splitext(dataset_file)
        output_file = f'{base_path}.json'
    dataset_json = dataset_to_dict(dataset_file, ignored_keys, dataset_name=dataset_name)
    for key, value in dataset_json.items():
        if isinstance(value, numpy.ndarray):
            dataset_json[key] = value.tolist()
    with open(output_file, 'w') as dataset_file:
        s = json.dumps(dataset_json, indent=4, sort_keys=True)
        dataset_file.write(s)
        dataset_file.write('\n')


def _update_ignored_text(expected, output):
    """Changes expected ignored lines to match output.

    Args:
        expected (str): Base output that contains "$IGNORE: ...$" to mark ignored lines.
        output (str): Output text to remove matching lines from.

    Returns (str):
        Updated expected string with ignored sections replaced.
    """
    expected_lines = expected.splitlines(keepends=True)
    output_lines = output.splitlines(keepends=True)
    if len(expected_lines) != len(output_lines):
        return expected

    new_expected = ''
    for expected_line, output_line in zip(expected_lines, output_lines):
        if expected_line.startswith('$IGNORE_LINE$'):
            new_expected += output_line
        else:
            new_expected += expected_line
    return new_expected


def compare_ugrids(base, out, tol):
    """Compares the ugrids. Uses a tolerance for the point locations.

    Args:
        base (UGrid): base grid
        out (UGrid): output grid
        tol (float): tolerance for point locations

    Returns:
        (bool): True if the grids match
    """
    base_cell = base.cellstream
    out_cell = out.cellstream
    if base_cell != out_cell:
        print('Cells do not match.\n')
        return False
    base_loc = base.locations
    out_loc = out.locations
    if len(base_loc) != len(out_loc):
        print('Number of UGrid points does not match.\n')
        return False
    for i in range(len(base_loc)):
        if abs(base_loc[i][0] - out_loc[i][0]) > tol:
            print('Grid locations do not match.\n')
            return False
        if abs(base_loc[i][1] - out_loc[i][1]) > tol:
            print('Grid locations do not match.\n')
            return False
        if abs(base_loc[i][2] - out_loc[i][2]) > tol:
            print('Grid locations do not match.\n')
            return False
    return True


def clean_directory(folder, prefix):
    """
    Clean all the files out of a directory whose names start with prefix.

    Args:
        folder(str): Absolute path to folder to clean files out of.
        prefix(str): Only files starting with prefix are removed.
    """
    for file in os.listdir(folder):
        if file.startswith(prefix):
            os.remove(os.path.join(folder, file))


def directories_equal(left_dir, right_dir):
    """Compare two directories.

    Args:
        left_dir(str): First directory to compare.
        right_dir(str): Second directory to compare.

    Returns:
        True if directories' contents are same, False otherwise.
    """
    return _directories_equal(left_dir, right_dir, files_equal)


def recreate_directory(directory):
    """Delete and recreate a directory.

    Args:
        directory(str): Directory to recreate.
    """
    remove_directory(directory)
    os.makedirs(directory)


def remove_directory(directory: str):
    """Remove a directory if it exists.

    Args:
        directory: The path to the directory to remove.
    """
    try:
        shutil.rmtree(directory)
    except FileNotFoundError:
        pass  # We wanted it gone and it is. No problem.


def remove_file(file):
    """Remove a file, ignoring errors from the file not existing.

    Args:
        file(str): File to remove.
    """
    try:
        os.remove(file)
    except FileNotFoundError:
        pass


def _directories_equal(left_dir, right_dir, compare):
    """Compare two directories.

    Args:
        left_dir(str): First directory to compare.
        right_dir(str): Second directory to compare.
        compare(function): Function used to compare files. Should take two file names and return whether to
                           consider the files equal.

    Returns:
        True if directories' contents are same, False otherwise.
    """
    left_files = os.listdir(left_dir)
    left_files.sort()
    right_files = os.listdir(right_dir)
    right_files.sort()

    if left_files != right_files:
        return False

    for file in left_files:
        left = os.path.join(left_dir, file)
        right = os.path.join(right_dir, file)
        if not compare(left, right):
            return False

    return True


def files_equal(first: os.PathLike | str, second: os.PathLike | str) -> bool:
    """
    Compare two files and see if they have the same content.

    Similar to filecmp.cmp(), except it defaults to a deep comparison rather than trusting file metadata.

    Args:
        first: First file to compare.
        second: Second file to compare.

    Returns:
        Whether the two files have the same content.
    """
    return filecmp.cmp(first, second, shallow=False)


def text_files_equal(first: os.PathLike | str, second: os.PathLike | str) -> bool:
    """
    Compare two text files and see if they have the same content.

    Checks each line to see if lines are the same, ignoring newline characters.

    Args:
        first: First file to compare.
        second: Second file to compare.

    Returns:
        Whether the two files have the same content.
    """
    l1 = l2 = True
    with open(first, 'r') as f1, open(second, 'r') as f2:
        while l1 and l2:
            l1 = f1.readline()
            l2 = f2.readline()
            if l1 != l2:
                return False
    return True


def files_equal_ignore_newlines(first, second):
    """
    Compare two files and see if they have the same content while ignoring newlines.

    Args:
        first: First file to compare.
        second: Second file to compare.

    Returns:
        Whether the two files have the same content.
    """
    with open(first, 'r', newline=None) as f1, open(second, 'r', newline=None) as f2:
        return f1.read().splitlines() == f2.read().splitlines()


def assert_dataset_files_equal(base_file: str, out_file: str, allow_close=False,
                               dataset_name: Optional[str] = None, group_path: Optional[str] = None,
                               base_dataset_name: Optional[str] = None) -> None:
    """Assert two dataset files are equal for testing.

    Args:
        base_file: The path to the base dataset file.
        out_file: The path to the output dataset file.
        allow_close: Should dataset values be close or all equal?
        dataset_name: The name of the dataset.
        group_path: The group path to the dataset.
        base_dataset_name:  An alternate base dataset name to compare with the target dataset.
    """
    file_name = os.path.basename(out_file)
    if group_path is None and dataset_name is None:
        dataset_name, _ = os.path.splitext(file_name)
    if base_dataset_name is None:
        base_dataset_name = dataset_name
    base_reader = get_dataset_reader(base_file, dataset_name=base_dataset_name, group_path=group_path)
    out_reader = get_dataset_reader(out_file, dataset_name=dataset_name, group_path=group_path)
    # compare metadata
    base_metadata = build_dataset_meta_dict(base_reader, 'uuid')
    out_metadata = build_dataset_meta_dict(out_reader, 'uuid')
    same_metadata = base_metadata == out_metadata
    # compare dataset values
    base_json = build_dataset_value_dict(base_reader, 'uuid')
    out_json = build_dataset_value_dict(out_reader, 'uuid')
    same_values = dataset_values_equal(base_json, out_json, allow_close)

    base_json_file = os.path.splitext(base_file)[0] + '_out.json'
    out_json_file = os.path.splitext(out_file)[0] + '.json'
    if not (same_metadata and same_values):  # pragma no cover - only hit on failure
        dump_dataset_json(base_file, 'uuid', dataset_name, base_json_file)
        dump_dataset_json(out_file, 'uuid', dataset_name, out_json_file)
    assert same_metadata and same_values, f'\nfiles differ:\n  {base_json_file}\n  {out_json_file}'


def dataset_values_equal(base_json: dict, out_json: dict, allow_close) -> bool:
    """Determine if dataset values in JSON are equal.

    Args:
        base_json: Dictionary of base values (containing 'values', 'times', 'mins', 'maxs', 'active' numpy arrays).
        out_json: Dictionary of out values.
        allow_close: Should dataset values be close or all equal?

    Returns:
        True if values are equal.
    """
    same_values = True
    for value_type in ['values', 'times', 'mins', 'maxs', 'active']:
        out = out_json[value_type]
        base = base_json[value_type]
        out_is_none = out is None
        base_is_none = base is None
        if out_is_none and base_is_none:
            # when values are both None they are equal
            continue
        if value_type == 'active' and (out_is_none or base_is_none):
            # when only one has activity they are not equal
            same_values = False
        # compare numpy array of values
        same_values = same_values and out.shape == base.shape
        if allow_close:
            same_values = same_values and numpy.allclose(out, base, equal_nan=True)
        else:
            same_values = same_values and numpy.array_equal(out, base, equal_nan=True)
        if not same_values:
            break
    return same_values


def compare_h5_files(base_file: str, out_file: str, to_exclude: Optional[list[str]] = None,
                     allow_close: bool = False) -> str:
    """Compare two H5 files for testing.

    Args:
        base_file: The path to the base coverage file.
        out_file: The path to the output coverage file.
        to_exclude: Optional list of dataset or attribute names to exclude from printing.
        allow_close: Allow floating point values to be close but not equal.
    """
    if to_exclude is None:
        to_exclude = []
    base_txt_file = os.path.splitext(base_file)[0] + '.h5diff.txt'
    out_txt_file = os.path.splitext(out_file)[0] + '.h5diff.txt'
    base_writer = write_h5_as_text_file(base_file, base_txt_file, to_exclude=to_exclude)
    out_writer = write_h5_as_text_file(out_file, out_txt_file, to_exclude=to_exclude)
    equal = files_equal_ignore_newlines(base_txt_file, out_txt_file)
    if not equal and allow_close:
        base_item_names = list(base_writer.compare_items.keys())
        out_item_names = list(out_writer.compare_items.keys())
        base_item_names.sort()
        out_item_names.sort()
        if base_item_names != out_item_names:
            return (f'files differ:\n  {base_txt_file}\n  {out_txt_file}\n'
                    'groups, datasets, and/or attributes differ')
        for name in base_item_names:
            base_data = base_writer.compare_items[name]
            out_data = out_writer.compare_items[name]
            base_float = isinstance(base_data, numpy.ndarray) and base_data.dtype.kind == 'f'
            out_float = isinstance(out_data, numpy.ndarray) and out_data.dtype.kind == 'f'
            if base_float and out_float:
                equal = numpy.allclose(base_data, out_data, equal_nan=True)
                if not equal:
                    message = (f'files differ:\n  {base_txt_file}\n  {out_txt_file}\n'
                               f'{name} not close')
                    return message
            elif base_data != out_data:  # pragma no cover - only hit for differing groups which should fail above
                message = (f'files differ:\n  {base_txt_file}\n  {out_txt_file}\n'
                           f'{name} not equal')
                return message
        equal = True
    if not equal:
        return f'files differ:\n  {base_txt_file}\n  {out_txt_file}'
    return ''


def assert_coverage_files_equal(base_file: str, out_file: str, allow_close: bool = False) -> None:
    """Assert two coverage files are equal for testing.

    Args:
        base_file: The path to the base coverage file.
        out_file: The path to the output coverage file.
        allow_close: Allow floating point values to be close but not equal.
    """
    message = compare_h5_files(base_file, out_file, to_exclude=['GUID'], allow_close=allow_close)
    if message:  # pragma no cover - only hit when test fails
        pytest.fail(message)


def assert_tool_validate_error(tool: Tool, arguments: list[Argument], expected_error: str):
    """Assert that a tool raises the expected ValidateError.

    Args:
        tool: The tool.
        arguments: The arguments.
        expected_error: The expected validate error.
    """
    with pytest.raises(ValidateError) as validate_error:
        tool.run_tool(arguments)
    assert str(validate_error.value) == expected_error.rstrip()


def assert_tool_error(tool: Tool, arguments: list[Argument], expected_error: str):
    """Assert that a tool raises the expected ToolError.

    Args:
        tool: The tool.
        arguments: The tool arguments.
        expected_error: The expected error.
    """
    with pytest.raises(ToolError):
        tool.run_tool(arguments)
    testing_output = tool.get_testing_output()
    if expected_error:
        assert expected_error in testing_output


def get_folder_files(folder, extension) -> list[str]:
    """Get the relative path to files within a folder.

    Args:
        folder: The root folder.
        extension: The extension of files to get.

    Returns:
        List of relative file paths.
    """
    files = []
    folder_path = Path(folder)
    if not folder_path.is_dir():
        raise RuntimeError(f'Attempting to get folder files from non-existent path: {folder_path}')
    for file in folder_path.glob(f'**/*.{extension}'):
        files.append(str(file.relative_to(folder).as_posix()))
    files.sort()
    return files


def get_project_files(folder) -> list[str]:
    """Get the project files within a folder.

    Args:
        folder: The root folder.

    Returns:
        List of relative file paths.
    """
    files = get_folder_files(folder, 'xmc')
    files.extend(get_folder_files(folder, 'h5'))
    files.extend(get_folder_files(folder, 'tif'))
    files.extend(get_folder_files(folder, 'uuid'))
    files.sort()
    return files


def assert_project_folders_equal(base_folder, out_folder):
    """Assert two project folders are equal.

    Args:
        base_folder: The base project folder.
        out_folder: The output project folder.
    """
    base_files = get_project_files(base_folder)
    out_files = get_project_files(out_folder)
    # check same files names in each folder
    assert out_files == base_files
    # check each file for equality
    for file in base_files:
        base_file = str(Path(base_folder) / file)
        out_file = str(Path(out_folder) / file)
        if file.startswith('grids') and Path(file).suffix == '.h5':
            # get dataset group path from file path
            file_path = Path(file).parent / Path(file).stem
            dataset_name = '/'.join(file_path.parts[2:])
            assert_dataset_files_equal(base_file, out_file, dataset_name=dataset_name)
        elif file.startswith('coverages') and Path(file).suffix == '.h5':
            assert_coverage_files_equal(base_file, out_file)
        elif file.startswith('rasters') and Path(file).suffix == '.tif':
            assert compare_tif_files(base_file, out_file) == (0, 'Differences Found: 0')
        else:
            assert files_equal(base_file, out_file)


def files_equal_within_tolerance(file1: str, file2: str, tolerance: float = 0.95):
    """Compare two text files line by line and check if they are equal within a given tolerance.

    Args:
        file1 (str): Path to the first file.
        file2 (str): Path to the second file.
        tolerance (float): Fraction (0–1) of lines that must be equal to be considered "equal".

    Returns:
        (bool, float): (are_equal_within_tolerance, fraction_equal)
    """
    with open(file1, 'r', encoding='utf-8') as f1, open(file2, 'r', encoding='utf-8') as f2:
        lines1 = [line.strip() for line in f1]
        lines2 = [line.strip() for line in f2]

    # Compare only up to the shorter length
    min_len = min(len(lines1), len(lines2))
    equal_count = sum(1 for i in range(min_len) if lines1[i] == lines2[i])
    fraction_equal = equal_count / max(len(lines1), len(lines2))

    return fraction_equal >= tolerance, fraction_equal
