diff --git a/data/tests/images/IC86lower_deepcore_test.npy b/data/tests/images/IC86lower_deepcore_test.npy new file mode 100644 index 000000000..178a09047 Binary files /dev/null and b/data/tests/images/IC86lower_deepcore_test.npy differ diff --git a/data/tests/images/IC86main_array_test.npy b/data/tests/images/IC86main_array_test.npy new file mode 100644 index 000000000..628cbfd71 Binary files /dev/null and b/data/tests/images/IC86main_array_test.npy differ diff --git a/data/tests/images/IC86upper_deepcore_test.npy b/data/tests/images/IC86upper_deepcore_test.npy new file mode 100644 index 000000000..24a3cc697 Binary files /dev/null and b/data/tests/images/IC86upper_deepcore_test.npy differ diff --git a/src/graphnet/constants.py b/src/graphnet/constants.py index 5e5717f78..a069fbfd8 100644 --- a/src/graphnet/constants.py +++ b/src/graphnet/constants.py @@ -21,6 +21,14 @@ TEST_PARQUET_DATA = os.path.join( TEST_DATA_DIR, "parquet", _test_dataset_name, "merged" ) +TEST_IMAGE_DIR = os.path.join(TEST_DATA_DIR, "images") +TEST_IC86MAIN_IMAGE = os.path.join(TEST_IMAGE_DIR, "IC86main_array_test.npy") +TEST_IC86LOWERDC_IMAGE = os.path.join( + TEST_IMAGE_DIR, "IC86lower_deepcore_test.npy" +) +TEST_IC86UPPERDC_IMAGE = os.path.join( + TEST_IMAGE_DIR, "IC86upper_deepcore_test.npy" +) # Example data EXAMPLE_DATA_DIR = os.path.join(DATA_DIR, "examples") diff --git a/src/graphnet/models/data_representation/__init__.py b/src/graphnet/models/data_representation/__init__.py index 84dd64331..e72c2c712 100644 --- a/src/graphnet/models/data_representation/__init__.py +++ b/src/graphnet/models/data_representation/__init__.py @@ -18,3 +18,11 @@ NodeAsDOMTimeSeries, IceMixNodes, ) +from .images import ( + ExamplePrometheusGridDefinition, + ExamplePrometheusImage, + GridDefinition, + IC86GridDefinition, + IC86Image, + ImageRepresentation, +) diff --git a/src/graphnet/models/data_representation/images/__init__.py b/src/graphnet/models/data_representation/images/__init__.py new file mode 100644 index 000000000..14e1d2bac --- /dev/null +++ b/src/graphnet/models/data_representation/images/__init__.py @@ -0,0 +1,9 @@ +"""CNN images: ``ImageRepresentation`` + ``GridDefinition``.""" + +from .image_representation import ImageRepresentation +from .images import IC86Image, ExamplePrometheusImage +from .mappings import ( + ExamplePrometheusGridDefinition, + GridDefinition, + IC86GridDefinition, +) diff --git a/src/graphnet/models/data_representation/images/image_representation.py b/src/graphnet/models/data_representation/images/image_representation.py new file mode 100644 index 000000000..38e0c64e1 --- /dev/null +++ b/src/graphnet/models/data_representation/images/image_representation.py @@ -0,0 +1,145 @@ +"""Image CNN input pipeline: pulses → pixels → grids / tensors.""" + +from typing import List, Optional, Dict, Tuple, Union, Any, Callable +import torch +import numpy as np +from numpy.random import Generator + +from graphnet.models.data_representation import DataRepresentation +from graphnet.models.data_representation.graphs import NodeDefinition +from torch_geometric.data import Data +from .mappings import GridDefinition + + +class ImageRepresentation(DataRepresentation): + """Compose a pixel definition with a detector grid for CNN inputs. + + A :class:`~graphnet.models.data_representation.graphs.nodes.NodeDefinition` + acts as **pixel definition**: pulses ``X`` are aggregated into unordered + pixel rows ``P`` (the same abstraction as graph nodes, without requiring + graph terminology for CNN users). + + A :class:`GridDefinition` defines detector-bound orthonormal grid shape(s) + and lookup table(s); its :meth:`~GridDefinition.forward` scatters ``P`` + into image tensor(s). + + The :class:`~graphnet.models.detector.detector.Detector` is taken from + ``grid_definition.detector`` so the grid matches the preprocessing geometry. + """ + + def __init__( + self, + pixel_definition: NodeDefinition, + grid_definition: GridDefinition, + input_feature_names: Optional[List[str]] = None, + dtype: Optional[torch.dtype] = torch.float, + perturbation_dict: Optional[Dict[str, float]] = None, + seed: Optional[Union[int, Generator]] = None, + add_inactive_sensors: bool = False, + sensor_mask: Optional[List[int]] = None, + string_mask: Optional[List[int]] = None, + ): + """Construct `ImageRepresentation`. + + Args: + pixel_definition: Pulse-level features → one row per pixel/DOM. + grid_definition: Pixel keys + voxel indices + scatter into images. + input_feature_names: Column names in raw pulse tables. If omitted, + the detector's feature list is used. + dtype: Feature dtype (e.g. ``torch.float``). + perturbation_dict: Optional feature noise (see ``DataRepresentation``). + seed: RNG for perturbations. + add_inactive_sensors: Pad inactive sensors when True. + sensor_mask: Drop these sensor IDs. + string_mask: Drop these string IDs. + + Note: + ``pixel_definition`` output columns must match what + ``grid_definition`` expects (including key fields in + :attr:`GridDefinition.map_pixels_by`). + """ + super().__init__( + detector=grid_definition.detector, + input_feature_names=input_feature_names, + dtype=dtype, + perturbation_dict=perturbation_dict, + seed=seed, + add_inactive_sensors=add_inactive_sensors, + sensor_mask=sensor_mask, + string_mask=string_mask, + repeat_labels=False, + ) + self._pixel_definition = pixel_definition + self._grid_definition = grid_definition + + @property + def shape(self) -> List[List[int]]: + """Channel-spatial layout per image tensor (see ``GridDefinition``).""" + return self._grid_definition.shape + + def single_image_spatial_shape(self) -> Tuple[int, int, int]: + """Return spatial size as ``(height, width, depth)`` for one 3D image. + + Raises: + ValueError: If ``shape`` does not describe exactly one four-axis + layout (channels plus three spatial axes). + """ + layouts = self.shape + if len(layouts) != 1: + raise ValueError( + "Expected a single-image data representation (one shape " + f"entry), got {len(layouts)}. For multi-image inputs, build " + "the backbone explicitly for each tensor." + ) + layout = layouts[0] + if len(layout) != 4: + raise ValueError( + "Expected each image layout as " + "[num_channels, height, width, depth]; " + f"got {layout!r}." + ) + return (layout[1], layout[2], layout[3]) + + def _set_output_feature_names( + self, input_feature_names: List[str] + ) -> List[str]: + """Sync pixel columns and grid output names.""" + self._pixel_definition.set_output_feature_names(input_feature_names) + self._grid_definition._set_image_feature_names( + self._pixel_definition._output_feature_names + ) + return self._grid_definition.image_feature_names + + def forward( # type: ignore + self, + input_features: np.ndarray, + input_feature_names: List[str], + truth_dicts: Optional[List[Dict[str, Any]]] = None, + custom_label_functions: Optional[Dict[str, Callable[..., Any]]] = None, + loss_weight_column: Optional[str] = None, + loss_weight: Optional[float] = None, + loss_weight_default_value: Optional[float] = None, + data_path: Optional[str] = None, + ) -> Data: + """Build a ``Data`` object with image tensor(s) on ``x``.""" + data = super().forward( + input_features=input_features, + input_feature_names=input_feature_names, + truth_dicts=truth_dicts, + custom_label_functions=custom_label_functions, + loss_weight_column=loss_weight_column, + loss_weight=loss_weight, + loss_weight_default_value=loss_weight_default_value, + data_path=data_path, + ) + data.x = self._pixel_definition(data.x) + data.x = data.x.type(self.dtype) + data = self._grid_definition(data, self.output_feature_names) + if not isinstance(data.x, list): + data.x = [data.x] + nb_nodes = [] + for i, x in enumerate(data.x): + data.x[i] = x.type(self.dtype) + nb_nodes.append(np.prod(list(data.x[i].size()[2:]))) + data.num_nodes = torch.tensor(np.sum(nb_nodes)) + return data diff --git a/src/graphnet/models/data_representation/images/images.py b/src/graphnet/models/data_representation/images/images.py new file mode 100644 index 000000000..e4a7491ee --- /dev/null +++ b/src/graphnet/models/data_representation/images/images.py @@ -0,0 +1,123 @@ +"""Concrete :class:`ImageRepresentation` subclasses for common detectors.""" + +from typing import List, Optional, Any +import torch + +from graphnet.models.data_representation.graphs import NodeDefinition +from graphnet.models.detector import Detector, IceCube86, ORCA150 + +from .image_representation import ImageRepresentation +from .mappings import IC86GridDefinition, ExamplePrometheusGridDefinition + + +class IC86Image(ImageRepresentation): + """IceCube-86 images (main array + optional DeepCore tensors).""" + + def __init__( + self, + pixel_definition: NodeDefinition, + input_feature_names: List[str], + include_lower_dc: bool = True, + include_upper_dc: bool = True, + string_label: str = "string", + dom_number_label: str = "dom_number", + dtype: Optional[torch.dtype] = torch.float, + detector: Optional[Detector] = None, + **kwargs: Any, + ) -> None: + """Construct `IC86Image`. + + Args: + pixel_definition: Pulse → DOM row features (:class:`NodeDefinition`). + input_feature_names: Raw input column names. + include_lower_dc: Include lower DeepCore grid. + include_upper_dc: Include upper DeepCore grid. + string_label: DOM string column in pixel rows. + dom_number_label: DOM index column in pixel rows. + dtype: Tensor dtype for images. + detector: ``IceCube86``; default standardizes all but coordinates. + """ + if detector is None: + detector = IceCube86( + replace_with_identity=input_feature_names, + ) + else: + assert isinstance(detector, IceCube86) + pixel_definition.set_output_feature_names(input_feature_names) + assert ( + string_label in input_feature_names + ), f"String label '{string_label}' not in input feature names" + assert ( + dom_number_label in input_feature_names + ), f"DOM number label '{dom_number_label}' not in input feature names" + + grid_definition = IC86GridDefinition( + detector=detector, + dtype=dtype, + string_label=string_label, + dom_number_label=dom_number_label, + pixel_feature_names=pixel_definition._output_feature_names, + include_lower_dc=include_lower_dc, + include_upper_dc=include_upper_dc, + ) + + super().__init__( + pixel_definition=pixel_definition, + grid_definition=grid_definition, + input_feature_names=input_feature_names, + add_inactive_sensors=False, + **kwargs, + ) + + +class ExamplePrometheusImage(ImageRepresentation): + """Example Prometheus-style single-image layout (tutorial scripts only).""" + + def __init__( + self, + pixel_definition: NodeDefinition, + input_feature_names: List[str], + string_label: str = "sensor_string_id", + dom_number_label: str = "sensor_id", + dtype: Optional[torch.dtype] = torch.float, + detector: Optional[Detector] = None, + **kwargs: Any, + ) -> None: + """Construct `ExamplePrometheusImage`. + + Args: + pixel_definition: Pulse → sensor row features (:class:`NodeDefinition`). + input_feature_names: Raw input column names. + string_label: String id column in pixel rows. + dom_number_label: Sensor id column (internal grid key name). + dtype: Tensor dtype for images. + detector: ``ORCA150`` by default. + """ + if detector is None: + detector = ORCA150( + replace_with_identity=input_feature_names, + ) + + pixel_definition.set_output_feature_names(input_feature_names) + assert ( + string_label in input_feature_names + ), f"String label '{string_label}' not in input feature names" + assert ( + dom_number_label in input_feature_names + ), f"DOM number label '{dom_number_label}' not in input feature names" + + grid_definition = ExamplePrometheusGridDefinition( + detector=detector, + dtype=dtype, + string_label=string_label, + sensor_number_label=dom_number_label, + pixel_feature_names=pixel_definition._output_feature_names, + ) + + super().__init__( + pixel_definition=pixel_definition, + grid_definition=grid_definition, + input_feature_names=input_feature_names, + add_inactive_sensors=False, + **kwargs, + ) diff --git a/src/graphnet/models/data_representation/images/mappings/__init__.py b/src/graphnet/models/data_representation/images/mappings/__init__.py new file mode 100644 index 000000000..a62f8ccd9 --- /dev/null +++ b/src/graphnet/models/data_representation/images/mappings/__init__.py @@ -0,0 +1,7 @@ +"""Detector grids: shapes, lookup tables, scatter into image tensors.""" + +from .grid_definition import ( + ExamplePrometheusGridDefinition, + GridDefinition, + IC86GridDefinition, +) diff --git a/src/graphnet/models/data_representation/images/mappings/cnn_mapping_tables.py b/src/graphnet/models/data_representation/images/mappings/cnn_mapping_tables.py new file mode 100644 index 000000000..aa9035822 --- /dev/null +++ b/src/graphnet/models/data_representation/images/mappings/cnn_mapping_tables.py @@ -0,0 +1,667 @@ +"""CNN image-grid lookup tables for IceCube-86 and the Prometheus example. + +Tensor columns use :data:`MAT_AX0_COL`, :data:`MAT_AX1_COL`, and +:data:`MAT_AX2_COL` for voxel indices. IceCube-86 is built by +:func:`build_ic86_cnn_mapping`; the Prometheus example uses a fixed row list +from the former ``prometheus_CNN_mapping.parquet`` layout. +""" + +from __future__ import annotations + +from typing import Dict, Final, List, Tuple + +import pandas as pd + +MAT_AX0_COL = "mat_ax0" +MAT_AX1_COL = "mat_ax1" +MAT_AX2_COL = "mat_ax2" + +# Fixed (mat_ax0, mat_ax1) placement for IceCube86 main-array strings +# (1..78) on a 10x10 grid. Matches the legacy IC86 CNN mapping table. +_IC86_STRING_TO_AX01: Dict[int, Tuple[int, int]] = { + 1: (9, 4), + 2: (9, 5), + 3: (9, 6), + 4: (9, 7), + 5: (9, 8), + 6: (9, 9), + 7: (8, 3), + 8: (8, 4), + 9: (8, 5), + 10: (8, 6), + 11: (8, 7), + 12: (8, 8), + 13: (8, 9), + 14: (7, 2), + 15: (7, 3), + 16: (7, 4), + 17: (7, 5), + 18: (7, 6), + 19: (7, 7), + 20: (7, 8), + 21: (7, 9), + 22: (6, 1), + 23: (6, 2), + 24: (6, 3), + 25: (6, 4), + 26: (6, 5), + 27: (6, 6), + 28: (6, 7), + 29: (6, 8), + 30: (6, 9), + 31: (5, 0), + 32: (5, 1), + 33: (5, 2), + 34: (5, 3), + 35: (5, 4), + 36: (5, 5), + 37: (5, 6), + 38: (5, 7), + 39: (5, 8), + 40: (5, 9), + 41: (4, 0), + 42: (4, 1), + 43: (4, 2), + 44: (4, 3), + 45: (4, 4), + 46: (4, 5), + 47: (4, 6), + 48: (4, 7), + 49: (4, 8), + 50: (4, 9), + 51: (3, 0), + 52: (3, 1), + 53: (3, 2), + 54: (3, 3), + 55: (3, 4), + 56: (3, 5), + 57: (3, 6), + 58: (3, 7), + 59: (3, 8), + 60: (2, 0), + 61: (2, 1), + 62: (2, 2), + 63: (2, 3), + 64: (2, 4), + 65: (2, 5), + 66: (2, 6), + 67: (2, 7), + 68: (1, 0), + 69: (1, 1), + 70: (1, 2), + 71: (1, 3), + 72: (1, 4), + 73: (1, 5), + 74: (1, 6), + 75: (0, 0), + 76: (0, 1), + 77: (0, 2), + 78: (0, 3), +} + +# Sentinel for DeepCore rows in ``mat_ax2`` (two axes only in sub-images). +_DC_AX2_SENTINEL = -500 + + +def build_ic86_cnn_mapping( + string_label: str, + dom_number_label: str, +) -> pd.DataFrame: + """Build the IceCube-86 CNN pixel mapping DataFrame. + + Places the 78 main-array strings on a fixed 10x10 grid and folds the + eight DeepCore strings (79..86) into upper/lower DeepCore views using + the usual DOM numbering convention. + + Args: + string_label: Column name for IceCube string number (1..86). + dom_number_label: Column name for DOM number (1..60). + + Returns: + DataFrame indexed by (``string_label``, ``dom_number_label``) with + voxel index columns ``mat_ax0``, ``mat_ax1``, and ``mat_ax2``. + """ + rows: List[Tuple[int, int, int, int, float]] = [] + + for string in range(1, 79): + ax0, ax1 = _IC86_STRING_TO_AX01[string] + for dom_number in range(1, 61): + rows.append((string, dom_number, ax0, ax1, float(dom_number - 1))) + + for string in range(79, 87): + ax0 = string - 79 + for dom_number in range(1, 11): + rows.append( + (string, dom_number, ax0, dom_number - 1, _DC_AX2_SENTINEL) + ) + for dom_number in range(11, 61): + rows.append( + (string, dom_number, ax0, dom_number - 11, _DC_AX2_SENTINEL) + ) + + df = pd.DataFrame( + rows, + columns=[ + string_label, + dom_number_label, + MAT_AX0_COL, + MAT_AX1_COL, + MAT_AX2_COL, + ], + ) + df.sort_values( + by=[string_label, dom_number_label], + ascending=[True, True], + inplace=True, + ) + df.set_index([string_label, dom_number_label], inplace=True, drop=False) + return df + + +_Row = Tuple[int, int, int, int, int] + +PROMETHEUS_CNN_MAPPING_ROWS: Final[tuple[_Row, ...]] = ( + (0, 0, 3, 4, 0), + (0, 1, 3, 4, 1), + (0, 2, 3, 4, 2), + (0, 3, 3, 4, 3), + (0, 4, 3, 4, 4), + (0, 5, 3, 4, 5), + (0, 6, 3, 4, 6), + (0, 7, 3, 4, 7), + (0, 8, 3, 4, 8), + (0, 9, 3, 4, 9), + (0, 10, 3, 4, 10), + (0, 11, 3, 4, 11), + (0, 12, 3, 4, 12), + (0, 13, 3, 4, 13), + (0, 14, 3, 4, 14), + (0, 15, 3, 4, 15), + (0, 16, 3, 4, 16), + (0, 17, 3, 4, 17), + (0, 18, 3, 4, 18), + (0, 19, 3, 4, 19), + (0, 20, 3, 4, 20), + (0, 21, 3, 4, 21), + (1, 22, 4, 3, 0), + (1, 23, 4, 3, 1), + (1, 24, 4, 3, 2), + (1, 25, 4, 3, 3), + (1, 26, 4, 3, 4), + (1, 27, 4, 3, 5), + (1, 28, 4, 3, 6), + (1, 29, 4, 3, 7), + (1, 30, 4, 3, 8), + (1, 31, 4, 3, 9), + (1, 32, 4, 3, 10), + (1, 33, 4, 3, 11), + (1, 34, 4, 3, 12), + (1, 35, 4, 3, 13), + (1, 36, 4, 3, 14), + (1, 37, 4, 3, 15), + (1, 38, 4, 3, 16), + (1, 39, 4, 3, 17), + (1, 40, 4, 3, 18), + (1, 41, 4, 3, 19), + (1, 42, 4, 3, 20), + (1, 43, 4, 3, 21), + (2, 44, 3, 5, 0), + (2, 45, 3, 5, 1), + (2, 46, 3, 5, 2), + (2, 47, 3, 5, 3), + (2, 48, 3, 5, 4), + (2, 49, 3, 5, 5), + (2, 50, 3, 5, 6), + (2, 51, 3, 5, 7), + (2, 52, 3, 5, 8), + (2, 53, 3, 5, 9), + (2, 54, 3, 5, 10), + (2, 55, 3, 5, 11), + (2, 56, 3, 5, 12), + (2, 57, 3, 5, 13), + (2, 58, 3, 5, 14), + (2, 59, 3, 5, 15), + (2, 60, 3, 5, 16), + (2, 61, 3, 5, 17), + (2, 62, 3, 5, 18), + (2, 63, 3, 5, 19), + (2, 64, 3, 5, 20), + (2, 65, 3, 5, 21), + (3, 66, 3, 3, 0), + (3, 67, 3, 3, 1), + (3, 68, 3, 3, 2), + (3, 69, 3, 3, 3), + (3, 70, 3, 3, 4), + (3, 71, 3, 3, 5), + (3, 72, 3, 3, 6), + (3, 73, 3, 3, 7), + (3, 74, 3, 3, 8), + (3, 75, 3, 3, 9), + (3, 76, 3, 3, 10), + (3, 77, 3, 3, 11), + (3, 78, 3, 3, 12), + (3, 79, 3, 3, 13), + (3, 80, 3, 3, 14), + (3, 81, 3, 3, 15), + (3, 82, 3, 3, 16), + (3, 83, 3, 3, 17), + (3, 84, 3, 3, 18), + (3, 85, 3, 3, 19), + (3, 86, 3, 3, 20), + (3, 87, 3, 3, 21), + (4, 88, 4, 4, 0), + (4, 89, 4, 4, 1), + (4, 90, 4, 4, 2), + (4, 91, 4, 4, 3), + (4, 92, 4, 4, 4), + (4, 93, 4, 4, 5), + (4, 94, 4, 4, 6), + (4, 95, 4, 4, 7), + (4, 96, 4, 4, 8), + (4, 97, 4, 4, 9), + (4, 98, 4, 4, 10), + (4, 99, 4, 4, 11), + (4, 100, 4, 4, 12), + (4, 101, 4, 4, 13), + (4, 102, 4, 4, 14), + (4, 103, 4, 4, 15), + (4, 104, 4, 4, 16), + (4, 105, 4, 4, 17), + (4, 106, 4, 4, 18), + (4, 107, 4, 4, 19), + (4, 108, 4, 4, 20), + (4, 109, 4, 4, 21), + (5, 110, 2, 4, 0), + (5, 111, 2, 4, 1), + (5, 112, 2, 4, 2), + (5, 113, 2, 4, 3), + (5, 114, 2, 4, 4), + (5, 115, 2, 4, 5), + (5, 116, 2, 4, 6), + (5, 117, 2, 4, 7), + (5, 118, 2, 4, 8), + (5, 119, 2, 4, 9), + (5, 120, 2, 4, 10), + (5, 121, 2, 4, 11), + (5, 122, 2, 4, 12), + (5, 123, 2, 4, 13), + (5, 124, 2, 4, 14), + (5, 125, 2, 4, 15), + (5, 126, 2, 4, 16), + (5, 127, 2, 4, 17), + (5, 128, 2, 4, 18), + (5, 129, 2, 4, 19), + (5, 130, 2, 4, 20), + (5, 131, 2, 4, 21), + (6, 132, 4, 2, 0), + (6, 133, 4, 2, 1), + (6, 134, 4, 2, 2), + (6, 135, 4, 2, 3), + (6, 136, 4, 2, 4), + (6, 137, 4, 2, 5), + (6, 138, 4, 2, 6), + (6, 139, 4, 2, 7), + (6, 140, 4, 2, 8), + (6, 141, 4, 2, 9), + (6, 142, 4, 2, 10), + (6, 143, 4, 2, 11), + (6, 144, 4, 2, 12), + (6, 145, 4, 2, 13), + (6, 146, 4, 2, 14), + (6, 147, 4, 2, 15), + (6, 148, 4, 2, 16), + (6, 149, 4, 2, 17), + (6, 150, 4, 2, 18), + (6, 151, 4, 2, 19), + (6, 152, 4, 2, 20), + (6, 153, 4, 2, 21), + (7, 154, 4, 5, 0), + (7, 155, 4, 5, 1), + (7, 156, 4, 5, 2), + (7, 157, 4, 5, 3), + (7, 158, 4, 5, 4), + (7, 159, 4, 5, 5), + (7, 160, 4, 5, 6), + (7, 161, 4, 5, 7), + (7, 162, 4, 5, 8), + (7, 163, 4, 5, 9), + (7, 164, 4, 5, 10), + (7, 165, 4, 5, 11), + (7, 166, 4, 5, 12), + (7, 167, 4, 5, 13), + (7, 168, 4, 5, 14), + (7, 169, 4, 5, 15), + (7, 170, 4, 5, 16), + (7, 171, 4, 5, 17), + (7, 172, 4, 5, 18), + (7, 173, 4, 5, 19), + (7, 174, 4, 5, 20), + (7, 175, 4, 5, 21), + (8, 176, 2, 3, 0), + (8, 177, 2, 3, 1), + (8, 178, 2, 3, 2), + (8, 179, 2, 3, 3), + (8, 180, 2, 3, 4), + (8, 181, 2, 3, 5), + (8, 182, 2, 3, 6), + (8, 183, 2, 3, 7), + (8, 184, 2, 3, 8), + (8, 185, 2, 3, 9), + (8, 186, 2, 3, 10), + (8, 187, 2, 3, 11), + (8, 188, 2, 3, 12), + (8, 189, 2, 3, 13), + (8, 190, 2, 3, 14), + (8, 191, 2, 3, 15), + (8, 192, 2, 3, 16), + (8, 195, 2, 3, 19), + (8, 196, 2, 3, 20), + (8, 197, 2, 3, 21), + (9, 198, 5, 3, 0), + (9, 199, 5, 3, 1), + (9, 200, 5, 3, 2), + (9, 202, 5, 3, 4), + (9, 203, 5, 3, 5), + (9, 204, 5, 3, 6), + (9, 205, 5, 3, 7), + (9, 206, 5, 3, 8), + (9, 207, 5, 3, 9), + (9, 208, 5, 3, 10), + (9, 209, 5, 3, 11), + (9, 210, 5, 3, 12), + (9, 211, 5, 3, 13), + (9, 212, 5, 3, 14), + (9, 213, 5, 3, 15), + (9, 214, 5, 3, 16), + (9, 217, 5, 3, 19), + (9, 218, 5, 3, 20), + (9, 219, 5, 3, 21), + (10, 220, 2, 5, 0), + (10, 222, 2, 5, 2), + (10, 224, 2, 5, 4), + (10, 225, 2, 5, 5), + (10, 228, 2, 5, 8), + (10, 229, 2, 5, 9), + (10, 230, 2, 5, 10), + (10, 231, 2, 5, 11), + (10, 233, 2, 5, 13), + (10, 235, 2, 5, 15), + (10, 236, 2, 5, 16), + (10, 239, 2, 5, 19), + (10, 240, 2, 5, 20), + (10, 241, 2, 5, 21), + (11, 242, 3, 2, 0), + (11, 244, 3, 2, 2), + (11, 245, 3, 2, 3), + (11, 246, 3, 2, 4), + (11, 247, 3, 2, 5), + (11, 250, 3, 2, 8), + (11, 252, 3, 2, 10), + (11, 253, 3, 2, 11), + (11, 255, 3, 2, 13), + (11, 256, 3, 2, 14), + (11, 257, 3, 2, 15), + (11, 258, 3, 2, 16), + (11, 261, 3, 2, 19), + (11, 263, 3, 2, 21), + (12, 264, 5, 5, 0), + (12, 266, 5, 5, 2), + (12, 268, 5, 5, 4), + (12, 269, 5, 5, 5), + (12, 271, 5, 5, 7), + (12, 272, 5, 5, 8), + (12, 274, 5, 5, 10), + (12, 275, 5, 5, 11), + (12, 277, 5, 5, 13), + (12, 279, 5, 5, 15), + (12, 280, 5, 5, 16), + (12, 283, 5, 5, 19), + (12, 285, 5, 5, 21), + (13, 286, 1, 4, 0), + (13, 288, 1, 4, 2), + (13, 289, 1, 4, 3), + (13, 290, 1, 4, 4), + (13, 291, 1, 4, 5), + (13, 294, 1, 4, 8), + (13, 296, 1, 4, 10), + (13, 297, 1, 4, 11), + (13, 299, 1, 4, 13), + (13, 301, 1, 4, 15), + (13, 302, 1, 4, 16), + (13, 305, 1, 4, 19), + (14, 308, 5, 2, 0), + (14, 310, 5, 2, 2), + (14, 312, 5, 2, 4), + (14, 313, 5, 2, 5), + (14, 316, 5, 2, 8), + (14, 319, 5, 2, 11), + (14, 321, 5, 2, 13), + (14, 322, 5, 2, 14), + (14, 323, 5, 2, 15), + (14, 324, 5, 2, 16), + (14, 327, 5, 2, 19), + (14, 329, 5, 2, 21), + (15, 330, 3, 6, 0), + (15, 332, 3, 6, 2), + (15, 333, 3, 6, 3), + (15, 334, 3, 6, 4), + (15, 335, 3, 6, 5), + (15, 338, 3, 6, 8), + (15, 341, 3, 6, 11), + (15, 343, 3, 6, 13), + (15, 345, 3, 6, 15), + (15, 346, 3, 6, 16), + (15, 349, 3, 6, 19), + (16, 352, 2, 2, 0), + (16, 354, 2, 2, 2), + (16, 356, 2, 2, 4), + (16, 357, 2, 2, 5), + (16, 360, 2, 2, 8), + (16, 362, 2, 2, 10), + (16, 363, 2, 2, 11), + (16, 365, 2, 2, 13), + (16, 366, 2, 2, 14), + (16, 367, 2, 2, 15), + (16, 368, 2, 2, 16), + (16, 369, 2, 2, 17), + (16, 371, 2, 2, 19), + (17, 374, 5, 4, 0), + (17, 376, 5, 4, 2), + (17, 378, 5, 4, 4), + (17, 379, 5, 4, 5), + (17, 382, 5, 4, 8), + (17, 385, 5, 4, 11), + (17, 387, 5, 4, 13), + (17, 388, 5, 4, 14), + (17, 389, 5, 4, 15), + (17, 390, 5, 4, 16), + (17, 393, 5, 4, 19), + (18, 396, 1, 6, 0), + (18, 398, 1, 6, 2), + (18, 400, 1, 6, 4), + (18, 401, 1, 6, 5), + (18, 404, 1, 6, 8), + (18, 406, 1, 6, 10), + (18, 407, 1, 6, 11), + (18, 409, 1, 6, 13), + (18, 411, 1, 6, 15), + (18, 412, 1, 6, 16), + (18, 413, 1, 6, 17), + (18, 415, 1, 6, 19), + (19, 418, 4, 1, 0), + (19, 419, 4, 1, 1), + (19, 420, 4, 1, 2), + (19, 421, 4, 1, 3), + (19, 422, 4, 1, 4), + (19, 423, 4, 1, 5), + (19, 426, 4, 1, 8), + (19, 428, 4, 1, 10), + (19, 429, 4, 1, 11), + (19, 433, 4, 1, 15), + (19, 434, 4, 1, 16), + (19, 437, 4, 1, 19), + (20, 440, 4, 6, 0), + (20, 442, 4, 6, 2), + (20, 444, 4, 6, 4), + (20, 445, 4, 6, 5), + (20, 448, 4, 6, 8), + (20, 450, 4, 6, 10), + (20, 451, 4, 6, 11), + (20, 453, 4, 6, 13), + (20, 454, 4, 6, 14), + (20, 455, 4, 6, 15), + (20, 456, 4, 6, 16), + (20, 459, 4, 6, 19), + (21, 462, 1, 3, 0), + (21, 463, 1, 3, 1), + (21, 464, 1, 3, 2), + (21, 465, 1, 3, 3), + (21, 466, 1, 3, 4), + (21, 467, 1, 3, 5), + (21, 470, 1, 3, 8), + (21, 473, 1, 3, 11), + (21, 477, 1, 3, 15), + (21, 478, 1, 3, 16), + (21, 481, 1, 3, 19), + (22, 484, 6, 2, 0), + (22, 486, 6, 2, 2), + (22, 488, 6, 2, 4), + (22, 489, 6, 2, 5), + (22, 492, 6, 2, 8), + (22, 495, 6, 2, 11), + (22, 496, 6, 2, 12), + (22, 497, 6, 2, 13), + (22, 498, 6, 2, 14), + (22, 499, 6, 2, 15), + (22, 500, 6, 2, 16), + (22, 503, 6, 2, 19), + (23, 506, 2, 6, 0), + (23, 510, 2, 6, 4), + (23, 511, 2, 6, 5), + (23, 514, 2, 6, 8), + (23, 517, 2, 6, 11), + (23, 520, 2, 6, 14), + (23, 521, 2, 6, 15), + (23, 522, 2, 6, 16), + (23, 525, 2, 6, 19), + (24, 528, 3, 1, 0), + (24, 531, 3, 1, 3), + (24, 532, 3, 1, 4), + (24, 536, 3, 1, 8), + (24, 539, 3, 1, 11), + (24, 543, 3, 1, 15), + (24, 547, 3, 1, 19), + (25, 550, 6, 5, 0), + (25, 554, 6, 5, 4), + (25, 558, 6, 5, 8), + (25, 561, 6, 5, 11), + (25, 565, 6, 5, 15), + (25, 569, 6, 5, 19), + (26, 572, 1, 5, 0), + (26, 576, 1, 5, 4), + (26, 580, 1, 5, 8), + (26, 583, 1, 5, 11), + (26, 587, 1, 5, 15), + (26, 591, 1, 5, 19), + (27, 594, 5, 1, 0), + (27, 598, 5, 1, 4), + (27, 602, 5, 1, 8), + (27, 605, 5, 1, 11), + (27, 609, 5, 1, 15), + (27, 613, 5, 1, 19), + (28, 616, 3, 7, 0), + (28, 620, 3, 7, 4), + (28, 624, 3, 7, 8), + (28, 627, 3, 7, 11), + (28, 631, 3, 7, 15), + (28, 635, 3, 7, 19), + (29, 638, 1, 2, 0), + (29, 642, 1, 2, 4), + (29, 646, 1, 2, 8), + (29, 649, 1, 2, 11), + (29, 653, 1, 2, 15), + (29, 657, 1, 2, 19), + (30, 660, 6, 3, 0), + (30, 664, 6, 3, 4), + (30, 668, 6, 3, 8), + (30, 671, 6, 3, 11), + (30, 675, 6, 3, 15), + (30, 679, 6, 3, 19), + (31, 682, 1, 7, 0), + (31, 686, 1, 7, 4), + (31, 690, 1, 7, 8), + (31, 693, 1, 7, 11), + (31, 697, 1, 7, 15), + (31, 701, 1, 7, 19), + (32, 704, 4, 0, 0), + (32, 708, 4, 0, 4), + (32, 712, 4, 0, 8), + (32, 715, 4, 0, 11), + (32, 719, 4, 0, 15), + (32, 723, 4, 0, 19), + (33, 726, 5, 6, 0), + (33, 730, 5, 6, 4), + (33, 734, 5, 6, 8), + (33, 737, 5, 6, 11), + (33, 741, 5, 6, 15), + (33, 745, 5, 6, 19), + (34, 748, 0, 4, 0), + (34, 752, 0, 4, 4), + (34, 756, 0, 4, 8), + (34, 759, 0, 4, 11), + (34, 763, 0, 4, 15), + (34, 767, 0, 4, 19), + (35, 770, 6, 1, 0), + (35, 774, 6, 1, 4), + (35, 778, 6, 1, 8), + (35, 781, 6, 1, 11), + (35, 785, 6, 1, 15), + (35, 789, 6, 1, 19), + (36, 792, 2, 7, 0), + (36, 796, 2, 7, 4), + (36, 800, 2, 7, 8), + (36, 803, 2, 7, 11), + (36, 807, 2, 7, 15), + (36, 811, 2, 7, 19), + (37, 814, 2, 1, 0), + (37, 818, 2, 1, 4), + (37, 822, 2, 1, 8), + (37, 825, 2, 1, 11), + (37, 829, 2, 1, 15), + (37, 833, 2, 1, 19), + (38, 836, 6, 4, 0), + (38, 840, 6, 4, 4), + (38, 844, 6, 4, 8), + (38, 847, 6, 4, 11), + (38, 851, 6, 4, 15), + (38, 855, 6, 4, 19), + (39, 862, 0, 6, 4), + (39, 866, 0, 6, 8), + (39, 873, 0, 6, 15), + (39, 877, 0, 6, 19), + (40, 884, 5, 0, 4), + (40, 888, 5, 0, 8), + (40, 895, 5, 0, 15), + (40, 899, 5, 0, 19), + (41, 906, 4, 7, 4), + (41, 910, 4, 7, 8), + (41, 917, 4, 7, 15), + (41, 921, 4, 7, 19), + (42, 928, 0, 3, 4), + (42, 932, 0, 3, 8), + (42, 939, 0, 3, 15), + (42, 943, 0, 3, 19), + (43, 950, 7, 2, 4), + (43, 961, 7, 2, 15), + (44, 972, 1, 8, 4), + (44, 983, 1, 8, 15), + (45, 994, 3, 0, 4), + (45, 1005, 3, 0, 15), + (46, 1016, 6, 6, 4), + (46, 1027, 6, 6, 15), + (47, 1038, 0, 5, 4), + (47, 1049, 0, 5, 15), + (48, 1060, 6, 0, 4), + (48, 1071, 6, 0, 15), + (49, 1082, 3, 8, 4), +) diff --git a/src/graphnet/models/data_representation/images/mappings/grid_definition.py b/src/graphnet/models/data_representation/images/mappings/grid_definition.py new file mode 100644 index 000000000..046b1d9b4 --- /dev/null +++ b/src/graphnet/models/data_representation/images/mappings/grid_definition.py @@ -0,0 +1,465 @@ +"""Detector-specific grid layouts: lookup tables, shapes, and scatter into tensors. + +Each :class:`GridDefinition` is bound to a :class:`~graphnet.models.detector.Detector`. +:class:`~graphnet.models.data_representation.images.image_representation.ImageRepresentation` +calls :meth:`GridDefinition.forward` to place pixel rows into image tensor(s). +""" + +from abc import abstractmethod +from typing import List + +import numpy as np +import pandas as pd +import torch +from torch_geometric.data import Data + +from graphnet.models import Model +from graphnet.models.detector import Detector + +from .cnn_mapping_tables import ( + MAT_AX0_COL, + MAT_AX1_COL, + MAT_AX2_COL, + PROMETHEUS_CNN_MAPPING_ROWS, + build_ic86_cnn_mapping, +) + + +class GridDefinition(Model): + """Detector-specific orthonormal image grid(s). + + Holds tensor shapes and tables that map pixel keys to voxel indices. + """ + + def __init__( + self, + detector: Detector, + pixel_feature_names: List[str], + ) -> None: + """Construct `GridDefinition`. + + Args: + detector: Geometry this grid is defined for (CNN grids are + detector-specific). + pixel_feature_names: Column names expected on each pixel row, + including keys listed in :attr:`map_pixels_by`. + """ + super().__init__(name=__name__, class_name=self.__class__.__name__) + self._detector = detector + self._set_image_feature_names(pixel_feature_names) + + @property + def detector(self) -> Detector: + """Detector instance this grid belongs to.""" + return self._detector + + @property + @abstractmethod + def map_pixels_by(self) -> List[str]: + """Feature columns that join pixel rows to :meth:`mappings`.""" + + @abstractmethod + def mappings(self) -> List[pd.DataFrame]: + """DataFrame(s) keyed by ``map_pixels_by`` with voxel index columns.""" + + @abstractmethod + def forward(self, data: Data, data_feature_names: List[str]) -> Data: + """Scatter pixel features into shaped tensors (see subclasses).""" + + @abstractmethod + def _set_image_feature_names(self, input_feature_names: List[str]) -> None: + """Set the final image feature names.""" + raise NotImplementedError + + @property + @abstractmethod + def shape( + self, + ) -> List[List[int]]: + """Return the shape of the output images as a list of tuples. + + In the dimensions (F,D,H,W) where F is the number of features + per pixel. And D,H,W are the dimension of the image + """ + pass + + +class IC86GridDefinition(GridDefinition): + """IceCube-86 CNN grid (main array + DeepCore) layouts and lookup table.""" + + def __init__( + self, + detector: Detector, + dtype: torch.dtype, + pixel_feature_names: List[str], + string_label: str = "string", + dom_number_label: str = "dom_number", + include_main_array: bool = True, + include_lower_dc: bool = True, + include_upper_dc: bool = True, + ) -> None: + """Construct `IC86GridDefinition`. + + The mapping from (string, dom_number) to a position in the + resulting images is generated programmatically at instantiation + time from the IceCube86 detector geometry, so no auxiliary file + is required. + + Args: + detector: ``IceCube86`` instance (grid is fixed to that geometry). + dtype: data type used for node features. e.g. ´torch.float´ + string_label: Name of the feature corresponding + to the DOM string number. Values Integers between 1 - 86 + dom_number_label: Name of the feature corresponding + to the DOM number (1 - 60). Values Integers between 1 - 60 + where 1 is the dom with the highest z coordinate. + pixel_feature_names: Names of each column in expected input data + that will be built into a image. + include_main_array: If True, the main array will be included. + include_lower_dc: If True, the lower DeepCore will be included. + include_upper_dc: If True, the upper DeepCore will be included. + + Raises: + ValueError: If no array type is included. + + NOTE: Expects input data to be DOMs with aggregated features. + """ + if not np.any( + [include_main_array, include_lower_dc, include_upper_dc] + ): + raise ValueError("Include at least one array type.") + + self._dtype = dtype + self._string_label = string_label + self._dom_number_label = dom_number_label + self._pixel_feature_names = pixel_feature_names + + self._set_indices(pixel_feature_names, dom_number_label, string_label) + + self._nb_cnn_features = ( + len(pixel_feature_names) - 2 + ) # 2 for string and dom_number + + self._include_main_array = include_main_array + self._include_lower_dc = include_lower_dc + self._include_upper_dc = include_upper_dc + + self._mapping = build_ic86_cnn_mapping( + string_label=string_label, + dom_number_label=dom_number_label, + ) + super().__init__( + detector=detector, pixel_feature_names=pixel_feature_names + ) + + @property + def map_pixels_by(self) -> List[str]: + """String and DOM identifiers used for voxel lookup.""" + return [self._string_label, self._dom_number_label] + + def mappings(self) -> List[pd.DataFrame]: + """Return the single combined lookup table for all sub-images.""" + return [self._mapping] + + def _set_indices( + self, + feature_names: List[str], + dom_number_label: str, + string_label: str, + ) -> None: + """Set the indices for the features.""" + self._cnn_features_idx = [] + for feature in feature_names: + if feature == dom_number_label: + self._dom_number_idx = feature_names.index(feature) + elif feature == string_label: + self._string_idx = feature_names.index(feature) + else: + self._cnn_features_idx.append(feature_names.index(feature)) + + def forward(self, data: Data, data_feature_names: List[str]) -> Data: + """Scatter pixel rows into IceCube-86 image tensor(s).""" + # Initialize output arrays + if self._include_main_array: + main_arr = torch.zeros( + (self._nb_cnn_features, 10, 10, 60), + dtype=self._dtype, + ) + if self._include_upper_dc: + upper_dc_arr = torch.zeros( + (self._nb_cnn_features, 8, 10), + dtype=self._dtype, + ) + if self._include_lower_dc: + lower_dc_arr = torch.zeros( + (self._nb_cnn_features, 8, 50), + dtype=self._dtype, + ) + + # data.x is a tensor with shape (N, F) where N is the number of + # pixels (DOMs) and F is the number of features. Each row + # represents a single pixel. + x = data.x + + string_dom_number = x[ + :, [self._string_idx, self._dom_number_idx] + ].int() + batch_row_features = x[:, self._cnn_features_idx] + + # Look up the pixel position in each sub-image for every (string, + # dom_number) pair. Columns are referenced via the configurable + # `string_label`/`dom_number_label` and the internal axis-column + # constants so the lookup does not depend on free-form strings. + match_indices = self._mapping.loc[ + zip(*string_dom_number.t().tolist()) + ][ + [ + self._string_label, + self._dom_number_label, + MAT_AX0_COL, + MAT_AX1_COL, + MAT_AX2_COL, + ] + ].values.astype( + int + ) + + # Copy CNN features to the appropriate arrays + for i, row in enumerate(match_indices): + # Select appropriate array and indexing + if row[0] < 79: # Main Array + if self._include_main_array: + main_arr[ + :, + row[2], # mat_ax0 + row[3], # mat_ax1 + row[4], # mat_ax2 + ] = batch_row_features[i] + + elif row[1] < 11: # Upper DeepCore + if self._include_upper_dc: + upper_dc_arr[ + :, + row[2], # mat_ax0 + row[3], # mat_ax1 + ] = batch_row_features[i] + + else: # Lower DeepCore + if self._include_lower_dc: + lower_dc_arr[ + :, + row[2], # mat_ax0 + row[3], # mat_ax1 + ] = batch_row_features[i] + + # unsqueeze to add dimension for batching + # with collate_fn Batch.from_data_list + ret: List[torch.Tensor] = [] + if self._include_main_array: + ret.append(main_arr.unsqueeze(0)) + if self._include_upper_dc: + ret.append(upper_dc_arr.unsqueeze(0)) + if self._include_lower_dc: + ret.append(lower_dc_arr.unsqueeze(0)) + + # Set list of images as data.x + data.x = ret + return data + + def _set_image_feature_names(self, input_feature_names: List[str]) -> None: + """Set the final output feature names.""" + # string and dom_number are only used for mapping + # and will not be included in the output features. + self.image_feature_names = [ + infeature + for infeature in input_feature_names + if infeature not in [self._string_label, self._dom_number_label] + ] + + @property + def shape( + self, + ) -> List[List[int]]: + """Return the shape of the output images as a list of tuples.""" + ret = [] + if self._include_main_array: + ret.append([self._nb_cnn_features, 10, 10, 60]) + if self._include_upper_dc: + ret.append([self._nb_cnn_features, 1, 8, 10]) + if self._include_lower_dc: + ret.append([self._nb_cnn_features, 1, 8, 50]) + return ret + + +class ExamplePrometheusGridDefinition(GridDefinition): + """Example single-image grid for Prometheus-style layouts.""" + + def __init__( + self, + detector: Detector, + dtype: torch.dtype, + pixel_feature_names: List[str], + string_label: str = "sensor_string_id", + sensor_number_label: str = "sensor_id", + ) -> None: + """Construct grid. + + Args: + detector: Typically ``ORCA150`` in the example scripts. + dtype: data type used for node features. e.g. ´torch.float´ + string_label: Name of the feature corresponding + to the sensor string number. + sensor_number_label: Name of the feature corresponding + to the sensor number + pixel_feature_names: Names of each column in expected input data + that will be built into a image. + + Raises: + ValueError: If no array type is included. + + NOTE: Expects input data to be sensors with aggregated features. + """ + self._dtype = dtype + self._string_label = string_label + self._sensor_number_label = sensor_number_label + self._pixel_feature_names = pixel_feature_names + + self._set_indices( + pixel_feature_names, sensor_number_label, string_label + ) + + self._nb_cnn_features = ( + len(pixel_feature_names) - 2 + ) # 2 for string and sensor number + + # Hand-crafted layout for the example ORCA image; stored as an + # embedded table (was ``prometheus_CNN_mapping.parquet``). + df = pd.DataFrame( + PROMETHEUS_CNN_MAPPING_ROWS, + columns=[ + "sensor_string_id", + "sensor_id", + MAT_AX0_COL, + MAT_AX1_COL, + MAT_AX2_COL, + ], + ) + df.rename( + columns={ + "sensor_string_id": string_label, + "sensor_id": sensor_number_label, + }, + inplace=True, + ) + df.sort_values( + by=[string_label, sensor_number_label], + ascending=[True, True], + inplace=True, + ) + + df.set_index( + [string_label, sensor_number_label], + inplace=True, + drop=False, + ) + + self._mapping = df + super().__init__( + detector=detector, pixel_feature_names=pixel_feature_names + ) + + @property + def map_pixels_by(self) -> List[str]: + """String and sensor identifiers used for voxel lookup.""" + return [self._string_label, self._sensor_number_label] + + def mappings(self) -> List[pd.DataFrame]: + """Return the embedded lookup table for the example layout.""" + return [self._mapping] + + def _set_indices( + self, + feature_names: List[str], + sensor_number_label: str, + string_label: str, + ) -> None: + """Set the indices for the features.""" + self._cnn_features_idx = [] + for feature in feature_names: + if feature == sensor_number_label: + self._sensor_number_idx = feature_names.index(feature) + elif feature == string_label: + self._string_idx = feature_names.index(feature) + else: + self._cnn_features_idx.append(feature_names.index(feature)) + + def forward(self, data: Data, data_feature_names: List[str]) -> Data: + """Scatter pixel rows into the example 3D image tensor.""" + # Initialize output arrays + image_tensor = torch.zeros( + (self._nb_cnn_features, 8, 9, 22), + dtype=self._dtype, + ) + + # data.x is expected to be a tensor with shape (N, F) + # where N is the number of nodes and F is the number of features. + x = data.x + + # Direct coordinate and feature extraction + string_sensor_number = x[ + :, [self._string_idx, self._sensor_number_idx] + ].int() + batch_row_features = x[:, self._cnn_features_idx] + + # Look up the pixel position in the image for every (string, + # sensor_id) pair. Column references go through the configurable + # labels and the internal axis-column constants so this method + # does not rely on hard-coded column names from the data file. + match_indices = self._mapping.loc[ + zip(*string_sensor_number.t().tolist()) + ][ + [ + self._string_label, + self._sensor_number_label, + MAT_AX0_COL, + MAT_AX1_COL, + MAT_AX2_COL, + ] + ].values.astype( + int + ) + + # Copy CNN features to the appropriate arrays + for i, row in enumerate(match_indices): + # Select appropriate array and indexing + image_tensor[ + :, + row[2], # mat_ax0 + row[3], # mat_ax1 + row[4], # mat_ax2 + ] = batch_row_features[i] + + # unsqueeze to add dimension for batching + # with collate_fn Batch.from_data_list + ret: List[torch.Tensor] = [image_tensor.unsqueeze(0)] + + # Set list of images as data.x + data.x = ret + return data + + def _set_image_feature_names(self, input_feature_names: List[str]) -> None: + """Set the final output feature names.""" + # string and sensor_number are only used for mapping + # and will not be included in the output features. + self.image_feature_names = [ + infeature + for infeature in input_feature_names + if infeature not in [self._string_label, self._sensor_number_label] + ] + + @property + def shape( + self, + ) -> List[List[int]]: + """Return the shape of the output images as a list of tuples.""" + return [[self._nb_cnn_features, 8, 9, 22]] diff --git a/tests/models/ic86_grid_testdata.py b/tests/models/ic86_grid_testdata.py new file mode 100644 index 000000000..af204315e --- /dev/null +++ b/tests/models/ic86_grid_testdata.py @@ -0,0 +1,26 @@ +"""Synthetic IC86 pixel tables for CNN grid tests (no parquet).""" + +import numpy as np + + +def ic86_full_detector_pixel_table() -> np.ndarray: + """Return all IceCube-86 DOM rows with redundant id columns. + + Columns are ``string``, ``dom_number``, ``redundant_string``, + ``redundant_dom_number`` (``redundant_*`` copy ``string`` / ``dom_number``), + in row-major order over strings ``1..86`` and DOMs ``1..60``. + + Shape is ``(5160, 4)`` with ``float32`` values, matching the tables that + were previously loaded from packaged parquet. + """ + s = np.repeat(np.arange(1, 87, dtype=np.float32), 60) + d = np.tile(np.arange(1, 61, dtype=np.float32), 86) + return np.column_stack([s, d, s.copy(), d.copy()]) + + +IC86_TEST_PIXEL_COLUMNS = [ + "string", + "dom_number", + "redundant_string", + "redundant_dom_number", +] diff --git a/tests/models/test_grid_definition.py b/tests/models/test_grid_definition.py new file mode 100644 index 000000000..51079cc5c --- /dev/null +++ b/tests/models/test_grid_definition.py @@ -0,0 +1,185 @@ +"""Tests for IceCube-86 :class:`IC86GridDefinition`.""" + +import numpy as np +import torch +from torch_geometric.data import Data +from copy import deepcopy +from graphnet.models.data_representation.images import IC86GridDefinition +from graphnet.models.detector import IceCube86 +from graphnet.constants import ( + TEST_IC86MAIN_IMAGE, + TEST_IC86UPPERDC_IMAGE, + TEST_IC86LOWERDC_IMAGE, +) +import pytest + +from tests.models.ic86_grid_testdata import ( + IC86_TEST_PIXEL_COLUMNS, + ic86_full_detector_pixel_table, +) + + +def basic_checks_picture(picture: Data, dtype: torch.dtype) -> None: + """Basic checks for grid scatter output.""" + assert isinstance( + picture, Data + ), f"Output should be a Data object got {type(picture)}" + assert isinstance( + picture.x, list + ), f"x should be a list of tensors got {type(picture.x)}" + assert np.all( + [isinstance(picture.x[i], torch.Tensor) for i in range(len(picture.x))] + ), ( + "All tensors in x should be torch.Tensors", + f"got {[type(picture.x[i]) for i in range(len(picture.x))]}", + ) + assert np.all( + [picture.x[i].dtype == dtype for i in range(len(picture.x))] + ), ( + "All tensors in x should have the dtype specified on the grid", + f"got {[picture.x[i].dtype for i in range(len(picture.x))]}", + ) + + +def test_ic86_grid_definition() -> None: + """End-to-end scatter for IC86 main + DeepCore grids.""" + dtype = torch.float32 + pixel_feature_names = ["string", "dom_number", "data1", "data2"] + string_label = "string" + dom_number_label = "dom_number" + + dummy_data = Data( + x=torch.tensor( + [[1, 2, 5.8, 1e-4], [79, 46, 3.7, 1e-18], [84, 9, 6.87, 2e5]], + dtype=dtype, + ), + ) + + detector = IceCube86(replace_with_identity=pixel_feature_names) + grid_definition = IC86GridDefinition( + detector=detector, + dtype=dtype, + pixel_feature_names=pixel_feature_names, + string_label=string_label, + dom_number_label=dom_number_label, + include_lower_dc=True, + include_upper_dc=True, + ) + + picture = grid_definition(dummy_data, pixel_feature_names) + new_features = grid_definition.image_feature_names + n_features = len(new_features) + + basic_checks_picture(picture, dtype) + + assert ( + len(grid_definition.shape) == 3 + ), f"Expected shape to be 3 got {len(grid_definition.shape)}" + assert grid_definition.shape == [ + [n_features, 10, 10, 60], + [n_features, 1, 8, 10], + [n_features, 1, 8, 50], + ], ( + f"Expected shape to be [[{n_features},10,10,60], " + f"[{n_features},1,8,10], [{n_features},1,8,50]] got " + f"{grid_definition.shape}" + ) + assert isinstance( + new_features, list + ), f"Output should be a list of feature names got {type(new_features)}" + assert new_features == [ + "data1", + "data2", + ], f"Expected feature to be ['data1', 'data2'] names got: {new_features}" + assert len(picture.x) == 3, ( + "There should be three tensors in x ", + f"got list with length {len(picture.x)}" + "(main array, upper DeepCore, lower DeepCore)", + ) + assert picture.x[0].size() == torch.Size( + [1, 2, 10, 10, 60] + ), f"Main array should have shape (1,2,10,10,60) got {picture.x[0].size()}" + assert picture.x[1].size() == torch.Size( + [1, 2, 8, 10] + ), f"upper DeepCore should have shape (1,2,8,10) got {picture.x[1].size()}" + assert picture.x[2].size() == torch.Size( + [1, 2, 8, 50] + ), f"lower DeepCore should have shape (1,2,8,50) got {picture.x[2].size()}" + assert not torch.all( + picture.x[0] == 0 + ), "Main array should not be all zeros, got all zeros." + assert not torch.all( + picture.x[1] == 0 + ), "Upper DeepCore should not be all zeros, got all zeros." + assert not torch.all( + picture.x[2] == 0 + ), "Lower DeepCore should not be all zeros, got all zeros." + + dummy_data = Data( + x=torch.tensor( + [ + [100, 5, 5.8, 1e-4], + [54, 230, 3.7, 1e-18], + [1294, 500, 6.87, 2e5], + ], + dtype=dtype, + ), + ) + + with pytest.raises(KeyError): + grid_definition(dummy_data, pixel_feature_names) + + +def test_ic86_grid_segments() -> None: + """IC86 grid: single sub-image at a time vs reference arrays.""" + dtype = torch.float32 + string_label = "string" + dom_number_label = "dom_number" + pixel_feature_names = IC86_TEST_PIXEL_COLUMNS + + grid_tensor = Data( + x=torch.tensor(ic86_full_detector_pixel_table(), dtype=dtype) + ) + + detector = IceCube86(replace_with_identity=pixel_feature_names) + + for image, inc_main, inc_upc, inc_lowdc, label in zip( + [TEST_IC86MAIN_IMAGE, TEST_IC86UPPERDC_IMAGE, TEST_IC86LOWERDC_IMAGE], + [True, False, False], + [False, True, False], + [False, False, True], + ["main array", "upper deepcore", "lower deepcore"], + ): + tmp = deepcopy(grid_tensor) + grid_definition = IC86GridDefinition( + detector=detector, + dtype=dtype, + pixel_feature_names=pixel_feature_names, + string_label=string_label, + dom_number_label=dom_number_label, + include_main_array=inc_main, + include_lower_dc=inc_lowdc, + include_upper_dc=inc_upc, + ) + picture = grid_definition(tmp, pixel_feature_names) + tensor_image: torch.Tensor = torch.tensor( + np.load(image), dtype=dtype + ).unsqueeze(0) + + basic_checks_picture(picture, dtype) + + assert len(picture.x) == 1, ( + "There should be one tensor in x ", + f"got list with length {len(picture.x)}", + ) + assert picture.x[0].size() == tensor_image.size(), ( + f"{label} should have shape {tensor_image.size()} " + f"got {picture.x[0].size()}" + ) + assert not torch.all( + picture.x[0] == 0 + ), f"{label} should not be all zeros, got all zeros." + assert torch.equal(tensor_image, picture.x[0]), ( + f"{label} should match the expected" + " main array from IC86 DNN mapping." + ) diff --git a/tests/models/test_image_representation.py b/tests/models/test_image_representation.py new file mode 100644 index 000000000..7cda8eaab --- /dev/null +++ b/tests/models/test_image_representation.py @@ -0,0 +1,88 @@ +from graphnet.models.data_representation import IC86Image +from graphnet.models.data_representation import NodesAsPulses +from graphnet.models.detector import IceCube86 +import torch +from torch_geometric.data import Data +import numpy as np +from graphnet.constants import ( + TEST_IC86MAIN_IMAGE, + TEST_IC86UPPERDC_IMAGE, + TEST_IC86LOWERDC_IMAGE, +) + +from tests.models.ic86_grid_testdata import ( + IC86_TEST_PIXEL_COLUMNS, + ic86_full_detector_pixel_table, +) + + +def test_ic86_image_representation() -> None: + """Pipeline IC86Image: pixel definition + grid → list of tensors.""" + dtype = torch.float32 + columns = IC86_TEST_PIXEL_COLUMNS + pixel_table = ic86_full_detector_pixel_table() + + pixel_def = NodesAsPulses( + input_feature_names=columns, + ) + + detector = IceCube86(replace_with_identity=columns) + + image_representation = IC86Image( + pixel_definition=pixel_def, + input_feature_names=columns, + include_lower_dc=True, + include_upper_dc=True, + string_label="string", + dom_number_label="dom_number", + dtype=dtype, + detector=detector, + ) + + assert ( + image_representation.nb_outputs == 2 + ), "Expected 2 outputs, got {}".format(image_representation.nb_outputs) + + output_feature_names = columns.copy() + output_feature_names.remove("string") + output_feature_names.remove("dom_number") + + assert image_representation.output_feature_names == output_feature_names, ( + f"Output feature names do not match expected output: " + f"{image_representation.output_feature_names} != {output_feature_names}" + ) + + image = image_representation( + pixel_table, + input_feature_names=columns, + ) + + assert isinstance( + image, Data + ), "Expected output to be a torch_geometric.data.Data object" + assert isinstance(image.x, list), "Expected image.x to be a list" + assert np.all( + [isinstance(x, torch.Tensor) for x in image.x] + ), "Expected all elements in image.x to be torch.Tensor" + assert ( + len(image.x) == 3 + ), "Expected image.x to have 3 elements, got {}".format(len(image.x)) + assert ( + "num_nodes" in image.keys() + ), "Expected 'num_nodes' in image attributes" + + image_list = [ + TEST_IC86MAIN_IMAGE, + TEST_IC86UPPERDC_IMAGE, + TEST_IC86LOWERDC_IMAGE, + ] + for i, img in enumerate(image_list): + expected_image = torch.tensor(np.load(img), dtype=dtype).unsqueeze(0) + assert image.x[i].size() == expected_image.size(), ( + f"Image at index {i} size mismatch: " + f"expected {torch.tensor(expected_image).size()}," + f"got {image.x[i].size()}" + ) + assert torch.equal( + image.x[i], expected_image + ), f"Image at index {i} does not match expected image"