"""This module contains functions for loading an AdH solution."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import glob
import os
import re
import shutil
import sys
import uuid

# 2. Third party modules
from adhparam import dat_reader, file_io, mesh
import pandas as pd

# 3. Aquaveo modules
from xms.api.dmi import Query, XmsEnvironment as XmEnv
from xms.datasets.dataset_writer import DatasetWriter

# 4. Local modules


class SolutionReader:
    """Class that handles reading an AdH solution."""
    def __init__(self, _dummy_main_file=''):
        """Initializes the class.

        Args:
            _dummy_main_file (str): Unused, but here to keep constructor consistent with component classes.
        """
        super().__init__()
        self.sim_name = None
        self.file_location = ''
        self.geom_uuid = ''
        self.xms_temp_dir = ''

    def read_solution(self, query, params, _win_cont):
        """Reads the AdH Solution.

        Args:
            query (:obj:`xms.data_objects.parameters.Query`): a Query object to communicate with GMS.
            params (:obj:`dict`): Generic map of parameters. Contains the structures for various components that
             are required for adding vertices to the Query Context with Add().
            _win_cont (QWidget): The parent window

        Returns:
            (:obj:`tuple`): tuple containing:
                - messages (:obj:`list` of :obj:`tuple` of :obj:`str`): List of tuples with the first element of the
                  tuple being the message level (DEBUG, ERROR, WARNING, INFO) and the second element being the message
                  text.
                - action_requests (:obj:`list` of :obj:`xms.api.dmi.ActionRequest`): List of actions for XMS to perform.
        """
        self.sim_name = params[0]['simulation_name']
        self.file_location = params[0]['file_location']
        self.geom_uuid = params[0]['geom_uuid']
        self.xms_temp_dir = params[0]['temp_dir']
        return self.read_solution_file(query)

    def _get_h5_filename(self):
        """Get a hardcoded filename if testing."""
        if XmEnv.xms_environ_running_tests() == 'TRUE':
            return os.path.join(self.file_location, 'test_output.h5')
        return os.path.join(self.xms_temp_dir, f'{uuid.uuid4()}.h5')

    def read_solution_file(self, query):
        """
        Reads the AdH Solution.

        Args:
            query (:obj:`xms.data_objects.parameters.Query`): a Query object to communicate with GMS.

        Returns:
            (:obj:`tuple`): tuple containing:
                - messages (:obj:`list` of :obj:`tuple` of :obj:`str`): List of tuples with the first element of the
                  tuple being the message level (DEBUG, ERROR, WARNING, INFO) and the second element being the message
                  text.
                - action_requests (:obj:`list` of :obj:`xms.api.dmi.ActionRequest`): List of actions for XMS to perform.
        """
        messages = []
        path_and_file = os.path.join(self.file_location, self.sim_name)

        # Step 1: Retrieve dataset filenames
        scalar_files, vector_files = self._get_dat_filenames(path_and_file)

        # Step 2: Read mesh file
        all_elems = _read_mesh_file(path_and_file)

        # Step 3: Create an HDF5 file from the dat file
        ds_file_name, data, dset_name, values, node_activity = self._process_dat_file(scalar_files['depth'])

        # Step 4: Create element activity
        elem_activity = self._create_element_activity(all_elems, node_activity)

        # Step 5: Determine time units
        time_units = data.attrs.get('TIMEUNITS', 'SECONDS')

        # Step 6: Add datasets
        dset_list = []
        self._add_datasets(
            dset_name, values, ds_file_name, dset_list, elem_activity, time_units, scalar_files, vector_files
        )

        # Step 8: Send the datasets
        self._send_datasets(query, dset_list)

        # Step 7: Process nodal output
        prepare_and_process_nodal_output(self.file_location, self.sim_name)

        return messages, []

    def _process_dat_file(self, dat_file):
        """
        Process a dat file.

        Args:
            dat_file: The .dat file to be processed.

        Returns:
            Tuple containing:
                - The HDF5 file name.
                - The parsed data from the .dat file.
                - The dataset name extracted from the data.
                - A dictionary mapping times to data values.
                - A list of node activity booleans indicating active nodes based on data values.
        """
        ds_file_name = self._get_h5_filename()
        data = dat_reader.parse_dat_file(dat_file, self.sim_name)
        dset_name = str(data.name)

        if 'times' in data.coords:
            do_data = data.data[:, :].tolist()
            do_times = data.times.astype(str).data.tolist()
            node_activity = [[val >= 0.0 for val in data_values] for data_values in do_data]
            values = {time: do_data[i] for i, time in enumerate(do_times)}
        else:
            node_activity = [[val >= 0.0 for val in data.values[0]]]
            values = {0.0: data.values.tolist()}

        return ds_file_name, data, dset_name, values, node_activity

    def _add_datasets(
        self, dset_name, values, ds_file_name, dset_list, elem_activity, time_units, scalar_files, vector_files
    ):
        """
        Add solution datasets.

        Args:
            dset_name: The name of the dataset to be added.
            values: The values associated with the dataset.
            ds_file_name: The name of the file where the dataset is stored.
            dset_list: A list of datasets to which the new dataset will be added.
            elem_activity: Element activity tracking used in the dataset.
            time_units: Units of time used within the dataset.
            scalar_files: Dictionary containing files with scalar data.
            vector_files: Files containing vector data.
        """
        self._add_dataset(dset_name, values, ds_file_name, dset_list, elem_activity, time_units)
        self._add_vectors(vector_files, ds_file_name, dset_list, elem_activity)
        self._add_multi_column_scalars(scalar_files.get(''), '', ds_file_name, dset_list, elem_activity)
        self._add_multi_column_scalars(scalar_files.get('grain'), 'grain', ds_file_name, dset_list, elem_activity)
        self._add_multi_column_scalars(scalar_files.get('cohesive'), 'property', ds_file_name, dset_list, elem_activity)
        self._add_multi_column_scalars(scalar_files.get('constituent'), 'con', ds_file_name, dset_list, elem_activity)

    def _send_datasets(self, query, dset_list):
        """
        Send the datasets to XMS.

        Args:
            query: The query object.
            dset_list: A list of dataset objects to be added.
        """
        for dset in dset_list:
            query.add_dataset(dset)

    def _get_dat_filenames(self, path_and_file):
        """Gets the *_.dat files to read.

        Args:
            path_and_file (str): The file path and the simulation name.

        Returns:
            A tuple with the first object being a dictionary of scalar files with keys of 'depth', 'grain', 'cohesive',
            'constituent', and '' for all other scalar datasets. The second item in the tuple is a list of vector
            dataset files.
        """
        scalar_files = {}
        # Create regular expressions for cases that have numbers before the .dat.
        all_files = os.listdir(self.file_location)
        re_sim_name = re.escape(self.sim_name)
        re_bld = f'{re_sim_name}_bld[0-9]+\\.dat'
        re_cpb = f'{re_sim_name}_cpb[0-9]+\\.dat'
        re_con = f'{re_sim_name}_con[0-9]+\\.dat'
        # Get the potential *.dat filenames that we want to read in.
        dataset_files = glob.glob(f'{path_and_file}_*.dat')
        grain_scalar_files = [f'{path_and_file}_ald.dat']
        grain_scalar_files.extend([os.path.join(self.file_location, f) for f in all_files if re.match(re_bld, f)])
        grain_scalar_files.extend(glob.glob(f'{path_and_file}_smr.dat'))
        cohesive_scalar_files = [os.path.join(self.file_location, f) for f in all_files if re.match(re_cpb, f)]
        constituent_scalar_files = [os.path.join(self.file_location, f) for f in all_files if re.match(re_con, f)]
        scalar_files['depth'] = f'{path_and_file}_dep.dat'
        # Get potential vector filenames that we want to read in.
        vel_file_name = f'{path_and_file}_ovl.dat'
        wave_force_file_name = f'{path_and_file}_waveForces.dat'
        bedload_file_name = f'{path_and_file}_bedload.dat'
        susload_file_name = f'{path_and_file}_susload.dat'
        vector_files = [vel_file_name, wave_force_file_name, bedload_file_name, susload_file_name]

        # Get file names of solution files that are not *.dat files.
        # Could potentially also read in solution files that are not *.dat formatted.
        # flux_file_name = f'{path_and_file}_tlfx'
        # Split up the files in to group that may need to be treated differently, such as vectors vs. scalars.

        def scalar_check(dset_file):
            is_scalar = dset_file == scalar_files['depth']
            is_scalar = is_scalar or dset_file in vector_files
            is_scalar = is_scalar or dset_file in grain_scalar_files
            is_scalar = is_scalar or dset_file in cohesive_scalar_files
            is_scalar = is_scalar or dset_file in constituent_scalar_files

            return not is_scalar

        scalar_files[''] = filter(scalar_check, dataset_files)
        v_files = filter(lambda dset_file: True if dset_file in vector_files else False, dataset_files)
        scalar_files['grain'] = \
            filter(lambda dset_file: True if dset_file in grain_scalar_files else False, dataset_files)
        scalar_files['cohesive'] = \
            filter(lambda dset_file: True if dset_file in cohesive_scalar_files else False, dataset_files)
        scalar_files['constituent'] = \
            filter(lambda dset_file: True if dset_file in constituent_scalar_files else False, dataset_files)
        return scalar_files, v_files

    def _add_dataset(self, dset_name, values, ds_file_name, dset_list, elem_activity, time_units, is_vector=False):
        """Adds a dataset to the dataset list.

        Args:
            dset_name (str): The name of the dataset.
            values (dict): A dictionary where the key is the timestep,
                           and the value is a list of values at that timestep.
            ds_file_name (str): The path and filename where this dataset will be stored.
            dset_list (list): The list to add a dataset to.
            elem_activity (list): A list of boolean values for each element in the mesh.
            time_units (str): The time units of the dataset.
            is_vector (bool): True if a vector dataset
        """
        num_components = 2 if is_vector else 1
        dataset = DatasetWriter(
            h5_filename=ds_file_name,
            name=dset_name,
            geom_uuid=self.geom_uuid,
            time_units=time_units.capitalize(),
            num_components=num_components,
            overwrite=False
        )
        dataset.write_xmdf_dataset(times=list(values.keys()), data=list(values.values()), activity=elem_activity)
        dset_list.append(dataset)

    @staticmethod
    def _create_element_activity(all_elems, node_activity):
        """Creates an activity list based on the activity on the nodes.

        This method assumes a 2D mesh with only triangles.

        Args:
            all_elems (Dataframe): Dataframe of the element defintions.
            node_activity (list): A list containing a list of boolean values for each node for each timestep.

        Returns:
            A list of boolean values for each element in the mesh. The element is inactive if and only if all nodes
            that are a part of that element are inactive.
        """
        elem_activity = []
        # get parallel lists of nodes that comprise the elements
        node_0 = all_elems['NODE_0'].tolist()
        node_1 = all_elems['NODE_1'].tolist()
        node_2 = all_elems['NODE_2'].tolist()
        # got through each timestep
        for node_at_time_activity in node_activity:
            time_activity = [True for _ in all_elems.index]
            # go through each element
            for idx, (n0, n1, n2) in enumerate(zip(node_0, node_1, node_2)):
                active_0 = node_at_time_activity[n0 - 1]
                active_1 = node_at_time_activity[n1 - 1]
                active_2 = node_at_time_activity[n2 - 1]
                if not active_0 and not active_1 and not active_2:
                    time_activity[idx] = False
            elem_activity.append(time_activity)
        return elem_activity

    def _add_vectors(self, dat_files, ds_file_name, dset_list, elem_activity):
        """Adds vector datasets to the list.

        Args:
            dat_files (iterable): A list of the dat files to read.
            ds_file_name (str): The filename of the h5 datasets.
            dset_list (list): A list of xms.data_objects::Datasets to add to.
            elem_activity (list): The activity at each element of the mesh.
        """
        for dat_file_name in dat_files:
            # create an H5 file from the dat file
            data = dat_reader.parse_dat_file(dat_file_name, self.sim_name)
            dset_name = str(data.name)

            if 'times' in data.coords:
                do_data = data.data[:, :, :2].tolist()
                do_times = data.times.astype(str).data.tolist()
                values = {time: do_data[i] for i, time in enumerate(do_times)}
            else:
                values = {0.0: data.values.tolist()}

            # Check for time units.
            if 'TIMEUNITS' in data.attrs:
                time_units = data.attrs['TIMEUNITS']
            else:
                time_units = 'SECONDS'

            self._add_dataset(dset_name, values, ds_file_name, dset_list, elem_activity, time_units, is_vector=True)

    def _add_multi_column_scalars(self, dat_files, dset_text, ds_file_name, dset_list, elem_activity):
        """Reads the dat files as if each column were an independent scalar dataset.

        Args:
            dat_files (iterable): A list of the dat files to read.
            dset_text (str): The string to add to the dataset name.
            ds_file_name (str): The filename of the h5 datasets.
            dset_list (list): A list of xms.data_objects::Datasets to add to.
            elem_activity (list): The activity at each element of the mesh.
        """
        for dat_file_name in dat_files:
            # create an h5 file from the dat files
            data = dat_reader.parse_dat_file(dat_file_name, self.sim_name)
            base_dset_name = str(data.name)

            do_times = None
            if 'times' in data.coords:
                do_times = data.times.astype(str).data.tolist()
            # use the 3rd dimension index (0-based) to figure out how many datasets there are
            if len(data.shape) == 3:
                num_dsets = data.shape[2]
            else:
                num_dsets = 1

            # Check for time units.
            if 'TIMEUNITS' in data.attrs:
                time_units = data.attrs['TIMEUNITS']
            else:
                time_units = 'SECONDS'

            # go through each column, which is its own dataset
            for idx in range(num_dsets):
                if num_dsets > 1:
                    # if there is more than one datset in the file, then split off the one column
                    do_data = data.data[:, :, idx:idx + 1].tolist()
                    if do_times:
                        values = {time: [val[0] for val in do_data[i]] for i, time in enumerate(do_times)}
                    else:
                        values = {0.0: do_data}
                    dset_name = f'{base_dset_name}_{dset_text}_{idx + 1}'
                else:
                    # just use all the read in values since there is only one dataset
                    if do_times:
                        do_data = data.data[:, :].tolist()
                        values = {time: do_data[i] for i, time in enumerate(do_times)}
                    else:
                        values = {0.0: data.values.tolist()}
                    dset_name = base_dset_name

                # create the Dataset object to send to SMS
                self._add_dataset(dset_name, values, ds_file_name, dset_list, elem_activity, time_units)


def _read_mesh_file(path_and_file):
    """
    Reads a mesh file and returns its elements sorted by 'ID'.

    Args:
        path_and_file: The base name of the file (without extension) to read.

    Returns:
        A DataFrame containing the elements of the mesh, sorted by 'ID'.
    """
    mesh_file_name = f'{path_and_file}.3dm'
    mesh_data = mesh.Mesh()
    file_io.read_mesh_file(mesh_file_name, mesh_data)
    return mesh_data.elements.sort_values('ID')


def prepare_and_process_nodal_output(file_location: str, sim_name: str) -> None:
    """Prepare and process the nodal output file.

    Args:
        file_location: Location of the simulation files.
        sim_name: The simulation name.
    """
    node_to_point_id = read_nodal_mapping(file_location, sim_name)
    nodal_output_file = os.path.join(file_location, f'{sim_name}_nodal_output')
    output_path = os.path.join(file_location, "Output_MISC")
    column_name_map = {'TIME': 'Time(minutes)'}

    generate_node_output_files(node_to_point_id, nodal_output_file, output_path, column_name_map, sim_name)


def read_nodal_mapping(file_location: str, sim_name: str) -> dict:
    """
    Read a nodal mapping file and return a dictionary mapping node identifiers to point IDs.

    Args:
        file_location: The directory where the nodal mapping file is located.
        sim_name: The name of the simulation, used to construct the filename.

    Returns:
        A dictionary where the keys are node identifiers and the values are point IDs
        from the nodal mapping file. If the file does not exist, an empty dictionary
        is returned.
    """
    nodal_mapping_file = os.path.join(file_location, f'{sim_name}.prn_map')
    if os.path.isfile(nodal_mapping_file):
        df = pd.read_csv(nodal_mapping_file, index_col=0)
        return df.set_index('NODE')['POINT_ID'].to_dict()
    return {}


def generate_node_output_files(
    node_to_point_id: dict[str, int],
    input_file_path: str,
    output_dir_path: str,
    column_name_map: dict[str, str] | None = None,
    simulation_name: str | None = None
) -> None:
    """
    Creates output files for each node, with data starting from the 'TIME' column onward.

    Args:
        node_to_point_id: A dictionary mapping nodes to point IDs.
        input_file_path: Path to the input text file.
        output_dir_path: Path to the directory where the output files should be saved.
        column_name_map: A dictionary for renaming the columns.
        simulation_name: A prefix to be added to the output file names.
    """
    # Ensure that the input file exists
    if not os.path.isfile(input_file_path):
        return

    # Create an empty output directory
    if os.path.exists(output_dir_path):
        shutil.rmtree(output_dir_path)
    os.makedirs(output_dir_path, exist_ok=True)

    # Read the input file, including the column names
    with open(input_file_path, 'r') as file:
        column_headers = file.readline().strip().split()

    df = pd.read_csv(input_file_path, sep='\\s+', names=column_headers, skiprows=1)

    # Optionally rename columns using the provided dictionary
    if column_name_map:
        df.rename(columns=column_name_map, inplace=True)

    # Iterate over each node to point ID mapping
    for node_suffix, point_id in node_to_point_id.items():
        node_data = df[df['NODE'].str.endswith(f'-{node_suffix}')]

        if not node_data.empty:
            node_data_to_save = node_data.drop(columns=['NODE'])
            output_file_name = f"PT{point_id}.dat"

            if simulation_name:
                output_file_name = f"{simulation_name}_{output_file_name}"

            output_file_path = os.path.join(output_dir_path, output_file_name)
            node_data_to_save.to_csv(output_file_path, sep=' ', index=False, header=True)


def get_sms_data(xms_query):
    """Get the path and filenames of the *.dat files. Get index of the take the solution will be read on.

    A fictitious *.adh_dataset filename comes from SMS query as the primary read file. We can figure out *.dat
    file names from it.

    Args:
        xms_query (:obj:`xms_api.dmi.Query`): The query used to communicate with SMS.

    Returns:
        str: path and filename of the .h5 file

    """
    # Read file is the fake adh_dataset file in the same location, and with the same basename, as the solution.
    dat_filename = xms_query.read_file

    if not os.path.isfile(dat_filename):
        sys.stderr.write(f"Solution file not found - {dat_filename}")
        sys.exit()

    return dat_filename


def get_and_send_data():
    """Gets data about the simulation from XMS, loads its solution, and sends it back to XMS."""
    query = Query()
    file_name_and_path = get_sms_data(query)
    file_name = os.path.basename(file_name_and_path)
    file_location = os.path.dirname(file_name_and_path)
    simulation_name = file_name.replace('.adh_dataset', '')
    load_sim = SolutionReader()
    load_sim.sim_name = simulation_name
    load_sim.file_location = file_location
    load_sim.read_solution_file(query)
    query.send()
