Callbacks
Callback #
Bases: Callback
, Generic[BatchType, StepOutputType]
Adds a bit of typing info and shared functions to the PyTorch Lightning Callback class.
Adds the following typing information:
- The type of inputs that the algorithm takes
- The type of outputs that are returned by the algorithm's [training/validation/test]_step
methods.
Adds the following methods:
- on_shared_batch_start
: called by on_[train/validation/test]_batch_start
- on_shared_batch_end
: called by on_[train/validation/test]_batch_end
- on_shared_epoch_start
: called by on_[train/validation/test]_epoch_start
- on_shared_epoch_end
: called by on_[train/validation/test]_epoch_end
on_shared_batch_start #
on_shared_batch_start(
trainer: Trainer,
pl_module: LightningModule,
batch: BatchType,
batch_index: int,
phase: Literal["train", "val", "test"],
dataloader_idx: int | None = None,
)
Shared hook, called by on_[train/validation/test]_batch_start
.
Use this if you want to do something at the start of batches in more than one phase.
on_shared_batch_end #
on_shared_batch_end(
trainer: Trainer,
pl_module: LightningModule,
outputs: StepOutputType,
batch: BatchType,
batch_index: int,
phase: Literal["train", "val", "test"],
dataloader_idx: int | None = None,
)
Shared hook, called by on_[train/validation/test]_batch_end
.
Use this if you want to do something at the end of batches in more than one phase.
ClassificationMetricsCallback #
Bases: Callback[BatchType, ClassificationOutputs]
Callback that adds classification metrics to a LightningModule.