"""ChdFromPolysBuilder class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules

# 3. Aquaveo modules

# 4. Local modules
from xms.mf6.mapping.cell_adder import append_period_row
from xms.mf6.mapping.cell_adder_polys_chd import CellAdderPolysChd
from xms.mf6.mapping.cell_polygon_calculator import CellPolygonCalculator
from xms.mf6.mapping.package_builder_base import add_list_package_period_data, att_type_from_ftype, cell_active


class ChdFromPolysBuilder:
    """Builds a CHD package from polygons."""
    def __init__(self, builder, ix_recs, records, reader):
        """Initializer.

        Args:
            builder: The package builder.
            ix_recs (list): The intersection info.
            records: The records from the shapefile.
            reader: Shapefile reader.
        """
        self._builder = builder
        self._ix_recs = ix_recs
        self._records = records
        self._reader = reader

    def build(self):
        """Handles building the CHD list package from polygons, which is different from other list packages."""
        att_type = att_type_from_ftype('CHD6')
        calculator = CellPolygonCalculator(self._builder, self._ix_recs, self._records, att_type)
        cell_areas, cell_polys = calculator.get_cell_polys_and_areas()

        # Do it one period at a time so as not to accumulate too much in RAM before we dump it
        period_list = [1] if not self._builder.trans_data else list(self._builder.trans_data.keys())
        for self._builder.period in period_list:
            period_dict = {}
            self._add_chd_cells_to_period(cell_areas, cell_polys, period_dict)

            # Convert period bc dicts to lists
            period_rows = []
            for cell_idx, column_values in period_dict.items():
                cellid = self._builder.grid_info.modflow_cellid_from_cell_index(cell_idx)
                append_period_row(cellid, None, column_values, period_rows)
            if period_rows:
                add_list_package_period_data(self._builder._package, self._builder.period, period_rows)

    def _add_chd_cells_to_period(self, cell_areas, cell_polys, period_dict):
        """For this period, add the list of rows to the period.

        Args:
            cell_areas (dict): Dict of cell_idx -> list of overlapping areas.
            cell_polys (dict): Dict of cell_idx -> list of polygons overlapping the cell.
            period_dict (dict): Dict of cell_idx -> list of values (a row in the list package).
        """
        for cell_idx, poly_list in cell_polys.items():
            if not cell_active(self._builder._idomain, cell_idx):
                continue
            if len(poly_list) == 1:  # Only one poly touches cell. It gets full poly value.
                poly_idx = poly_list[0]
                self._builder._shape = self._reader.shape(poly_idx)
                cell_adder = CellAdderPolysChd(
                    self._builder, cell_idx, self._records[poly_idx].as_dict(), period_dict, None, None
                )
                cell_adder.add_bc_to_cell()
            else:  # Multiple polys touch cell. Add 1 chd bc to cell by adding all values, weighting by area
                total_overlap = sum(cell_areas[cell_idx])
                for i, area in enumerate(cell_areas[cell_idx]):
                    poly_idx = poly_list[i]
                    self._builder._shape = self._reader.shape(poly_idx)
                    cell_adder = CellAdderPolysChd(
                        self._builder, cell_idx, self._records[poly_idx].as_dict(), period_dict, area, total_overlap
                    )
                    cell_adder.add_bc_to_cell()
