Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions src/graphnet/models/easy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,23 @@ def __init__(
scheduler_class: Optional[type] = None,
scheduler_kwargs: Optional[Dict] = None,
scheduler_config: Optional[Dict] = None,
also_log_train_loss_per_step: bool = False,
) -> None:
"""Construct `StandardModel`."""
"""Construct `StandardModel`.

Args:
tasks: Task(s) appended as the head(s) of the model, defining
the prediction target(s) and loss(es).
optimizer_class: Optimizer class used during training.
optimizer_kwargs: Keyword arguments passed to `optimizer_class`.
scheduler_class: Learning-rate scheduler class. If `None`, no
scheduler is used.
scheduler_kwargs: Keyword arguments passed to `scheduler_class`.
scheduler_config: Additional configuration for how the scheduler
is invoked by PyTorch Lightning (e.g. `interval`, `frequency`).
also_log_train_loss_per_step: If `True`, additionally logs the
per-batch training loss under `train_loss_step`.
"""
# Base class constructor
super().__init__(name=__name__, class_name=self.__class__.__name__)

Expand All @@ -52,6 +67,7 @@ def __init__(
self._scheduler_class = scheduler_class
self._scheduler_kwargs = scheduler_kwargs or dict()
self._scheduler_config = scheduler_config or dict()
self._also_log_train_loss_per_step = also_log_train_loss_per_step

self.validate_tasks()

Expand Down Expand Up @@ -243,15 +259,26 @@ def training_step(
if isinstance(train_batch, Data):
train_batch = [train_batch]
loss = self.shared_step(train_batch, batch_idx)
batch_size = self._get_batch_size(train_batch)
self.log(
"train_loss",
loss,
batch_size=self._get_batch_size(train_batch),
batch_size=batch_size,
prog_bar=True,
on_epoch=True,
on_step=False,
sync_dist=True,
)
if self._also_log_train_loss_per_step:
self.log(
"train_loss_step",
loss,
batch_size=batch_size,
prog_bar=False,
on_epoch=False,
on_step=True,
sync_dist=False,
)

current_lr = self.trainer.optimizers[0].param_groups[0]["lr"]
self.log("lr", current_lr, prog_bar=True, on_step=True)
Expand Down
Loading