From 1a5900314638a05bd6c2f34d6850624b35f7ea82 Mon Sep 17 00:00:00 2001 From: Severin Magel Date: Mon, 18 May 2026 17:17:39 -0400 Subject: [PATCH 1/4] Clean up typing and docstrings in NodeDefinition Replace lowercase `torch.tensor` annotations with `torch.Tensor`, fix return-type docstrings on `NodeDefinition.forward` and `_construct_nodes` (they return a tensor, not a graph), correct the `ClusterSummaryFeatures._construct_nodes` return annotation from `Data` to `torch.Tensor`, and fix `_verify_standardization`'s return annotation (it returns None, not a tensor). Reorder imports to match the project's standard grouping. Also clarify the `ClusterSummaryFeatures` docstring about which features come from Glauch's thesis and which (counts) do not. --- .../data_representation/graphs/nodes/nodes.py | 38 +++++++++---------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/src/graphnet/models/data_representation/graphs/nodes/nodes.py b/src/graphnet/models/data_representation/graphs/nodes/nodes.py index 8fa4b74cb..10aeffb5b 100644 --- a/src/graphnet/models/data_representation/graphs/nodes/nodes.py +++ b/src/graphnet/models/data_representation/graphs/nodes/nodes.py @@ -1,22 +1,21 @@ """Class(es) for building/connecting graphs.""" -from typing import List, Tuple, Optional, Dict, Union from abc import abstractmethod +from copy import deepcopy +from typing import Dict, List, Optional, Tuple, Union +import numpy as np import torch from torch_geometric.data import Data -from graphnet.utilities.decorators import final from graphnet.models import Model from graphnet.models.data_representation.graphs.utils import ( cluster_and_pad, + ice_transparency, identify_indices, lex_sort, - ice_transparency, ) -from copy import deepcopy - -import numpy as np +from graphnet.utilities.decorators import final class NodeDefinition(Model): # pylint: disable=too-few-public-methods @@ -34,7 +33,7 @@ def __init__( ) @final - def forward(self, x: torch.tensor) -> torch.tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """Construct nodes from raw node features. Args: @@ -43,11 +42,9 @@ def forward(self, x: torch.tensor) -> torch.tensor: node_feature_names: list of names for each column in ´x´. Returns: - graph: a graph without edges + Node feature tensor of shape ´[num_nodes, num_features]´. """ - data = self._construct_nodes(x=x) - - return data + return self._construct_nodes(x=x) @property def _output_feature_names(self) -> List[str]: @@ -109,7 +106,7 @@ def _define_output_feature_names( """ @abstractmethod - def _construct_nodes(self, x: torch.tensor) -> torch.tensor: + def _construct_nodes(self, x: torch.Tensor) -> torch.Tensor: """Construct nodes from raw node features ´x´. Args: @@ -119,8 +116,7 @@ def _construct_nodes(self, x: torch.tensor) -> torch.tensor: order of appearance. Length `d`. Returns: - graph: graph without edges. - new_node_features: A list of node features names. + Node feature tensor of shape ´[num_nodes, num_features]´. """ @@ -495,9 +491,13 @@ class ClusterSummaryFeatures(NodeDefinition): - number of pulses per clusters feature name: `counts` - For more details on some of the features see - Theo Glauchs thesis (chapter 5.3): + For more details on most of the listed features see + Theo Glauch's thesis (chapter 5.3): https://mediatum.ub.tum.de/node?id=1584755 + + NOTE: The `counts` feature (number of pulses per cluster) is an + addition introduced in this implementation and is not part of + the feature set described in the referenced thesis. """ def __init__( @@ -602,7 +602,7 @@ def _define_output_feature_names( new_feature_names.append("counts") return new_feature_names - def _construct_nodes(self, x: torch.Tensor) -> Data: + def _construct_nodes(self, x: torch.Tensor) -> torch.Tensor: """Construct nodes from raw node features ´x´.""" # Cast to Numpy x = x.numpy() @@ -720,9 +720,7 @@ def _standardize_features( f"but got {standardization}" ) - def _verify_standardization( - self, - ) -> torch.Tensor: + def _verify_standardization(self) -> None: """Verify settings of standardization of the features.""" if not isinstance(self._charge_standardization, float): if isinstance(self._charge_standardization, str): From b50bd56590600a07624d37e23842defd595a629f Mon Sep 17 00:00:00 2001 From: Severin Magel Date: Mon, 18 May 2026 17:20:51 -0400 Subject: [PATCH 2/4] Fix NodeAsDOMTimeSeries empty-input handling When the input pulse tensor was empty, `_construct_nodes` returned a `torch_geometric.data.Data` object built from `np.column_stack([x, []])` instead of a `torch.Tensor` shaped `[0, num_features]`. Downstream code expects a tensor, so this could crash or silently produce a malformed graph. Return a correctly-shaped empty tensor instead, and drop the now-unused `Data` import. --- .../models/data_representation/graphs/nodes/nodes.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/graphnet/models/data_representation/graphs/nodes/nodes.py b/src/graphnet/models/data_representation/graphs/nodes/nodes.py index 10aeffb5b..0e0b216f6 100644 --- a/src/graphnet/models/data_representation/graphs/nodes/nodes.py +++ b/src/graphnet/models/data_representation/graphs/nodes/nodes.py @@ -6,7 +6,6 @@ import numpy as np import torch -from torch_geometric.data import Data from graphnet.models import Model from graphnet.models.data_representation.graphs.utils import ( @@ -265,10 +264,16 @@ def _define_output_feature_names( def _construct_nodes(self, x: torch.Tensor) -> torch.Tensor: """Construct nodes from raw node features ´x´.""" + if x.shape[0] == 0: + n_features = len(self._keys) + ( + 1 if self._charge_index is None else 0 + ) + # `new_node_col` appended below for non-empty events + n_features += 1 + return torch.empty((0, n_features), dtype=x.dtype, device=x.device) + # Cast to Numpy x = x.numpy() - if x.shape[0] == 0: - return Data(x=torch.tensor(np.column_stack([x, []]))) # if there is no charge column add a dummy column # of zeros with the same shape as the time column if self._charge_index is None: From 5576fdd425513f8a94eead23abde656a0c0b0965 Mon Sep 17 00:00:00 2001 From: Severin Magel Date: Mon, 18 May 2026 17:23:58 -0400 Subject: [PATCH 3/4] Export ClusterSummaryFeatures from data_representation The class was already importable through `graphnet.models.data_representation.graphs` but missing from the top-level `data_representation` package's public re-exports. --- src/graphnet/models/data_representation/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/graphnet/models/data_representation/__init__.py b/src/graphnet/models/data_representation/__init__.py index 84dd64331..52c5502e8 100644 --- a/src/graphnet/models/data_representation/__init__.py +++ b/src/graphnet/models/data_representation/__init__.py @@ -17,4 +17,5 @@ PercentileClusters, NodeAsDOMTimeSeries, IceMixNodes, + ClusterSummaryFeatures, ) From d355af172563feeea62080a6d8344d653ffe7952 Mon Sep 17 00:00:00 2001 From: Severin Magel Date: Tue, 26 May 2026 19:54:55 -0400 Subject: [PATCH 4/4] raising error for empty pulses --- .../data_representation/graphs/nodes/nodes.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/graphnet/models/data_representation/graphs/nodes/nodes.py b/src/graphnet/models/data_representation/graphs/nodes/nodes.py index 0e0b216f6..dae036b9b 100644 --- a/src/graphnet/models/data_representation/graphs/nodes/nodes.py +++ b/src/graphnet/models/data_representation/graphs/nodes/nodes.py @@ -43,6 +43,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: Node feature tensor of shape ´[num_nodes, num_features]´. """ + if x.shape[0] == 0: + raise ValueError( + f"{self.__class__.__name__} received an empty input tensor " + "(no pulses). An empty data object cannot be turned into " + "nodes; ensure events are non-empty before node construction." + ) return self._construct_nodes(x=x) @property @@ -264,14 +270,6 @@ def _define_output_feature_names( def _construct_nodes(self, x: torch.Tensor) -> torch.Tensor: """Construct nodes from raw node features ´x´.""" - if x.shape[0] == 0: - n_features = len(self._keys) + ( - 1 if self._charge_index is None else 0 - ) - # `new_node_col` appended below for non-empty events - n_features += 1 - return torch.empty((0, n_features), dtype=x.dtype, device=x.device) - # Cast to Numpy x = x.numpy() # if there is no charge column add a dummy column