"""DisuReader class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules
from typing_extensions import override

# 3. Aquaveo modules

# 4. Local modules
from xms.mf6.components import dis_builder
from xms.mf6.data.grid_info import DisEnum
from xms.mf6.file_io.dis_reader_base import DisReaderBase


class DisuReader(DisReaderBase):
    """Reads a DISU package file."""
    def __init__(self):
        """Initializes the class."""
        super().__init__(ftype='DISU6')

    def _read_dimensions(self, line):
        """Reads the dimensions block.

        Args:
            line (str): A line from the file.
        """
        self._data.grid_info().dis_enum = DisEnum.DISU

        words = line.split()
        if words and len(words) > 1:
            if words[0].upper() == 'NODES':
                self._data.grid_info().nodes = int(words[1])
            elif words[0].upper() == 'NJA':
                self._data.grid_info().nja = int(words[1])
            elif words[0].upper() == 'NVERT':
                self._data.grid_info().nvert = int(words[1])

    @override
    def _on_end_block(self, block_name):
        """Called when an END [block] line is found.

        Args:
            block_name (str): Name of the current block.
        """
        if block_name.upper() == 'DIMENSIONS':
            if self._data.grid_info().nvert <= 0:
                # "If NVERT is not specified or is specified as zero, then the VERTICES and CELL2D
                # blocks below are not read."
                self._blocks_to_skip.add('VERTICES')
                self._blocks_to_skip.add('CELL2D')
        elif block_name.upper() == 'CONNECTIONDATA' and self.importing:
            # Make sure the HWVA array is symmetric
            connectiondata = self._data.block('CONNECTIONDATA')
            iac = [int(s) for s in connectiondata.array('IAC').get_values()]
            ja = [int(s) for s in connectiondata.array('JA').get_values()]
            hwva = [float(s) for s in connectiondata.array('HWVA').get_values()]
            hwva = dis_builder.make_symmetric(iac, ja, hwva)
            shape = connectiondata.array('HWVA').layer(0).shape
            connectiondata.array('HWVA').set_values(hwva, shape, False)
