"""ImportUGridPoints class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os
import uuid

# 2. Third party modules
import pandas
import pandas as pd

# 3. Aquaveo modules
from xms.constraint.ugrid_builder import UGridBuilder
from xms.grid.ugrid import UGrid
from xms.tool_core import IoDirection, Tool

# 4. Local modules

ARG_INPUT_File = 0
ARG_INPUT_Null_Value = 1
ARG_INPUT_Time_Unit = 2
ARG_INPUT_Prj = 3
ARG_UGRID = 4


class ImportUGridPointsTool(Tool):
    """Tool to import UGrid points and datasets from a csv file."""

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Import UGrid Points')
        self._df = None
        self._ug_name = ''
        self._args = None
        self._cogrid = None
        self._location_cols = None
        self._time_exists = False
        self._ref_time = None
        self._time_offsets = None
        self._data_sets = []

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.file_argument(name='csv_file', description='CSV file with point coordinates and datasets'),
            self.float_argument(name='null_value', description='No data value', value=-999.0),
            self.string_argument(name='time_unit',
                                 description='Time unit (used for transient data not using dates-times)', value='Days',
                                 choices=['Seconds', 'Minutes', 'Hours', 'Days']),
            self.file_argument(name='prj_file', description='Coordinate system projection file (*.prj)',
                               optional=True),
            self.grid_argument(name='ugrid', description='The ugrid', hide=True, optional=True,
                               io_direction=IoDirection.OUTPUT)
        ]
        return arguments

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        in_file = arguments[ARG_INPUT_File].text_value
        try:
            self._df = pandas.read_csv(in_file)
        except ValueError:
            self.fail(f'Error reading file: {in_file}. Check file format.')
        self._args = arguments
        self._ug_name = os.path.basename(in_file)
        self._create_ugrid()
        self._create_data_sets()
        self._set_outputs()

    def _set_outputs(self):
        """Set outputs from the tool."""
        prj = None
        prj_file = self._args[ARG_INPUT_Prj].value
        if prj_file:
            with open(prj_file, 'r') as f:
                prj = f.readline()
        self.set_output_grid(self._cogrid, self._args[ARG_UGRID], projection=prj)
        if self._data_sets:
            for ds in self._data_sets:
                self.set_output_dataset(ds)
        else:
            self.logger.warning('No datasets created from file.')

    def _create_ugrid(self):
        """Creates a ugrid from the x, y coordinates in a data frame."""
        self.logger.info('Reading point coordinates from file.')
        cols = [c for c in self._df.columns]
        col_list = []
        if 'x' in cols:
            col_list.append('x')
        elif 'X' in cols:
            col_list.append('X')
        if 'y' in cols:
            col_list.append('y')
        elif 'Y' in cols:
            col_list.append('Y')
        if len(col_list) < 2:
            self.fail('File must have X and Y headings to import. Aborting.')
        if 'z' in cols:
            col_list.append('z')
        elif 'Z' in cols:
            col_list.append('Z')
        self._location_cols = col_list

        locs = [self._df[col_list[0]].tolist(), self._df[col_list[1]].tolist()]
        if len(col_list) > 2:
            locs.append(self._df[col_list[2]].tolist())
        else:
            locs.append([0.0] * len(locs[0]))

        if not self._values_are_float([locs[0][0], locs[1][0], locs[2][0]]):
            self.logger.info('Detected time(s) associated with datasets.')
            locs[0].pop(0)
            locs[1].pop(0)
            locs[2].pop(0)
            self._time_exists = True

        self.logger.info('Building UGrid.')
        ug_pts = [(float(locs[0][i]), float(locs[1][i]), float(locs[2][i])) for i in range(len(locs[0]))]
        ug = UGrid(ug_pts, [])
        co_builder = UGridBuilder()
        co_builder.set_ugrid(ug)
        self._cogrid = co_builder.build_grid()
        self._cogrid.uuid = str(uuid.uuid4())
        self._args[ARG_UGRID].value = self._ug_name
        self.logger.info('UGrid successfully created.')

    def _values_are_float(self, vals):
        """Determine if val is a float.

        Args:
            val (list(unknown)): value from a csv file that should be a float

        Returns:
            (bool): True if the value is a float
        """
        for val in vals:
            if pd.isna(val):
                return False
            try:
                float(val)
            except ValueError:
                return False
        return True

    def _create_data_sets(self):
        """Creates datasets from columns in a data frame."""
        self.logger.info('Creating datasets.')
        cols = [c for c in self._df.columns if c not in self._location_cols]
        ds_cols = []
        found_cols = set()
        for col in cols:
            if 'Unnamed' in col:
                continue
            if col in found_cols:
                continue
            df1 = self._df.filter(like=f'{col}.')
            filter_cols = [c for c in df1.columns]
            found_cols.update([c for c in df1.columns])
            ds_cols.append((col, [col] + filter_cols))

        self.logger.info(f'Found the following dataset columns in the file: {ds_cols}')

        time_units = self._args[ARG_INPUT_Time_Unit].value
        null_value = self._args[ARG_INPUT_Null_Value].value
        for ds_col in ds_cols:
            ds_name = ds_col[0]
            self._ref_time = None
            self._time_offsets = None
            self.logger.info(f'Reading data for dataset {ds_name}')
            df1 = self._df[ds_col[1]]
            cols = [c for c in df1.columns]
            if self._time_exists:
                times = df1.iloc[[0]].values.flatten().tolist()
                if len(times) > 1:
                    if not self._values_are_time(times):
                        self.logger.warning(f'Dataset "{ds_name}" skipped. Invalid times specified: '
                                            f'{times}.')
                        continue
                    else:
                        times = self._time_offsets
                else:
                    if not self._values_are_time(times):
                        times = [0.0]
                self.logger.info(f'Dataset: {ds_name} has the following times: {times}')
                times = [(t, cols[i]) for i, t in enumerate(times)]
                times.sort()
            else:
                times = [(0.0, cols[0])]
            dsw = self.get_output_dataset_writer(
                name=ds_name,
                geom_uuid=self._cogrid.uuid,
                num_components=1,
                ref_time=self._ref_time,
                time_units=time_units,
                null_value=null_value,
            )
            start = 0 if not self._time_exists else 1
            for ts in times:
                vals = [float(v) for v in df1[ts[1]].tolist()[start:]]
                dsw.append_timestep(time=ts[0], data=vals)
            dsw.appending_finished()
            self._data_sets.append(dsw)

    def _values_are_time(self, vals):
        """Determine if val is a time.

        Args:
            vals (list(unknown)): value from a csv file that should be a float or a time

        Returns:
            (bool): True if the value is a time
        """
        if self._values_are_float(vals):
            self._time_offsets = vals
            return True

        # if any values are float then return False
        for v in vals:
            try:
                float(v)
                return False
            except ValueError:
                pass

        # convert all to julian
        try:
            dt = pd.to_datetime(vals).to_julian_date()
            self._ref_time = min(dt)
            self._time_offsets = [v - self._ref_time for v in dt]
            return True
        except ValueError:
            return False
