"""Runs XMS DMI component-based migration for XMS projects."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules
import argparse
import os
import sqlite3
import sys

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import Query, XmsEnvironment as XmEnv

# 4. Local modules

project_db_file = ''
xms_migration_data = None


def get_type_string_from_int(type_int):
    """Convert an XMS ENTITY_TYPE enum to a string that is more manageable in Python.

    Args:
        type_int (int): The XMS ENTITY_TYPE enum value to convert

    Returns:
        str: Text representation of passed in enum value
    """
    if type_int == 0:
        return 'Simulation'
    elif type_int == 1:
        return 'Coverage'
    elif type_int == 2:
        return 'Point'
    elif type_int == 3:
        return 'Arc'
    elif type_int == 4:
        return 'Polygon'
    elif type_int == 5:
        return 'Material'
    elif type_int == 6:
        return 'Set'


def get_keyword_string_from_int(keyword_int):
    """Convert an XMS ValueInWidget_enum enum to a string that is more manageable in Python.

    Args:
        keyword_int (int): The XMS ValueInWidget_enum enum value to convert

    Returns:
        str: Text representation of passed in enum value
    """
    if keyword_int == -1:
        return 'none'
    elif keyword_int == 0:
        return 'value'
    elif keyword_int == 1:
        return 'units'
    elif keyword_int == 2:
        return 'geom_guid'
    elif keyword_int == 3:
        return 'file_path'
    elif keyword_int == 4:
        return 'file_name'
    elif keyword_int == 5:
        return 'sms_path'
    elif keyword_int == 6:
        return 'xmdf_path'
    elif keyword_int == 7:
        return 'curve_id'
    elif keyword_int == 8:
        return 'curve_x'
    elif keyword_int == 9:
        return 'curve_y'
    elif keyword_int == 10:
        return 'object'
    elif keyword_int == 11:
        return 'role'


def get_xms_migrate_data(query: Query, db_version: int) -> dict:
    """Get data from XMS required to run the project migration script.

    Args:
        query: Object for communicating with XMS.
        db_version: Version of the DMI database we are migrating.

    Returns:
        (dict): dict containing XMS migration data:
            - 'classes': [str],  # Parallel with modules
            - 'modules': [str],  # Parallel with classes
            - 'sim_uuids': [str],
    """
    xms_data = {
        'classes': [],
        'modules': [],
        'sim_id_to_uuid': {},
    }

    result = query._impl._instance.Get('class', 'module')
    if result and result['class'] and result['module']:
        xms_data['classes'] = result['class'][0]
        xms_data['modules'] = result['module'][0]

    if db_version <= 4:
        sim_uuid_result = query._impl._instance.Get('Simulation')
        if sim_uuid_result and sim_uuid_result['Simulation'] and sim_uuid_result['Simulation'][0]:
            for sim_id, sim_uuid in sim_uuid_result['Simulation'][0].items():
                xms_data['sim_id_to_uuid'][sim_id] = sim_uuid

    return xms_data


def run_migration(db_file, query=None, xms_data=None):
    """Perform migration of an XMS project.

    Args:
        db_file (str): Filepath of the project database to migrate
        query (Query): Object for communicating with XMS. Mutually exclusive with xms_data
        xms_data (dict): XMS data dictionary. Avoids use of Query. Useful for testing.

    Returns:
        list, list, list: List of all Query Context build vertices, list of all ActionRequests queued up by migration
            scripts, list of all messages queued up by migration scripts
    """
    global project_db_file
    global xms_migration_data
    project_db_file = db_file
    xms_migration_data = xms_data

    conn = sqlite3.connect(db_file)
    select = 'SELECT Value FROM DatabaseInfo WHERE Name = "dynamicDb";'
    cursor = conn.execute(select)
    db_version = int(cursor.fetchone()[0])

    if not xms_migration_data:  # Retrieve XMS data if not testing.
        xms_migration_data = get_xms_migrate_data(query, db_version)

    widget_map = {}
    uuid_dict = {}
    main_file_dict = {}
    take_dict = {}
    hidden_dict = {}
    db_path = os.path.dirname(db_file)

    if db_version >= 6:
        select = 'SELECT Uuid, MainFile FROM Components;'
        cursor = conn.execute(select)
        for uuid_object in cursor:
            abs_path = os.path.normpath(os.path.join(db_path, uuid_object[1]))
            if uuid_object[0] not in main_file_dict:
                main_file_dict[uuid_object[0]] = []
            main_file_dict[uuid_object[0]].append(abs_path)

    if db_version >= 6:
        select = 'SELECT Uuid, Model AS ModelName, UniqueName AS Type, "Component" AS Module, -1 AS Id FROM ' \
                 '((SELECT Uuid, ComponentDefId AS defId FROM Components) JOIN ' \
                 '(SELECT Model, UniqueName, ComponentDefId FROM ComponentsDef) ON defId = ComponentDefId) UNION ' \
                 'SELECT GUID, ModelName, Type, "Coverage" AS Module, -1 AS Id FROM Cov UNION ' \
                 'SELECT Uuid, ModelName, SimName AS Type, "Simulation" AS Module, SimId AS Id FROM Sim;'
        cursor = conn.execute(select)
        uuid_dict = {
            uuid_object[0]: (uuid_object[1], uuid_object[2], uuid_object[3], uuid_object[4])
            for uuid_object in cursor
        }
    elif db_version == 5:
        select = 'SELECT GUID, ModelName, Type, "Coverage" AS Module, -1 AS Id FROM Cov UNION ' \
                 'SELECT Uuid, ModelName, SimName AS Type, "Simulation" AS Module, SimId AS Id FROM Sim;'
        cursor = conn.execute(select)
        uuid_dict = {
            uuid_object[0]: (uuid_object[1], uuid_object[2], uuid_object[3], uuid_object[4])
            for uuid_object in cursor
        }
    elif db_version <= 4:
        select = 'SELECT GUID, ModelName, Type, "Coverage" AS Module, -1 AS Id FROM Cov UNION ' \
                 'SELECT "" AS Uuid, ModelName, SimName AS Type, "Simulation" AS Module, SimId AS Id FROM Sim;'
        cursor = conn.execute(select)
        for uuid_object in cursor:
            key_uuid = uuid_object[0]
            if key_uuid == "":
                # create a new uuid for this simulation
                key_uuid = xms_migration_data['sim_id_to_uuid'][uuid_object[4]]
            uuid_dict[key_uuid] = (uuid_object[1], uuid_object[2], uuid_object[3], uuid_object[4])

    if db_version > 3:
        select = 'SELECT CurveId, Time, X, Y, CurveRow FROM CurveIdValues;'
    else:
        select = 'SELECT CurveId, Time, X, Y, -1 as CurveRow FROM CurveIdValues;'
    cursor = conn.execute(select)
    curve_dict = {-1: [[-1], [], [], []]}
    for curve_object in cursor:
        curve_id = curve_object[0]
        time = curve_object[1]
        x = curve_object[2]
        y = curve_object[3]
        if curve_id not in curve_dict:
            curve_dict[curve_id] = [[curve_id], [time], [x], [y]]
        else:
            curve_dict[curve_id][1].append(time)
            curve_dict[curve_id][2].append(x)
            curve_dict[curve_id][3].append(y)

    # fix this, should be
    if db_version > 4:
        select = 'SELECT Uuid, Type, AttId, RowCol, KeywordId, WidgetName, Value FROM (' \
                 'SELECT * FROM (' \
                 '(SELECT EntityId AS EId, Type, SimOrCovId, AttId FROM Entity)' \
                 ' JOIN ' \
                 '(SELECT SimId AS Id, Uuid, 0 AS ObjType FROM Sim' \
                 ' UNION ' \
                 'SELECT CovId AS Id, GUID AS Uuid, 1 AS ObjType FROM Cov' \
                 ' UNION ' \
                 'SELECT CovId AS Id, GUID AS Uuid, 2 AS ObjType FROM Cov' \
                 ' UNION ' \
                 'SELECT CovId AS Id, GUID AS Uuid, 3 AS ObjType FROM Cov' \
                 ' UNION ' \
                 'SELECT CovId AS Id, GUID AS Uuid, 4 AS ObjType FROM Cov' \
                 ' UNION ' \
                 'SELECT CovId AS Id, GUID AS Uuid, 5 AS ObjType FROM Cov' \
                 ' UNION ' \
                 'SELECT CovId AS Id, GUID AS Uuid, 6 AS ObjType FROM Cov)' \
                 ' ON ' \
                 'SimOrCovId = Id AND ObjType = Type)' \
                 ' JOIN (' \
                 'SELECT WEId, RowCol, KeywordId, WidgetName, Value FROM (' \
                 '(SELECT WidgetInstId AS WIId, WEId, RowCol, KeywordId, WidgetName FROM (' \
                 '(SELECT WidgetInstId, EntityId AS WEId, WidgetDefId, RowCol, KeywordId FROM WidgetInst)' \
                 ' JOIN (SELECT WidgetDefId AS WDefId, WidgetName FROM WidgetDef) ON WidgetDefId = WDefId))' \
                 ' JOIN' \
                 '(SELECT WidgetInstId, Value FROM TextValues' \
                 ' UNION ' \
                 'SELECT WidgetInstId, Value FROM IntValues' \
                 ' UNION ' \
                 'SELECT WidgetInstId, Value FROM DoubleValues' \
                 ' UNION ' \
                 'SELECT WidgetInstId, Value FROM DateTimeValues)' \
                 ' ON WIId = WidgetInstId' \
                 '))ON EId = WEId)'
    else:
        select = 'SELECT Uuid, Type, AttId, RowCol, KeywordId, WidgetName, Value, SimOrCovId FROM (' \
                 'SELECT * FROM (' \
                 '(SELECT EntityId AS EId, Type, SimOrCovId, AttId FROM Entity)' \
                 ' JOIN ' \
                 '(SELECT SimId AS Id, "" AS Uuid, 0 AS ObjType FROM Sim' \
                 ' UNION ' \
                 'SELECT CovId AS Id, GUID AS Uuid, 1 AS ObjType FROM Cov' \
                 ' UNION ' \
                 'SELECT CovId AS Id, GUID AS Uuid, 2 AS ObjType FROM Cov' \
                 ' UNION ' \
                 'SELECT CovId AS Id, GUID AS Uuid, 3 AS ObjType FROM Cov' \
                 ' UNION ' \
                 'SELECT CovId AS Id, GUID AS Uuid, 4 AS ObjType FROM Cov' \
                 ' UNION ' \
                 'SELECT CovId AS Id, GUID AS Uuid, 5 AS ObjType FROM Cov' \
                 ' UNION ' \
                 'SELECT CovId AS Id, GUID AS Uuid, 6 AS ObjType FROM Cov)' \
                 ' ON ' \
                 'SimOrCovId = Id AND ObjType = Type)' \
                 ' JOIN (' \
                 'SELECT WEId, RowCol, KeywordId, WidgetName, Value FROM (' \
                 '(SELECT WidgetInstId AS WIId, WEId, RowCol, KeywordId, WidgetName FROM (' \
                 '(SELECT WidgetInstId, EntityId AS WEId, WidgetDefId, RowCol, KeywordId FROM WidgetInst)' \
                 ' JOIN (SELECT WidgetDefId AS WDefId, WidgetName FROM WidgetDef) ON WidgetDefId = WDefId))' \
                 ' JOIN' \
                 '(SELECT WidgetInstId, Value FROM TextValues' \
                 ' UNION ' \
                 'SELECT WidgetInstId, Value FROM IntValues' \
                 ' UNION ' \
                 'SELECT WidgetInstId, Value FROM DoubleValues' \
                 ' UNION ' \
                 'SELECT WidgetInstId, Value FROM DateTimeValues)' \
                 ' ON WIId = WidgetInstId' \
                 '))ON EId = WEId)'
    cursor = conn.execute(select)

    for widget_value in cursor:
        uuid_str = widget_value[0]
        type_str = get_type_string_from_int(widget_value[1])
        att_int = widget_value[2]
        widget_name = widget_value[5]
        keyword_int = widget_value[4]
        keyword = get_keyword_string_from_int(keyword_int)
        row = widget_value[3]
        value = widget_value[6]
        if db_version <= 4 and not uuid_str:
            sim_id = widget_value[7]
            uuid_str = xms_migration_data['sim_id_to_uuid'][sim_id]
        if uuid_str not in widget_map:
            widget_map[uuid_str] = {}
        if type_str not in widget_map[uuid_str]:
            widget_map[uuid_str][type_str] = {}
        if att_int not in widget_map[uuid_str][type_str]:
            widget_map[uuid_str][type_str][att_int] = {}
        if widget_name not in widget_map[uuid_str][type_str][att_int]:
            widget_map[uuid_str][type_str][att_int][widget_name] = []
        if keyword_int == 7:  # we have a curve id
            # change the value to to the whole curve, and the keyword to 'none'
            keyword = 'none'
            # looks like an old bug with curves
            if value == 0:
                value = -1
            value = curve_dict[value]
        widget_map[uuid_str][type_str][att_int][widget_name].append((keyword, row, value))

    if db_version >= 6:
        select = 'SELECT DISTINCT TakerUuid, TakenUuid FROM((' \
                 'SELECT DISTINCT TakerEntityId, TakenEntityId, Uuid AS TakerUuid FROM (' \
                 '(SELECT TakerEntityId, TakenEntityId FROM TakeConnection)' \
                 ' JOIN ' \
                 '(SELECT EId, Type, ObjType, Uuid FROM (' \
                 '(SELECT EntityId AS EId, Type, SimOrCovId FROM Entity WHERE Type IN (0, 10))' \
                 ' JOIN ' \
                 '(SELECT SimId AS Id, Uuid, 0 AS ObjType FROM Sim' \
                 ' UNION ' \
                 'SELECT ComponentId AS Id, Uuid, 10 AS ObjType FROM Components) ' \
                 'ON SimOrCovId = Id AND Type = ObjType)) ' \
                 'ON EId = TakerEntityId))' \
                 ' JOIN ' \
                 '(SELECT DISTINCT TakerEntityId AS TakerId, TakenEntityId AS TakenId, Uuid AS TakenUuid FROM (' \
                 '(SELECT TakerEntityId, TakenEntityId FROM TakeConnection)' \
                 ' JOIN ' \
                 '(SELECT EId, Type, ObjType, Uuid FROM ' \
                 '((SELECT EntityId AS EId, Type, SimOrCovId FROM Entity WHERE Type IN (0, 1, 10))' \
                 ' JOIN ' \
                 '(SELECT SimId AS Id, Uuid, 0 AS ObjType FROM Sim' \
                 ' UNION ' \
                 'SELECT CovId AS Id, GUID AS Uuid, 1 AS ObjType FROM Cov' \
                 ' UNION ' \
                 'SELECT ComponentId AS Id, Uuid, 10 AS ObjType FROM Components) ' \
                 'ON SimOrCovId = Id AND Type = ObjType)) ' \
                 'ON EId = TakenEntityId)) ' \
                 'ON TakerEntityId = TakerId);'
        cursor = conn.execute(select)
        for uuid_object in cursor:
            taker_uuid = uuid_object[0]
            taken_uuid = uuid_object[1]
            if taker_uuid not in take_dict:
                take_dict[taker_uuid] = []
            take_dict[taker_uuid].append(taken_uuid)
    elif db_version == 5:
        select = 'SELECT DISTINCT TakerUuid, TakenUuid FROM((' \
                 'SELECT DISTINCT TakerEntityId, TakenEntityId, Uuid AS TakerUuid FROM (' \
                 '(SELECT TakerEntityId, TakenEntityId FROM TakeConnection)' \
                 ' JOIN ' \
                 '(SELECT EId, Type, ObjType, Uuid FROM (' \
                 '(SELECT EntityId AS EId, Type, SimOrCovId FROM Entity WHERE Type = 0)' \
                 ' JOIN ' \
                 '(SELECT SimId AS Id, Uuid, 0 AS ObjType FROM Sim) ' \
                 'ON SimOrCovId = Id AND Type = ObjType)) ' \
                 'ON EId = TakerEntityId))' \
                 ' JOIN ' \
                 '(SELECT DISTINCT TakerEntityId AS TakerId, TakenEntityId AS TakenId, Uuid AS TakenUuid FROM (' \
                 '(SELECT TakerEntityId, TakenEntityId FROM TakeConnection)' \
                 ' JOIN ' \
                 '(SELECT EId, Type, ObjType, Uuid FROM ' \
                 '((SELECT EntityId AS EId, Type, SimOrCovId FROM Entity WHERE Type IN (0, 1))' \
                 ' JOIN ' \
                 '(SELECT SimId AS Id, Uuid, 0 AS ObjType FROM Sim' \
                 ' UNION ' \
                 'SELECT CovId AS Id, GUID AS Uuid, 1 AS ObjType FROM Cov) ' \
                 'ON SimOrCovId = Id AND Type = ObjType)) ' \
                 'ON EId = TakenEntityId)) ' \
                 'ON TakerEntityId = TakerId);'
        cursor = conn.execute(select)
        for uuid_object in cursor:
            taker_uuid = uuid_object[0]
            taken_uuid = uuid_object[1]
            if taker_uuid not in take_dict:
                take_dict[taker_uuid] = []
            take_dict[taker_uuid].append(taken_uuid)
    elif db_version <= 4:
        select = 'SELECT DISTINCT TakenUuid, SimOrCovId FROM((' \
                 'SELECT DISTINCT TakerEntityId, TakenEntityId, SimOrCovId FROM (' \
                 '(SELECT TakerEntityId, TakenEntityId FROM TakeConnection)' \
                 ' JOIN ' \
                 '(SELECT EId, Type, SimOrCovId FROM (' \
                 'SELECT EntityId AS EId, Type, SimOrCovId FROM Entity WHERE Type = 0)) ' \
                 'ON EId = TakerEntityId))' \
                 ' JOIN ' \
                 '(SELECT DISTINCT TakerEntityId AS TakerId, TakenEntityId AS TakenId, Uuid AS TakenUuid FROM (' \
                 '(SELECT TakerEntityId, TakenEntityId FROM TakeConnection)' \
                 ' JOIN ' \
                 '(SELECT EId, Type, ObjType, Uuid, Id FROM ' \
                 '((SELECT EntityId AS EId, Type, SimOrCovId FROM Entity WHERE Type = 1)' \
                 ' JOIN ' \
                 '(SELECT CovId AS Id, GUID AS Uuid, 1 AS ObjType FROM Cov) ' \
                 'ON SimOrCovId = Id AND Type = ObjType)) ' \
                 'ON EId = TakenEntityId)) ' \
                 'ON TakerEntityId = TakerId);'
        cursor = conn.execute(select)
        for uuid_object in cursor:
            taken_uuid = uuid_object[0]
            sim_uuid = xms_migration_data['sim_id_to_uuid'][uuid_object[1]]
            if sim_uuid not in take_dict:
                take_dict[sim_uuid] = []
            take_dict[sim_uuid].append(taken_uuid)

    if db_version >= 6:
        select = 'SELECT DISTINCT TakerUuid, TakenUuid FROM((' \
                 'SELECT DISTINCT OwnerEntityId, ComponentEntityId, Uuid AS TakerUuid FROM (' \
                 '(SELECT OwnerEntityId, ComponentEntityId FROM HiddenComponents)' \
                 ' JOIN ' \
                 '(SELECT EId, Type, ObjType, Uuid FROM (' \
                 '(SELECT EntityId AS EId, Type, SimOrCovId FROM Entity WHERE Type IN (0, 1))' \
                 ' JOIN ' \
                 '(SELECT SimId AS Id, Uuid, 0 AS ObjType FROM Sim' \
                 ' UNION ' \
                 'SELECT CovId AS Id, GUID AS Uuid, 1 AS ObjType FROM Cov) ' \
                 'ON SimOrCovId = Id AND Type = ObjType)) ' \
                 'ON EId = OwnerEntityId))' \
                 ' JOIN ' \
                 '(SELECT DISTINCT OwnerEntityId AS TakerId, ComponentEntityId AS TakenId, Uuid AS TakenUuid FROM (' \
                 '(SELECT OwnerEntityId, ComponentEntityId FROM HiddenComponents)' \
                 ' JOIN ' \
                 '(SELECT EId, Type, ObjType, Uuid FROM ' \
                 '((SELECT EntityId AS EId, Type, SimOrCovId FROM Entity WHERE Type = 10)' \
                 ' JOIN ' \
                 '(SELECT ComponentId AS Id, Uuid, 10 AS ObjType FROM Components) ' \
                 'ON SimOrCovId = Id AND Type = ObjType)) ' \
                 'ON EId = ComponentEntityId)) ' \
                 'ON OwnerEntityId = TakerId);'
        cursor = conn.execute(select)
        for uuid_object in cursor:
            owner_uuid = uuid_object[0]
            comp_uuid = uuid_object[1]
            if owner_uuid not in hidden_dict:
                hidden_dict[owner_uuid] = []
            hidden_dict[owner_uuid].append(comp_uuid)

    # Material atts
    select = 'SELECT GUID, MaterialId, Name, Red, Green, Blue, Alpha, TextureId FROM (' \
             '(SELECT MaterialId, Name, CovId as Id, Red, Green, Blue, Alpha, TextureId FROM MaterialsOnACoverage)' \
             ' JOIN ' \
             '(SELECT GUID, CovId as cid FROM Cov) ON cid = Id);'
    cursor = conn.execute(select)
    mat_dict = {}
    for row in cursor:
        cov_uuid = row[0]
        if cov_uuid not in mat_dict:
            mat_dict[cov_uuid] = {}
        # {cov_uuid {mat_id: (name, r, g, b, alpha, texture)}}
        mat_dict[cov_uuid][row[1]] = (row[2], row[3], row[4], row[5], row[6], row[7])

    # Material assignments
    select = 'SELECT GUID, PolyId, MatId From (' \
             '(SELECT PolyCovId, PolyId, MatId FROM (' \
             '(SELECT Poly_EntityId as PolyEntId, Material_EntityId as PolyMatEntId FROM MaterialOnPoly) ' \
             ' JOIN (SELECT AttId as PolyId, EntityId as pEId, SimOrCovId as PolyCovId FROM Entity)' \
             ' ON PolyEntId = pEId' \
             ' JOIN ' \
             '((SELECT Poly_EntityId as MatPolyEntId, Material_EntityId FROM MaterialOnPoly) ' \
             ' JOIN (SELECT AttId as MatId, EntityId as mEId FROM Entity) ON Material_EntityId = mEId)' \
             ' ON mEId = PolyMatEntId AND PolyEntId = MatPolyEntId))' \
             ' JOIN ' \
             '(SELECT GUID, CovId FROM Cov)' \
             ' ON PolyCovId = CovId);'
    cursor = conn.execute(select)
    mat_poly_dict = {}
    for row in cursor:
        cov_uuid = row[0]
        if cov_uuid not in mat_poly_dict:
            mat_poly_dict[cov_uuid] = {}
        cov_dict = mat_poly_dict[cov_uuid]
        mat_id = row[2]
        if mat_id not in cov_dict:
            cov_dict[mat_id] = set()
        cov_dict[mat_id].add(row[1])

    if query:
        root_place_mark = query._impl._instance.GetContext().GetRootInstance()
    else:
        root_place_mark = -1  # No Query, testing
    delete_replace_map = {}
    all_classes = []
    all_place_marks = []
    all_messages = []
    all_action_requests = []
    for class_name, module in zip(xms_migration_data['classes'], xms_migration_data['modules']):
        try:
            mod = __import__(module, fromlist=[class_name])
            klass = getattr(mod, class_name)
            class_instance = klass()
            all_classes.append(class_instance)
            map_to_method = class_instance.map_to_typename
            delete_replace_map_local = map_to_method(
                uuid_dict, widget_map, take_dict, main_file_dict, hidden_dict, mat_dict, mat_poly_dict, query
            )
            if query:
                ctxt = query._impl._instance.GetContext()
                local_marks = ctxt.GetPlaceMarks()
                ctxt.ClearPlaceMarks()
                ctxt.SetPlaceMark(root_place_mark)
                query._impl._instance.SetContext(ctxt)
                all_place_marks.extend(local_marks)
            for del_uuid, replace_uuids in delete_replace_map_local.items():
                if del_uuid not in delete_replace_map:
                    delete_replace_map[del_uuid] = []
                delete_replace_map[del_uuid].extend(replace_uuids)
        except Exception as ex:
            XmEnv.report_error(ex)
            if query:
                ctxt = query._impl._instance.GetContext()
                ctxt.ClearPlaceMarks()
                ctxt.SetPlaceMark(root_place_mark)
                query._impl._instance.SetContext(ctxt)

    for inst in all_classes:
        try:
            replace_method = inst.send_replace_map
            replace_method(delete_replace_map, query)
            if query:
                ctxt = query._impl._instance.GetContext()
                local_marks = ctxt.GetPlaceMarks()
                ctxt.ClearPlaceMarks()
                ctxt.SetPlaceMark(root_place_mark)
                query._impl._instance.SetContext(ctxt)
                all_place_marks.extend(local_marks)
        except Exception as ex:
            XmEnv.report_error(ex)
            if query:
                ctxt = query._impl._instance.GetContext()
                ctxt.ClearPlaceMarks()
                ctxt.SetPlaceMark(root_place_mark)
                query._impl._instance.SetContext(ctxt)

    for inst in all_classes:
        try:
            message_method = inst.get_messages_and_actions
            messages, action_requests = message_method()
            # Unwrap the pure Python ActionRequests
            all_action_requests.extend([action._instance for action in action_requests])
            all_messages.extend(messages)
        except Exception as ex:
            XmEnv.report_error(ex)

    return all_place_marks, all_action_requests, all_messages


def main():
    """Driver for DMI model/component migrate scripts.

    Called when XMS needs to migrate an old project that has an old DMI model/component definition that a currently
    loaded model/component definition claims it can migrate.
    """
    # Things needed from arg-parse:
    #   DataBase file
    #

    # Things needed from Query: (reset context between calls to the migrator)
    #   list of class/module 's to run
    #
    arguments = argparse.ArgumentParser(description="Migration runner.")
    arguments.add_argument(dest='script', type=str, help='script to run')
    arguments.add_argument(dest='file', type=str, help='the project database file')
    parsed_args = arguments.parse_args()

    query = Query(migrate_script=True)
    # Disallow Query.send() when we are running the migration scripts.
    query._impl._instance.SetAllowSend(False)  # Only exposed on the C++ exposed interface

    all_place_marks, all_actions, all_messages = run_migration(parsed_args.file, query=query)

    ctxt = query._impl._instance.GetContext()
    root_place_mark = ctxt.GetRootInstance()
    for mark in all_place_marks:
        ctxt.SetPlaceMark(mark)
    query._impl._instance.SetContext(ctxt)
    arg_list = []
    if all_actions:
        # Pure Python ActionRequests already unwrapped in run_migration()
        arg_list.append({"actions": all_actions})
    if all_messages:
        arg_list.append({"messages": all_messages})
    if arg_list:
        query._impl._instance.Set(arg_list, root_place_mark)

    # Re-enable Query.send()
    query._impl._instance.SetAllowSend(True)
    query.send(True)
    sys.exit()


if __name__ == "__main__":
    main()
