diff --git a/src/graphnet/models/easy_model.py b/src/graphnet/models/easy_model.py index b43b7d63e..90d48b448 100644 --- a/src/graphnet/models/easy_model.py +++ b/src/graphnet/models/easy_model.py @@ -36,8 +36,23 @@ def __init__( scheduler_class: Optional[type] = None, scheduler_kwargs: Optional[Dict] = None, scheduler_config: Optional[Dict] = None, + also_log_train_loss_per_step: bool = False, ) -> None: - """Construct `StandardModel`.""" + """Construct `StandardModel`. + + Args: + tasks: Task(s) appended as the head(s) of the model, defining + the prediction target(s) and loss(es). + optimizer_class: Optimizer class used during training. + optimizer_kwargs: Keyword arguments passed to `optimizer_class`. + scheduler_class: Learning-rate scheduler class. If `None`, no + scheduler is used. + scheduler_kwargs: Keyword arguments passed to `scheduler_class`. + scheduler_config: Additional configuration for how the scheduler + is invoked by PyTorch Lightning (e.g. `interval`, `frequency`). + also_log_train_loss_per_step: If `True`, additionally logs the + per-batch training loss under `train_loss_step`. + """ # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) @@ -52,6 +67,7 @@ def __init__( self._scheduler_class = scheduler_class self._scheduler_kwargs = scheduler_kwargs or dict() self._scheduler_config = scheduler_config or dict() + self._also_log_train_loss_per_step = also_log_train_loss_per_step self.validate_tasks() @@ -243,15 +259,26 @@ def training_step( if isinstance(train_batch, Data): train_batch = [train_batch] loss = self.shared_step(train_batch, batch_idx) + batch_size = self._get_batch_size(train_batch) self.log( "train_loss", loss, - batch_size=self._get_batch_size(train_batch), + batch_size=batch_size, prog_bar=True, on_epoch=True, on_step=False, sync_dist=True, ) + if self._also_log_train_loss_per_step: + self.log( + "train_loss_step", + loss, + batch_size=batch_size, + prog_bar=False, + on_epoch=False, + on_step=True, + sync_dist=False, + ) current_lr = self.trainer.optimizers[0].param_groups[0]["lr"] self.log("lr", current_lr, prog_bar=True, on_step=True)