"""ExporUGridTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os
from typing import List

# 2. Third party modules
import numpy as np
from stl import mesh, Mode

# 3. Aquaveo modules
from xms.grid.ugrid import UGrid as XmUGrid
from xms.tool_core import Argument, IoDirection, Tool

# 4. Local modules


class ExportUGridTool(Tool):
    """Tool to export a grid to various file formats."""

    ARG_INPUT_GRID = 0
    ARG_OUTPUT_FILETYPE = 1
    ARG_OUTPUT_FILENAME = 2

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Export UGrid')
        self._input_cogrid = None
        self._ftype_suffix = {
            'Ascii XMC': 'xmc',
            'Binary XMC': 'xmc',
            'Ascii STL': 'stl',
            'Binary STL': 'stl',
            'OBJ': 'obj'
        }

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.grid_argument(name='input_ugrid', description='UGrid', io_direction=IoDirection.INPUT),
            self.string_argument(name='file_type', description='File type', value='Ascii XMC',
                                 choices=list(self._ftype_suffix.keys())),
            self.file_argument(name='output_file', description='Output file', file_filter='XMC file (*.xmc)',
                               default_suffix='xmc', io_direction=IoDirection.OUTPUT, value='grid.xmc'),

        ]
        self.enable_arguments(arguments)
        return arguments

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}
        self._input_cogrid = self.get_input_grid(arguments[self.ARG_INPUT_GRID].text_value)
        if self._input_cogrid is None:
            msg = 'Unable to load UGrid. Aborting'
            errors[arguments[self.ARG_INPUT_GRID].name] = msg
            return errors
        ftype = arguments[self.ARG_OUTPUT_FILETYPE].value
        if any(s in ftype for s in ['STL', 'OBJ']):
            if not self._input_cogrid.check_all_cells_are_of_types([XmUGrid.cell_type_enum.TRIANGLE]):
                msg = 'The grid must only have triangle cells if exporting STL or OBJ.'
                errors[arguments[self.ARG_INPUT_GRID].name] = msg
        return errors

    def enable_arguments(self, arguments: List[Argument]):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.
        """
        ftype = arguments[self.ARG_OUTPUT_FILETYPE].value
        file_arg = arguments[self.ARG_OUTPUT_FILENAME]
        suffix = self._ftype_suffix[ftype]
        if file_arg.default_suffix == suffix:
            return
        file_arg.file_filter = f'{suffix.upper()} file (*.{self._ftype_suffix[ftype]})'
        file_arg.default_suffix = self._ftype_suffix[ftype]
        if file_arg.value:
            pre, ext = os.path.splitext(file_arg.value)
            file_arg.value = pre + '.' + file_arg.default_suffix

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        ftype = arguments[self.ARG_OUTPUT_FILETYPE].value
        out_file = arguments[self.ARG_OUTPUT_FILENAME].value
        if ftype == 'Ascii XMC':
            self._input_cogrid.write_to_file(out_file, binary_arrays=False)
        elif ftype == 'Binary XMC':
            self._input_cogrid.write_to_file(out_file, binary_arrays=True)
        else:
            ug = self._input_cogrid.ugrid
            pts = ug.locations
            cs = ug.cellstream
            cnt = 0
            cells = []
            while cnt < len(cs):
                cells.append(list(cs[cnt + 2:cnt + 5]))
                cnt += 5
            if 'STL' in ftype:
                faces = np.asarray(cells)
                dd = np.dtype([
                    ('normals', np.float32, (3,)),
                    ('vectors', np.float64, (3, 3)),
                    ('attr', np.uint16, (1,)),
                ])
                stl_mesh = mesh.Mesh(np.zeros(faces.shape[0], dtype=dd))
                for i, f in enumerate(faces):
                    for j in range(3):
                        stl_mesh.vectors[i][j] = pts[f[j], :]
                mode = Mode.AUTOMATIC
                if ftype == 'Ascii STL':
                    mode = Mode.ASCII
                stl_mesh.save(out_file, mode=mode)
            elif ftype == 'OBJ':
                with open(out_file, 'w') as f:
                    f.write('# OBJ file exported from XMS\n#\n')
                    pt_str = ''.join(f'v {p[0]} {p[1]} {p[2]}\n' for p in pts)
                    f.write(pt_str)
                    face_str = ''.join(f'f {c[0]+1} {c[1]+1} {c[2]+1}\n' for c in cells)
                    f.write(face_str)
