Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added data/tests/images/IC86lower_deepcore_test.npy
Binary file not shown.
Binary file added data/tests/images/IC86main_array_test.npy
Binary file not shown.
Binary file added data/tests/images/IC86upper_deepcore_test.npy
Binary file not shown.
8 changes: 8 additions & 0 deletions src/graphnet/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@
TEST_PARQUET_DATA = os.path.join(
TEST_DATA_DIR, "parquet", _test_dataset_name, "merged"
)
TEST_IMAGE_DIR = os.path.join(TEST_DATA_DIR, "images")
TEST_IC86MAIN_IMAGE = os.path.join(TEST_IMAGE_DIR, "IC86main_array_test.npy")
TEST_IC86LOWERDC_IMAGE = os.path.join(
TEST_IMAGE_DIR, "IC86lower_deepcore_test.npy"
)
TEST_IC86UPPERDC_IMAGE = os.path.join(
TEST_IMAGE_DIR, "IC86upper_deepcore_test.npy"
)

# Example data
EXAMPLE_DATA_DIR = os.path.join(DATA_DIR, "examples")
Expand Down
8 changes: 8 additions & 0 deletions src/graphnet/models/data_representation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,11 @@
NodeAsDOMTimeSeries,
IceMixNodes,
)
from .images import (
ExamplePrometheusGridDefinition,
ExamplePrometheusImage,
GridDefinition,
IC86GridDefinition,
IC86Image,
ImageRepresentation,
)
9 changes: 9 additions & 0 deletions src/graphnet/models/data_representation/images/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""CNN images: ``ImageRepresentation`` + ``GridDefinition``."""

from .image_representation import ImageRepresentation
from .images import IC86Image, ExamplePrometheusImage
from .mappings import (
ExamplePrometheusGridDefinition,
GridDefinition,
IC86GridDefinition,
)
145 changes: 145 additions & 0 deletions src/graphnet/models/data_representation/images/image_representation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""Image CNN input pipeline: pulses → pixels → grids / tensors."""

from typing import List, Optional, Dict, Tuple, Union, Any, Callable
import torch
import numpy as np
from numpy.random import Generator

from graphnet.models.data_representation import DataRepresentation
from graphnet.models.data_representation.graphs import NodeDefinition
from torch_geometric.data import Data
from .mappings import GridDefinition


class ImageRepresentation(DataRepresentation):
"""Compose a pixel definition with a detector grid for CNN inputs.

A :class:`~graphnet.models.data_representation.graphs.nodes.NodeDefinition`
acts as **pixel definition**: pulses ``X`` are aggregated into unordered
pixel rows ``P`` (the same abstraction as graph nodes, without requiring
graph terminology for CNN users).

A :class:`GridDefinition` defines detector-bound orthonormal grid shape(s)
and lookup table(s); its :meth:`~GridDefinition.forward` scatters ``P``
into image tensor(s).

The :class:`~graphnet.models.detector.detector.Detector` is taken from
``grid_definition.detector`` so the grid matches the preprocessing geometry.
"""

def __init__(
self,
pixel_definition: NodeDefinition,
grid_definition: GridDefinition,
input_feature_names: Optional[List[str]] = None,
dtype: Optional[torch.dtype] = torch.float,
perturbation_dict: Optional[Dict[str, float]] = None,
seed: Optional[Union[int, Generator]] = None,
add_inactive_sensors: bool = False,
sensor_mask: Optional[List[int]] = None,
string_mask: Optional[List[int]] = None,
):
"""Construct `ImageRepresentation`.

Args:
pixel_definition: Pulse-level features → one row per pixel/DOM.
grid_definition: Pixel keys + voxel indices + scatter into images.
input_feature_names: Column names in raw pulse tables. If omitted,
the detector's feature list is used.
dtype: Feature dtype (e.g. ``torch.float``).
perturbation_dict: Optional feature noise (see ``DataRepresentation``).
seed: RNG for perturbations.
add_inactive_sensors: Pad inactive sensors when True.
sensor_mask: Drop these sensor IDs.
string_mask: Drop these string IDs.

Note:
``pixel_definition`` output columns must match what
``grid_definition`` expects (including key fields in
:attr:`GridDefinition.map_pixels_by`).
"""
super().__init__(
detector=grid_definition.detector,
input_feature_names=input_feature_names,
dtype=dtype,
perturbation_dict=perturbation_dict,
seed=seed,
add_inactive_sensors=add_inactive_sensors,
sensor_mask=sensor_mask,
string_mask=string_mask,
repeat_labels=False,
)
self._pixel_definition = pixel_definition
self._grid_definition = grid_definition

@property
def shape(self) -> List[List[int]]:
"""Channel-spatial layout per image tensor (see ``GridDefinition``)."""
return self._grid_definition.shape

def single_image_spatial_shape(self) -> Tuple[int, int, int]:
"""Return spatial size as ``(height, width, depth)`` for one 3D image.

Raises:
ValueError: If ``shape`` does not describe exactly one four-axis
layout (channels plus three spatial axes).
"""
layouts = self.shape
if len(layouts) != 1:
raise ValueError(
"Expected a single-image data representation (one shape "
f"entry), got {len(layouts)}. For multi-image inputs, build "
"the backbone explicitly for each tensor."
)
layout = layouts[0]
if len(layout) != 4:
raise ValueError(
"Expected each image layout as "
"[num_channels, height, width, depth]; "
f"got {layout!r}."
)
return (layout[1], layout[2], layout[3])

def _set_output_feature_names(
self, input_feature_names: List[str]
) -> List[str]:
"""Sync pixel columns and grid output names."""
self._pixel_definition.set_output_feature_names(input_feature_names)
self._grid_definition._set_image_feature_names(
self._pixel_definition._output_feature_names
)
return self._grid_definition.image_feature_names

def forward( # type: ignore
self,
input_features: np.ndarray,
input_feature_names: List[str],
truth_dicts: Optional[List[Dict[str, Any]]] = None,
custom_label_functions: Optional[Dict[str, Callable[..., Any]]] = None,
loss_weight_column: Optional[str] = None,
loss_weight: Optional[float] = None,
loss_weight_default_value: Optional[float] = None,
data_path: Optional[str] = None,
) -> Data:
"""Build a ``Data`` object with image tensor(s) on ``x``."""
data = super().forward(
input_features=input_features,
input_feature_names=input_feature_names,
truth_dicts=truth_dicts,
custom_label_functions=custom_label_functions,
loss_weight_column=loss_weight_column,
loss_weight=loss_weight,
loss_weight_default_value=loss_weight_default_value,
data_path=data_path,
)
data.x = self._pixel_definition(data.x)
data.x = data.x.type(self.dtype)
data = self._grid_definition(data, self.output_feature_names)
if not isinstance(data.x, list):
data.x = [data.x]
nb_nodes = []
for i, x in enumerate(data.x):
data.x[i] = x.type(self.dtype)
nb_nodes.append(np.prod(list(data.x[i].size()[2:])))
data.num_nodes = torch.tensor(np.sum(nb_nodes))
return data
123 changes: 123 additions & 0 deletions src/graphnet/models/data_representation/images/images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""Concrete :class:`ImageRepresentation` subclasses for common detectors."""

from typing import List, Optional, Any
import torch

from graphnet.models.data_representation.graphs import NodeDefinition
from graphnet.models.detector import Detector, IceCube86, ORCA150

from .image_representation import ImageRepresentation
from .mappings import IC86GridDefinition, ExamplePrometheusGridDefinition


class IC86Image(ImageRepresentation):
"""IceCube-86 images (main array + optional DeepCore tensors)."""

def __init__(
self,
pixel_definition: NodeDefinition,
input_feature_names: List[str],
include_lower_dc: bool = True,
include_upper_dc: bool = True,
string_label: str = "string",
dom_number_label: str = "dom_number",
dtype: Optional[torch.dtype] = torch.float,
detector: Optional[Detector] = None,
**kwargs: Any,
) -> None:
"""Construct `IC86Image`.

Args:
pixel_definition: Pulse → DOM row features (:class:`NodeDefinition`).
input_feature_names: Raw input column names.
include_lower_dc: Include lower DeepCore grid.
include_upper_dc: Include upper DeepCore grid.
string_label: DOM string column in pixel rows.
dom_number_label: DOM index column in pixel rows.
dtype: Tensor dtype for images.
detector: ``IceCube86``; default standardizes all but coordinates.
"""
if detector is None:
detector = IceCube86(
replace_with_identity=input_feature_names,
)
else:
assert isinstance(detector, IceCube86)
pixel_definition.set_output_feature_names(input_feature_names)
assert (
string_label in input_feature_names
), f"String label '{string_label}' not in input feature names"
assert (
dom_number_label in input_feature_names
), f"DOM number label '{dom_number_label}' not in input feature names"

grid_definition = IC86GridDefinition(
detector=detector,
dtype=dtype,
string_label=string_label,
dom_number_label=dom_number_label,
pixel_feature_names=pixel_definition._output_feature_names,
include_lower_dc=include_lower_dc,
include_upper_dc=include_upper_dc,
)

super().__init__(
pixel_definition=pixel_definition,
grid_definition=grid_definition,
input_feature_names=input_feature_names,
add_inactive_sensors=False,
**kwargs,
)


class ExamplePrometheusImage(ImageRepresentation):
"""Example Prometheus-style single-image layout (tutorial scripts only)."""

def __init__(
self,
pixel_definition: NodeDefinition,
input_feature_names: List[str],
string_label: str = "sensor_string_id",
dom_number_label: str = "sensor_id",
dtype: Optional[torch.dtype] = torch.float,
detector: Optional[Detector] = None,
**kwargs: Any,
) -> None:
"""Construct `ExamplePrometheusImage`.

Args:
pixel_definition: Pulse → sensor row features (:class:`NodeDefinition`).
input_feature_names: Raw input column names.
string_label: String id column in pixel rows.
dom_number_label: Sensor id column (internal grid key name).
dtype: Tensor dtype for images.
detector: ``ORCA150`` by default.
"""
if detector is None:
detector = ORCA150(
replace_with_identity=input_feature_names,
)

pixel_definition.set_output_feature_names(input_feature_names)
assert (
string_label in input_feature_names
), f"String label '{string_label}' not in input feature names"
assert (
dom_number_label in input_feature_names
), f"DOM number label '{dom_number_label}' not in input feature names"

grid_definition = ExamplePrometheusGridDefinition(
detector=detector,
dtype=dtype,
string_label=string_label,
sensor_number_label=dom_number_label,
pixel_feature_names=pixel_definition._output_feature_names,
)

super().__init__(
pixel_definition=pixel_definition,
grid_definition=grid_definition,
input_feature_names=input_feature_names,
add_inactive_sensors=False,
**kwargs,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Detector grids: shapes, lookup tables, scatter into image tensors."""

from .grid_definition import (
ExamplePrometheusGridDefinition,
GridDefinition,
IC86GridDefinition,
)
Loading
Loading