Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/graphnet/models/data_representation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@
PercentileClusters,
NodeAsDOMTimeSeries,
IceMixNodes,
ClusterSummaryFeatures,
)
47 changes: 24 additions & 23 deletions src/graphnet/models/data_representation/graphs/nodes/nodes.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
"""Class(es) for building/connecting graphs."""

from typing import List, Tuple, Optional, Dict, Union
from abc import abstractmethod
from copy import deepcopy
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from torch_geometric.data import Data

from graphnet.utilities.decorators import final
from graphnet.models import Model
from graphnet.models.data_representation.graphs.utils import (
cluster_and_pad,
ice_transparency,
identify_indices,
lex_sort,
ice_transparency,
)
from copy import deepcopy

import numpy as np
from graphnet.utilities.decorators import final


class NodeDefinition(Model): # pylint: disable=too-few-public-methods
Expand All @@ -34,7 +32,7 @@ def __init__(
)

@final
def forward(self, x: torch.tensor) -> torch.tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Construct nodes from raw node features.

Args:
Expand All @@ -43,11 +41,15 @@ def forward(self, x: torch.tensor) -> torch.tensor:
node_feature_names: list of names for each column in ´x´.

Returns:
graph: a graph without edges
Node feature tensor of shape ´[num_nodes, num_features]´.
"""
data = self._construct_nodes(x=x)

return data
if x.shape[0] == 0:
raise ValueError(
f"{self.__class__.__name__} received an empty input tensor "
"(no pulses). An empty data object cannot be turned into "
"nodes; ensure events are non-empty before node construction."
)
return self._construct_nodes(x=x)

@property
def _output_feature_names(self) -> List[str]:
Expand Down Expand Up @@ -109,7 +111,7 @@ def _define_output_feature_names(
"""

@abstractmethod
def _construct_nodes(self, x: torch.tensor) -> torch.tensor:
def _construct_nodes(self, x: torch.Tensor) -> torch.Tensor:
"""Construct nodes from raw node features ´x´.

Args:
Expand All @@ -119,8 +121,7 @@ def _construct_nodes(self, x: torch.tensor) -> torch.tensor:
order of appearance. Length `d`.

Returns:
graph: graph without edges.
new_node_features: A list of node features names.
Node feature tensor of shape ´[num_nodes, num_features]´.
"""


Expand Down Expand Up @@ -271,8 +272,6 @@ def _construct_nodes(self, x: torch.Tensor) -> torch.Tensor:
"""Construct nodes from raw node features ´x´."""
# Cast to Numpy
x = x.numpy()
if x.shape[0] == 0:
return Data(x=torch.tensor(np.column_stack([x, []])))
# if there is no charge column add a dummy column
# of zeros with the same shape as the time column
if self._charge_index is None:
Expand Down Expand Up @@ -495,9 +494,13 @@ class ClusterSummaryFeatures(NodeDefinition):
- number of pulses per clusters
feature name: `counts`

For more details on some of the features see
Theo Glauchs thesis (chapter 5.3):
For more details on most of the listed features see
Theo Glauch's thesis (chapter 5.3):
https://mediatum.ub.tum.de/node?id=1584755

NOTE: The `counts` feature (number of pulses per cluster) is an
addition introduced in this implementation and is not part of
the feature set described in the referenced thesis.
"""

def __init__(
Expand Down Expand Up @@ -602,7 +605,7 @@ def _define_output_feature_names(
new_feature_names.append("counts")
return new_feature_names

def _construct_nodes(self, x: torch.Tensor) -> Data:
def _construct_nodes(self, x: torch.Tensor) -> torch.Tensor:
"""Construct nodes from raw node features ´x´."""
# Cast to Numpy
x = x.numpy()
Expand Down Expand Up @@ -720,9 +723,7 @@ def _standardize_features(
f"but got {standardization}"
)

def _verify_standardization(
self,
) -> torch.Tensor:
def _verify_standardization(self) -> None:
"""Verify settings of standardization of the features."""
if not isinstance(self._charge_standardization, float):
if isinstance(self._charge_standardization, str):
Expand Down
Loading