"""TransientDataExporter class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import csv
from dataclasses import dataclass
from datetime import datetime
import os
from pathlib import Path
from typing import Any

# 2. Third party modules
import pandas as pd

# 3. Aquaveo modules
from xms.coverage.xy.xy_series import XySeries

# 4. Local modules
from xms.mf6.file_io import io_util
from xms.mf6.mapping import map_util
from xms.mf6.mapping import xy_series as xy
from xms.mf6.misc import log_util


@dataclass
class TimeInfo:
    """Time information."""
    period_times: pd.DataFrame | None = None
    start_date_time: datetime | None = None
    time_units: str = ''


class TransientDataExporter:
    """Exports transient data to a csv file."""
    def __init__(self, coverage_att_file, map_info, time_info: TimeInfo):
        """Initializes the class.

        Args:
            coverage_att_file (str): Filepath to coverage attribute table file.
            map_info (dict): Dict describing how to get the MODFLOW variable from the shapefile or att table fields.
            time_info: Information about time from TDIS.
        """
        self._coverage_att_file = coverage_att_file
        self._map_info = map_info
        self.period_times = time_info.period_times
        self._start_date_time = time_info.start_date_time
        self._time_units = time_info.time_units

        self._xy_series: dict[int, XySeries] = {}  # Dict of xy series IDs -> XySeries
        self._mixed_arcs = []  # List of arcs where one node used a constant value and the other an XY series
        self._extrap = False  # If any temporal extrapolation occurred, this will be set to True
        self._tin_files: dict[str, str] = {}  # Tin UUID -> .xmc filepath

    def export(self):
        """Exports the transient data to a CSV file."""
        logger = log_util.get_logger()
        logger.info('Exporting transient data')

        self._xy_series = io_util.read_xy_series_file(
            self._coverage_att_file, self._start_date_time, self._time_units, date_times_to_floats=True
        )
        att_table_columns = self._get_att_file_columns_fixed(self._coverage_att_file)
        trans_data_columns = self._get_trans_data_columns(att_table_columns)
        if not trans_data_columns:  # This package doesn't have transient data so no file is created
            return ''

        # Open att file as CSV for reading
        with open(self._coverage_att_file, 'r') as att_csv_file:
            reader = csv.DictReader(att_csv_file, fieldnames=att_table_columns)
            # skip the header row
            header = next(reader)  # noqa F841 local variable 'header' is assigned to but never used

            # Open CSV file we're creating for writing
            csv_out_filename = os.path.splitext(self._coverage_att_file)[0] + '.csv'
            with open(csv_out_filename, 'w', newline='') as csv_out_file:
                writer = csv.writer(csv_out_file)
                headings = ['Period', 'ID'] + trans_data_columns
                writer.writerow(headings)

                csv_rows = {}  # period -> list of rows
                for row in reader:
                    self._get_csv_rows(row, trans_data_columns, csv_rows)
                for _, rows in csv_rows.items():
                    writer.writerows(rows)

        self._log_issues()
        return csv_out_filename

    def _log_issues(self):
        """If there were any issues, report them to the log now."""
        logger = log_util.get_logger()
        if self._mixed_arcs:
            arc_ids = ', '.join(self._mixed_arcs)
            logger.info(
                f'The following arc(s) have constant values at one end and transient data on the other.'
                f' The constant value was used for all time steps in the temporal interpolation.\n'
                f' Arc id(s): {arc_ids}'
            )

        if self._extrap:
            logger.info(
                'Some xy series do not cover the range of times specified by the MODFLOW stress periods.'
                ' Values were extrapolated from the information that is available.'
            )

    def _get_csv_rows(self, row, trans_data_columns, csv_rows):
        """Given an att table row, adds rows to csv_rows.

        Args:
            row (dict): A row from the att table returned by DictReader.
            trans_data_columns (list): Column names in the transient data file.
            csv_rows (dict): period -> list of rows
        """
        period_times = self.period_times['Time'].tolist()
        for i in range(1, len(period_times)):  # Start on 2nd time
            if i not in csv_rows:
                csv_rows[i] = []
            csv_columns = [str(i), row['ID']]  # Stress period, Feature ID
            trans_data_columns_iter = iter(trans_data_columns)
            include_row = True
            start = period_times[i - 1]
            end = period_times[i]
            for column in trans_data_columns_iter:
                if self._map_info[column] and self._map_info[column].node2:
                    # Values at the nodes. Do both together so we can warn if ss and trans
                    # Node 1
                    node1_column = column
                    value1, source1 = self._att_value(node1_column, row, start, end)

                    # Node 2
                    node2_column = self._map_info[column].node2
                    value2, source2 = self._att_value(node2_column, row, start, end)

                    if value1 is None or value2 is None:
                        # XY series null value. Don't include this period for this feature
                        include_row = False
                        break

                    if (source1 == 'xy' and source2 == '') or (source1 == '' and source2 == 'xy'):
                        self._mixed_arcs.append(row['ID'])

                    if source1 == 'tin':
                        value1 = self._get_tin_filepath(value1)
                    if source2 == 'tin':
                        value2 = self._get_tin_filepath(value2)

                    csv_columns.append(value1)
                    csv_columns.append(value2)
                    next(trans_data_columns_iter)  # Skip the next one since we already handled it
                else:
                    value, source = self._att_value(column, row, start, end)
                    if source == 'xy' and value is None:
                        # XY series null value. Don't include this period for this feature
                        include_row = False
                        break
                    elif source == 'tin':
                        value = self._get_tin_filepath(value)

                    csv_columns.append(value)
            if include_row:
                csv_rows[i].append(csv_columns)

    def _get_tin_filepath(self, tin_uuid: str) -> str:
        """Given a tin uuid, return the filepath to the .xmc file of the tin.

        Returns:
            See description.
        """
        if not self._tin_files:
            self._init_tin_files()
        return self._tin_files.get(tin_uuid, '')

    def _init_tin_files(self) -> None:
        """Initializes self._tin_files by reading the .ugrids.csv file."""
        ugrid_file_list = self._ugrid_list_file()
        if not ugrid_file_list:
            return

        # Read list of UGrids / datasets. File should be a .csv with 2 columns: tin uuid, .xmc filepath
        with Path(ugrid_file_list).open('r') as file:
            reader = csv.reader(file)
            for row in reader:
                self._tin_files[row[0]] = row[1]

    def _ugrid_list_file(self) -> str:
        """Return the name of the file containing a list of UGrids used by the coverage, or '' if it doesn't exist.

        Returns:
            See description.
        """
        ugrid_list_file = f'{self._coverage_att_file}.ugrids.csv'
        if Path(ugrid_list_file).exists:
            return ugrid_list_file
        return ''

    def _att_value(self, column_name: str, row: dict, start: float, end: float) -> tuple[Any, str]:
        """Return the value of the attribute for the time period, and the source of the value.

        Source is '', 'xy' if from an xy series, or 'tin' if value is a TIN uuid.

        Args:
            column_name: Name of the column.
            row: A row from the att table returned by DictReader.
            start: Beginning of time range.
            end: End of time range.

        Returns:
            See description.
        """
        value = _get_cell_data(column_name, row)
        curveid = _get_cell_data_xy(column_name, row)
        tin_uuid = _get_cell_data_tin(column_name, row)
        source = ''
        if curveid is not None and curveid > -1:
            value = self._value_at_time_range(curveid, start, end)
            source = 'xy'
        elif tin_uuid:
            value = tin_uuid
            source = 'tin'
        return value, source

    def _value_at_time_range(self, curve_id: int, start: float, end: float):
        """Returns the value of the attribute for the time period.

        Args:
            curve_id: XY series ID.
            start: Beginning of time range.
            end: End of time range.

        Returns:
            (tuple): tuple containing:
                - (float): The value.
                - (bool): True if value came from an XY series.
        """
        if not curve_id:
            return
        curve = self._xy_series[curve_id]
        value, extrap, reason = xy.average_value_from_time_range(curve.x, curve.y, start, end)
        if extrap and reason not in ['partially before', 'partially after']:
            self._extrap = True
            value = None
        return value

    # def _att_value_at_time_range(self, column, row, begtime, endtime):
    #     """Returns the value of the attribute for the time period.
    #
    #     Args:
    #         column (str): Name of the column.
    #         row (dict): A row from the att table returned by DictReader.
    #         begtime (float): Beginning of time range.
    #         endtime (float): End of time range.
    #
    #     Returns:
    #         (tuple): tuple containing:
    #             - (float): The value.
    #             - (bool): True if value came from an XY series.
    #     """
    #     curve_id = -1
    #     from_xy = False
    #     xy_column = f'{column}_TS'
    #     if xy_column in row:
    #         curve_id = int(row[xy_column])
    #     if curve_id >= 0:
    #         curve = self._xy_series[curve_id]
    #         value, extrap, reason = xy.average_value_from_time_range(curve['x'], curve['y'], begtime, endtime)
    #         from_xy = True
    #         if extrap and reason not in ['partially before', 'partially after']:
    #             self._extrap = True
    #             value = None
    #     elif column in row:
    #         value = row[column]
    #     else:
    #         value = None
    #     return value, from_xy

    def _get_att_file_columns_fixed(self, att_file):
        """Reads table def file and returns list of all columns.

        Args:
            att_file (str): Att table filename.

        Returns:
            (list): list of all column names.
        """
        table_def = map_util.read_table_definition_file(att_file)
        table_def = map_util.fix_field_names(table_def, map_util.standard_fields())
        column_names = []
        for column in table_def['columns']:
            column_names.append(column['name'])
        return column_names

    def _get_trans_data_columns(self, att_table_columns):
        """Returns a list of the column names that should be in the transient data file.

        Args:
            att_table_columns (list(str)): List of the columns in the attribute table (fixed).

        Returns:
            (list(str)): See description.
        """
        trans_data_columns = []
        for key in list(self._map_info.keys()):
            if key in att_table_columns:
                trans_data_columns.append(key)
        return trans_data_columns


def _get_cell_data(column: str, row: dict) -> str | None:
    """Return the data in the cell defined by column / row.

    Args:
        column: Name of the column.
        row: A row from the att table returned by DictReader.

    Returns:
        See description.
    """
    if column in row:
        return row[column]
    return None


def _get_cell_data_xy(column_name: str, row: dict) -> int | None:
    """Return the name of the XY series column associated with column_name (e.g. "Head-Stage" -> "Head-Stage_TS").

    Args:
        column_name: Column name.
        row: A row from the att table returned by DictReader.

    Returns:
        See description.
    """
    xy_column = f'{column_name}_TS'
    return int(row[xy_column]) if xy_column in row else None


def _get_cell_data_tin(column_name: str, row: dict) -> str | None:
    """Return the uuid of the TIN associated with column_name, or None.

    The TIN uuid will be stored in a column that starts with column_name and ends with "_TIN".
    Example: "Head-Stage" -> "Head-Stage_TIN"

    Args:
        column_name: Column name.
        row: A row from the att table returned by DictReader.

    Returns:
        See description.
    """
    return row.get(f'{column_name}_TN')
