"""Export a project into a folder for use with tools."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from datetime import datetime
from importlib.metadata import version
import logging
from xml.etree import ElementTree

# 2. Third party modules
from rtree import index

# 3. Aquaveo modules

# 4. Local modules

Locations = list[tuple[float, float]]  # list of xy locations


class LandXmlExporter:
    """Algorithm to maintain breaklines and output results to LandXML."""
    def __init__(self, ugrid, ugrid_name, coverage, output_path, horizontal_units, logger=None):
        """Initializes the class.

        Args:
            ugrid (UGrid): unstructured grid/mesh
            ugrid_name (str): name of ugrid
            coverage (Coverage): coverage
            output_path (str): location of output file
            horizontal_units (str): horizontal units of projection used
            logger (logging): logger
        """
        if not ugrid.check_all_cells_2d():
            raise ValueError('All UGrid cells must be 2D.')
        self._input_ugrid = ugrid
        self._ugrid = ugrid.ugrid
        self._ugrid_locations = ugrid.ugrid.locations
        self._ugrid_name = ugrid_name
        self._coverage = coverage
        self._units = horizontal_units  # using coverage projection for units
        self.output_path = output_path
        self._rtree = None
        self._breaklines = []

        self.logger = logger
        if self.logger is None:
            self.logger = logging.getLogger('xms.tool')

    def export(self):
        """
        Computes each breakline from each arc in the coverage. Outputs xml in landxml formatting.
        """
        # Process breaklines if coverage exists
        self._process_breaklines()

        self.logger.info('Writing UGrid and coverage data.')
        xml_output = self._create_landxml()

        with open(self.output_path, 'w') as f:
            f.write(xml_output)

        self.logger.info(f'XML file created and saved at {self.output_path}.')

    def _create_landxml(self):
        """Create an XML document for LandXML including breakline points."""
        current_date, current_time = _get_date_and_time()

        ElementTree.register_namespace('', 'http://www.landxml.org/schema/LandXML-1.2')
        # ElementTree.register_namespace('xsi', 'http://www.w3.org/2001/XMLSchema-instance')
        root = ElementTree.Element('LandXML',
                                   xmlns='http://www.landxml.org/schema/LandXML-1.2',
                                   attrib={'{http://www.w3.org/2001/XMLSchema-instance}schemaLocation':
                                           'http://www.landxml.org/schema/LandXML-1.2 '
                                           'http://www.landxml.org/schema/landxml-1.2/LandXML-1.2.xsd'},
                                   version='1.2',
                                   date=current_date,
                                   time=current_time)

        # Setting up units (Inches, Centimeter, and Degrees not supported)
        if self._units == 'FEET (U.S. SURVEY)' or self._units == 'FEET (INTERNATIONAL)':
            units = ElementTree.SubElement(root, "Units")
            ElementTree.SubElement(units, 'Imperial',
                                   areaUnit='squareFoot',
                                   linearUnit='foot',
                                   volumeUnit='cubicFeet',
                                   temperatureUnit='fahrenheit',
                                   pressureUnit='inHG',
                                   diameterUnit='foot')
        elif self._units == 'METERS':
            units = ElementTree.SubElement(root, "Units")
            ElementTree.SubElement(units, 'Metric',
                                   areaUnit='squareMeter',
                                   linearUnit='meter',
                                   volumeUnit='cubicMeter',
                                   temperatureUnit='celsius',
                                   pressureUnit='milliBars',
                                   diameterUnit='meter')
        else:
            self.logger.info('The current display projection\'s unit is not supported. File will not include units.')

        xmstool_version = _get_xmstool_version()
        ElementTree.SubElement(root, 'Application',
                               name='xmstool', manufacturer='Aquaveo', version=xmstool_version)

        # Surfaces (UGrid) for landXML
        surfaces = ElementTree.SubElement(root, 'Surfaces')
        surface = ElementTree.SubElement(surfaces, 'Surface', name=self._ugrid_name, state='existing')

        # SourceData elements in order (Chain, PointFiles, Boundaries, Breaklines*, Contour, DataPoints*)
        sourcedata = ElementTree.SubElement(surface, 'SourceData')
        # if self._coverage is not None:
        if len(self._breaklines) != 0:
            breaklines = ElementTree.SubElement(sourcedata, 'Breaklines')
            for i, breakline in enumerate(self._breaklines):
                breakline_header = ElementTree.SubElement(breaklines, 'Breakline', name=f'Breakline{i+1}')
                pnt_list_3d = ' '.join(f'{y} {x} {z}' for x, y, z in breakline)
                ElementTree.SubElement(breakline_header, 'PntList3D').text = pnt_list_3d

        datapoints = ElementTree.SubElement(sourcedata, 'DataPoints', name=self._ugrid_name)
        pnt_list = ' ' + ' '.join(f'{y} {x} {z}' for x, y, z in self._ugrid_locations)
        ElementTree.SubElement(datapoints, 'PntList3D').text = pnt_list

        # Definition (UGrid data)
        definition = ElementTree.SubElement(surface, 'Definition', surfType='TIN')  # TIN and grid
        def_points = ElementTree.SubElement(definition, 'Pnts')
        for point_id, (x, y, z) in enumerate(self._ugrid_locations, start=1):
            ElementTree.SubElement(def_points, 'P', id=str(point_id)).text = f'{y} {x} {z}'

        faces = ElementTree.SubElement(definition, 'Faces')
        for cell_count in range(self._ugrid.cell_count):
            faces_points = ' '.join(f'{cell + 1}' for cell in self._ugrid.get_cell_points(cell_count))
            ElementTree.SubElement(faces, 'F').text = faces_points

        # Convert the tree to a string
        ElementTree.indent(root, space='    ')
        xml_string = ElementTree.tostring(root, encoding='utf-8', xml_declaration=True).decode('utf-8')

        return xml_string

    def _process_breaklines(self):
        """Processes breaklines using ugrid and arcs."""
        if self._coverage is None or self._ugrid.cell_count < 1:
            return

        self.logger.info('Computing breakline.')
        self._rtree = _build_2d_rtree(self._ugrid_locations)
        self._breaklines = []  # When testing to rewrite baseline, breakline does not reset
        arcs = self._coverage[self._coverage['geometry_types'] == 'Arc']
        for arc in arcs.itertuples():
            arc_points = arc.geometry.coords
            breakline_points = []
            idx_list = []
            for x, y, _ in arc_points:
                idx = list(self._rtree.nearest((x, y)))[0]
                if not idx_list or idx != idx_list[-1]:
                    idx_list.append(idx)
                    breakline_points.append(self._ugrid_locations[idx])
            self._breaklines.append(breakline_points)


def _rtree_2d_insert_generator(grid_locations: Locations):
    """This generator function is supposed to be a faster way to populate the rtree?

    https://rtree.readthedocs.io/en/latest/performance.html#use-stream-loading
    """
    for i, location in enumerate(grid_locations):
        yield i, (location[0], location[1], location[0], location[1]), i


def _build_2d_rtree(grid_locations: Locations):
    """Builds an rtree using a generator function which is supposed to be faster.

    https://rtree.readthedocs.io/en/latest/performance.html#use-stream-loading
    """
    p = index.Property()
    p.dimension = 2
    return index.Index(_rtree_2d_insert_generator(grid_locations), properties=p)


def _get_date_and_time() -> tuple[str, str]:
    """Returns the current date and time (and is mocked in tests)."""
    current_datetime = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    current_date, current_time = current_datetime.split(" ")
    return current_date, current_time


def _get_xmstool_version() -> str:
    """Returns the current version of xmstool (and is mocked in tests)."""
    return version('xmstool')
