"""Runs XMS DMI component batch save events."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules
import argparse
import os
import traceback

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import Query

# 4. Local modules


def save_all_main(save_type):
    """Entry point for component batch save events.

    Args:
        save_type (str): The save event type
    """
    query = Query(timeout=300000)

    all_main_result = query._impl._instance.Get('all_main_files')
    if not all_main_result:
        raise Exception('Query::Get failed on main files')

    all_main = {}
    if all_main_result['all_main_files'] and all_main_result['all_main_files'][0]:
        all_main = all_main_result['all_main_files'][0]

    error_uuids = []
    start_dir = os.getcwd()
    all_action_requests = []
    all_messages = []
    all_main_files = {}
    # go through all of the components and call the save event on each one
    new_path = ''
    for idx, save_main in all_main.items():
        try:
            comp_module, comp_class, comp_old_file, new_path = save_main
            mod = __import__(comp_module, fromlist=[comp_class])
            klass = getattr(mod, comp_class)
            class_instance = klass(comp_old_file)

            new_main_file, messages, action_requests = class_instance.save_to_location(new_path, save_type)
            all_main_files[idx] = new_main_file
            all_action_requests.extend(action_requests)
            all_messages.extend(messages)
        except Exception as err:  # Print exception traceback to error log for debuggin
            with open(os.path.join(start_dir, f'project_save_error_{os.getpid()}.log'), 'a') as file:
                traceback.print_exception(type(err), err, err.__traceback__, file=file)
            error_uuids.append(os.path.basename(new_path))

    # Write a file with the UUIDs of components that threw exceptions.
    if error_uuids:
        with open(os.path.join(start_dir, f'save_all_error_uuids_{os.getpid()}.log'), 'w') as f:
            for error_uuid in error_uuids:
                f.write(f'{error_uuid}\n')

    arg_list = []
    if all_action_requests:
        # Unwrap the pure Python ActionRequests
        arg_list.append({"actions": [action._instance for action in all_action_requests]})
    if all_messages:
        arg_list.append({"messages": all_messages})
    if all_main_files:
        arg_list.append({"all_main_files": all_main_files})
    if arg_list:
        query._impl._instance.Set(arg_list)
        query.send(True)


if __name__ == "__main__":
    arguments = argparse.ArgumentParser(description="Component method runner.")
    arguments.add_argument(dest='script', type=str, help='script to run')
    arguments.add_argument(dest='save_type', type=str, help='one of: PACKAGE, SAVE, SAVE_AS')
    parsed_args = arguments.parse_args()
    save_all_main(parsed_args.save_type)
