Skip to content

Callbacks

Callback #

Bases: Callback, Generic[BatchType, StepOutputType]

Adds a bit of typing info and shared functions to the PyTorch Lightning Callback class.

Adds the following typing information: - The type of inputs that the algorithm takes - The type of outputs that are returned by the algorithm's [training/validation/test]_step methods.

Adds the following methods: - on_shared_batch_start: called by on_[train/validation/test]_batch_start - on_shared_batch_end: called by on_[train/validation/test]_batch_end - on_shared_epoch_start: called by on_[train/validation/test]_epoch_start - on_shared_epoch_end: called by on_[train/validation/test]_epoch_end

on_shared_batch_start #

on_shared_batch_start(
    trainer: Trainer,
    pl_module: LightningModule,
    batch: BatchType,
    batch_index: int,
    phase: Literal["train", "val", "test"],
    dataloader_idx: int | None = None,
)

Shared hook, called by on_[train/validation/test]_batch_start.

Use this if you want to do something at the start of batches in more than one phase.

on_shared_batch_end #

on_shared_batch_end(
    trainer: Trainer,
    pl_module: LightningModule,
    outputs: StepOutputType,
    batch: BatchType,
    batch_index: int,
    phase: Literal["train", "val", "test"],
    dataloader_idx: int | None = None,
)

Shared hook, called by on_[train/validation/test]_batch_end.

Use this if you want to do something at the end of batches in more than one phase.

on_shared_epoch_start #

on_shared_epoch_start(
    trainer: Trainer,
    pl_module: LightningModule,
    phase: Literal["train", "val", "test"],
) -> None

Shared hook, called by on_[train/validation/test]_epoch_start.

Use this if you want to do something at the start of epochs in more than one phase.

on_shared_epoch_end #

on_shared_epoch_end(
    trainer: Trainer,
    pl_module: LightningModule,
    phase: Literal["train", "val", "test"],
) -> None

Shared hook, called by on_[train/validation/test]_epoch_end.

Use this if you want to do something at the end of epochs in more than one phase.

ClassificationMetricsCallback #

Bases: Callback[BatchType, ClassificationOutputs]

Callback that adds classification metrics to a LightningModule.