"""ParametersManager class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import json  # noqa: I100, I202
import os

# 2. Third party modules
import numpy as np
from PySide2.QtWidgets import QDialog

# 3. Aquaveo modules
from xms.api.tree.tree_util import trim_project_explorer
from xms.guipy.data.target_type import TargetType

# 4. Local modules
from xms.srh.components.parameters.bc_parameter import BcParameter
from xms.srh.components.parameters.initial_condition_parameter import InitialConditionParameter
from xms.srh.components.parameters.inlet_q_parameter import InletQParameter
from xms.srh.components.parameters.mannings_n_parameter import ManningsNParameter
from xms.srh.components.parameters.parameter_data import ParameterData
from xms.srh.components.parameters.time_step_parameter import TimeStepParameter
from xms.srh.components.sim_query_helper import SimQueryHelper
from xms.srh.data.material_data import MaterialData
from xms.srh.gui.parameters_dialog import ParametersDialog


class ParametersManager:
    """Handles the parameters."""
    def __init__(self, main_file, data):
        """Initializes the class.

        Args:
            main_file: The main file associated with the SimComponent.
            data (ModelControl): The model control class

        """
        self.main_file = main_file
        self.data = data
        self.grid_uuid = None

    def run_parameters_dialog(self, query, win_cont):
        """Opens the Parameters dialog.

        Args:
            query (:obj:`xms.api.dmi.Query`): Object for communicating with GMS
            win_cont (:obj:`PySide2.QtWidgets.QWidget`): The window container.
        """
        param_data, helper = self._get_available_parameters(query)
        disk_data = self.read_parameter_file(self.main_file)
        param_data.apply_disk_data(disk_data)
        param_data.populate_runs_table(disk_data)

        pe_tree_full = query.copy_project_tree()
        pe_tree = trim_project_explorer(pe_tree_full, self.grid_uuid)

        dlg = ParametersDialog(
            param_data=param_data, pe_tree_trimmed=pe_tree, parent=win_cont, query=query, sim_helper=helper
        )
        if dlg.exec() == QDialog.Accepted:
            self.write_parameter_file(self.main_file, dlg.param_data)

    def _get_available_parameters(self, query):
        """Looks through the model data and finds items that can be parameters.

        Args:
            query (:obj:`xms.api.dmi.Query`): Object for communicating with GMS

        Returns:
            tuple(dict, SimQueryHelper): dict of parameter data and the SimQueryHelper used to retrieve it
        """
        # Get the sim data
        helper = SimQueryHelper(query)
        helper.get_sim_data()

        param_data = ParameterData()
        self.grid_uuid = helper.grid_uuid
        self.get_material_data(helper, param_data)
        self.get_bcs(helper, param_data)
        self.get_time_step(param_data)
        self.get_initial_condition(param_data)

        return param_data, helper

    def get_time_step(self, param_data):
        """Adds the time step as a parameter.

        Args:
            param_data (ParameterData): Parameter data.
        """
        param_data.params.append(
            TimeStepParameter(
                use=0,
                description='',
                default=self.data.hydro.time_step,
                string_value=str(self.data.hydro.time_step),
                value=self.data.hydro.time_step,
                optimized_value=0.0,
                min=self.data.hydro.time_step,
                max=self.data.hydro.time_step
            )
        )

    def get_initial_condition(self, param_data):
        """Adds the time step as a parameter.

        Args:
            param_data (ParameterData): Parameter data dict.
        """
        value = ''
        if self.data.hydro.initial_condition == 'Initial Water Surface Elevation':
            value = str(self.data.hydro.initial_water_surface_elevation)
        elif self.data.hydro.initial_condition == 'Restart File':
            value = self.data.hydro.restart_file
        elif self.data.hydro.initial_condition == 'Water Surface Elevation Dataset':
            value = self.data.hydro.water_surface_elevation_dataset
        param_data.params.append(
            InitialConditionParameter(
                use=0,
                description='',
                default=self.data.hydro.initial_condition,
                string_value=value,
                value=0.0,
                optimized_value=0.0,
                min=0.0,
                max=0.0
            )
        )

    @staticmethod
    def get_bcs(helper, param_data):
        """Adds inlet q, exit h, and internal sinks as parameters.

        Args:
            helper (SimQueryHelper): Query helper class.
            param_data (ParameterData): Parameter data.

        """
        if helper.bc_component is None:
            return
        # Get the arc bcs (gleaned from BcMapper._get_grid_points_from_arcs())
        bc_component = helper.bc_component
        bc_coverage = helper.coverages['Boundary Conditions'][0]
        arcs = bc_coverage.arcs
        bc_data = bc_component.data
        df = bc_data.comp_ids.to_dataframe()
        comp_id_to_bc_id_name = {x[0]: (x[1], x[2]) for x in [tuple(x) for x in df.to_numpy()]}
        arc_id_to_bc_id = {}
        for arc in arcs:
            pid = arc.id
            comp_id = bc_component.get_comp_id(TargetType.arc, pid)
            if comp_id is None or comp_id < 0:
                display_name = 'wall'
            else:
                bc_id = comp_id_to_bc_id_name[comp_id][0]
                arc_id_to_bc_id[pid] = bc_id
                # record = df_bc.loc[df_bc['id'] == bc_id].reset_index(drop=True).to_dict()
                bc_par = bc_component.data.bc_data_param_from_id(bc_id)
                display_name = comp_id_to_bc_id_name[comp_id][1]
            if display_name in ['inlet_q', 'exit_h', 'internal_sink']:
                # Add a parameter
                parameter = None
                if display_name == 'inlet_q':
                    parameter = InletQParameter(
                        use=0,
                        id_number=pid,
                        description=f'Arc {pid}',
                        default=bc_par.inlet_q.constant_q,
                        string_value=str(bc_par.inlet_q.constant_q),
                        value=bc_par.inlet_q.constant_q,
                        optimized_value=0.0,
                        min=bc_par.inlet_q.constant_q,
                        max=bc_par.inlet_q.constant_q
                    )
                elif display_name == 'exit_h':
                    parameter = BcParameter(
                        use=0,
                        id_number=pid,
                        type='Exit H',
                        description=f'Arc {pid}',
                        default=bc_par.exit_h.constant_wse,
                        string_value=str(bc_par.exit_h.constant_wse),
                        value=bc_par.exit_h.constant_wse,
                        optimized_value=0.0,
                        min=bc_par.exit_h.constant_wse,
                        max=bc_par.exit_h.constant_wse
                    )
                elif display_name == 'internal_sink':
                    parameter = BcParameter(
                        use=0,
                        id_number=pid,
                        type='Internal sink',
                        description=f'Arc {pid}',
                        default=bc_par.internal_sink.constant_q,
                        string_value=str(bc_par.internal_sink.constant_q),
                        value=bc_par.internal_sink.constant_q,
                        optimized_value=0.0,
                        min=bc_par.internal_sink.constant_q,
                        max=bc_par.internal_sink.constant_q
                    )
                param_data.params.append(parameter)

    @staticmethod
    def get_material_data(helper, param_data):
        """Adds the material manning's n as parameters.

        Args:
            helper (SimQueryHelper): Query helper class.
            param_data (ParameterData): Parameter data.

        """
        if helper.material_component is None:
            return
        # Get the material data (gleaned from MaterialMapper.do_map())
        df = helper.material_component.data.materials.to_dataframe()
        ids = df['id'].tolist()
        names = df['Name'].tolist()
        mannings = df["Manning's N"].tolist()
        for i in range(len(ids)):
            if ids[i] == MaterialData.UNASSIGNED_MAT:
                continue
            param_data.params.append(
                ManningsNParameter(
                    use=0,
                    id_number=ids[i],
                    description=names[i],
                    default=mannings[i],
                    string_value=str(mannings[i]),
                    value=mannings[i],
                    optimized_value=0.0,
                    min=mannings[i],
                    max=mannings[i]
                )
            )

    @staticmethod
    def write_parameter_file(main_file, param_data):
        """Writes the parameter data to disk as a json file.

        Read/write the params.json file using json instead of orjson because params.json could have lists
        with more than one type (both floats and strings) and orjson does not seem to support this.  orjson
        raises the following exception when we try to dump the param_data to json:
        'Type is not JSON serializable: numpy.float64'
        See example code (commented) below:

        Args:
            main_file (str): file name
            param_data (ParameterData or dict): Parameter data.
        """
        # trim all whitespace from scenario run names
        the_dict = {}
        if type(param_data) is dict:
            if 'runs' in param_data:
                the_dict = param_data['runs']
        else:
            the_dict = param_data.runs
        if 'Run' in the_dict:
            the_dict['Run'] = [x.strip() for x in the_dict['Run']]

        # Convert list of Parameters to dict for serialization
        param_dict = param_data if type(param_data) is dict else param_data.to_dict()
        filename = os.path.join(os.path.dirname(main_file), 'params.json')
        with open(filename, 'w') as file:
            data = json.dumps(param_dict, cls=NpEncoder)
            file.write(data)

        # import orjson
        # try:
        #     filename = os.path.join(os.path.dirname(main_file), 'params.json')
        #     with open(filename, 'wb') as file:
        #         data = orjson.dumps(param_data)
        #         file.write(data)
        # except orjson.JSONEncodeError as error:
        #     raise error

    @staticmethod
    def read_parameter_file(main_file):
        """Reads the parameter data from disk and returns it as a dict.

        Read/write the params.json file using json instead of orjson because params.json could have lists
        with more than one type (both floats and strings) and orjson does not seem to support this.

        Args:
            main_file (str): Path to SimComponent main file.

        Returns:
            See description.
        """
        filename = os.path.join(os.path.dirname(main_file), 'params.json')
        if os.path.isfile(filename):
            with open(filename, 'r') as file:
                ret = json.loads(file.read())
                # trim all whitespace from scenario run names
                rd = ret.get('runs', None)
                if rd is not None and 'Run' in rd:
                    rd['Run'] = [x.strip() for x in rd['Run']]
                return ret
        return None


class NpEncoder(json.JSONEncoder):
    """Class used with json to fix types that are written.

    See: https://stackoverflow.com/questions/50916422
    """
    def default(self, obj):
        """Default method.

        Args:
            obj: An object

        Returns:
            An object of the proper type.
        """
        if isinstance(obj, np.integer):
            return int(obj)
        # never called
        # elif isinstance(obj, np.floating):
        #     return float(obj)
        # elif isinstance(obj, np.ndarray):
        #     return obj.tolist()
        # else:
        #     return super(NpEncoder, self).default(obj)
