From 470d077119ff8176a69ce965797df605ef1af275 Mon Sep 17 00:00:00 2001 From: Severin Magel Date: Sun, 3 May 2026 20:30:59 -0400 Subject: [PATCH 1/3] Add log_train_loss_on_step toggle to EasySyntax Adds an opt-in `log_train_loss_on_step` constructor argument that, when enabled, logs the per-batch training loss under `train_loss_step` in addition to the epoch-aggregated `train_loss`. Default is False so existing behavior is unchanged. --- src/graphnet/models/easy_model.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/graphnet/models/easy_model.py b/src/graphnet/models/easy_model.py index b43b7d63e..e28093bf1 100644 --- a/src/graphnet/models/easy_model.py +++ b/src/graphnet/models/easy_model.py @@ -36,6 +36,7 @@ def __init__( scheduler_class: Optional[type] = None, scheduler_kwargs: Optional[Dict] = None, scheduler_config: Optional[Dict] = None, + log_train_loss_on_step: bool = False, ) -> None: """Construct `StandardModel`.""" # Base class constructor @@ -52,6 +53,7 @@ def __init__( self._scheduler_class = scheduler_class self._scheduler_kwargs = scheduler_kwargs or dict() self._scheduler_config = scheduler_config or dict() + self._log_train_loss_on_step = log_train_loss_on_step self.validate_tasks() @@ -243,15 +245,26 @@ def training_step( if isinstance(train_batch, Data): train_batch = [train_batch] loss = self.shared_step(train_batch, batch_idx) + batch_size = self._get_batch_size(train_batch) self.log( "train_loss", loss, - batch_size=self._get_batch_size(train_batch), + batch_size=batch_size, prog_bar=True, on_epoch=True, on_step=False, sync_dist=True, ) + if self._log_train_loss_on_step: + self.log( + "train_loss_step", + loss, + batch_size=batch_size, + prog_bar=False, + on_epoch=False, + on_step=True, + sync_dist=True, + ) current_lr = self.trainer.optimizers[0].param_groups[0]["lr"] self.log("lr", current_lr, prog_bar=True, on_step=True) From d4e3508f314fba2e7b5f079b453e41dec55b6f8e Mon Sep 17 00:00:00 2001 From: Christian Locatelli <97306084+christianlocatelli@users.noreply.github.com> Date: Tue, 26 May 2026 10:55:48 +0200 Subject: [PATCH 2/3] Apply suggestions from code review Co-authored-by: Christian Locatelli <97306084+christianlocatelli@users.noreply.github.com> --- src/graphnet/models/easy_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/graphnet/models/easy_model.py b/src/graphnet/models/easy_model.py index e28093bf1..ca967741c 100644 --- a/src/graphnet/models/easy_model.py +++ b/src/graphnet/models/easy_model.py @@ -36,7 +36,7 @@ def __init__( scheduler_class: Optional[type] = None, scheduler_kwargs: Optional[Dict] = None, scheduler_config: Optional[Dict] = None, - log_train_loss_on_step: bool = False, + also_log_train_loss_per_step: bool = False, ) -> None: """Construct `StandardModel`.""" # Base class constructor @@ -53,7 +53,7 @@ def __init__( self._scheduler_class = scheduler_class self._scheduler_kwargs = scheduler_kwargs or dict() self._scheduler_config = scheduler_config or dict() - self._log_train_loss_on_step = log_train_loss_on_step + self._also_log_train_loss_per_step = also_log_train_loss_per_step self.validate_tasks() @@ -255,7 +255,7 @@ def training_step( on_step=False, sync_dist=True, ) - if self._log_train_loss_on_step: +if self._also_log_train_loss_on_step: self.log( "train_loss_step", loss, @@ -263,7 +263,7 @@ def training_step( prog_bar=False, on_epoch=False, on_step=True, - sync_dist=True, + sync_dist=False, ) current_lr = self.trainer.optimizers[0].param_groups[0]["lr"] From 1152aa78a8900873fc62d863baf050469b727bb2 Mon Sep 17 00:00:00 2001 From: Severin Magel Date: Tue, 26 May 2026 21:56:23 -0400 Subject: [PATCH 3/3] Fix botched code-review suggestion apply The previous commit ("Apply suggestions from code review") was created via GitHub's batch-suggestion apply, which mangled the indentation and left a name mismatch, so the module no longer imported: - under-indented `also_log_train_loss_per_step` parameter and attribute - top-level `if self._also_log_train_loss_on_step:` referencing an attribute that is never set (`_on_step` vs `_per_step`) Re-apply the reviewer's intent cleanly: rename to `also_log_train_loss_per_step`, log the per-step metric with `sync_dist=False`, and document all `__init__` arguments. Co-Authored-By: Claude Opus 4.7 --- src/graphnet/models/easy_model.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/graphnet/models/easy_model.py b/src/graphnet/models/easy_model.py index ca967741c..90d48b448 100644 --- a/src/graphnet/models/easy_model.py +++ b/src/graphnet/models/easy_model.py @@ -36,9 +36,23 @@ def __init__( scheduler_class: Optional[type] = None, scheduler_kwargs: Optional[Dict] = None, scheduler_config: Optional[Dict] = None, - also_log_train_loss_per_step: bool = False, + also_log_train_loss_per_step: bool = False, ) -> None: - """Construct `StandardModel`.""" + """Construct `StandardModel`. + + Args: + tasks: Task(s) appended as the head(s) of the model, defining + the prediction target(s) and loss(es). + optimizer_class: Optimizer class used during training. + optimizer_kwargs: Keyword arguments passed to `optimizer_class`. + scheduler_class: Learning-rate scheduler class. If `None`, no + scheduler is used. + scheduler_kwargs: Keyword arguments passed to `scheduler_class`. + scheduler_config: Additional configuration for how the scheduler + is invoked by PyTorch Lightning (e.g. `interval`, `frequency`). + also_log_train_loss_per_step: If `True`, additionally logs the + per-batch training loss under `train_loss_step`. + """ # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) @@ -53,7 +67,7 @@ def __init__( self._scheduler_class = scheduler_class self._scheduler_kwargs = scheduler_kwargs or dict() self._scheduler_config = scheduler_config or dict() - self._also_log_train_loss_per_step = also_log_train_loss_per_step + self._also_log_train_loss_per_step = also_log_train_loss_per_step self.validate_tasks() @@ -255,7 +269,7 @@ def training_step( on_step=False, sync_dist=True, ) -if self._also_log_train_loss_on_step: + if self._also_log_train_loss_per_step: self.log( "train_loss_step", loss,