Skip to content

Algorithms

ExampleAlgorithm #

Bases: LightningModule

Example learning algorithm for image classification.

__init__ #

__init__(
    datamodule: ImageClassificationDataModule,
    network: _Config[Module],
    optimizer: _PartialConfig[Optimizer] = AdamConfig(
        lr=0.0003
    ),
    init_seed: int = 42,
)

Create a new instance of the algorithm.

Parameters:

Name Type Description Default
datamodule ImageClassificationDataModule

Object used to load train/val/test data. See the lightning docs for LightningDataModule for more info.

required
network _Config[Module]

The config of the network to instantiate and train.

required
optimizer _PartialConfig[Optimizer]

The config for the Optimizer. Instantiating this will return a function (a functools.partial) that will create the Optimizer given the hyper-parameters.

AdamConfig(lr=0.0003)
init_seed int

The seed to use when initializing the weights of the network.

42

forward #

forward(input: Tensor) -> Tensor

Forward pass of the network.

configure_optimizers #

configure_optimizers()

Creates the optimizers.

See lightning.pytorch.core.LightningModule.configure_optimizers for more information.

configure_callbacks #

configure_callbacks() -> Sequence[Callback] | Callback

Creates callbacks to be used by default during training.

HFExample #

Bases: LightningModule

Example of a lightning module used to train a huggingface model.

configure_optimizers #

configure_optimizers()

Prepare optimizer and schedule (linear warmup and decay)

JaxExample #

Bases: LightningModule

Example of a learning algorithm (LightningModule) that uses Jax.

In this case, the network is a flax.linen.Module, and its forward and backward passes are written in Jax, and the loss function is in pytorch.

HParams dataclass #

Hyper-parameters of the algo.

JaxRLExample #

Bases: PyTreeNode, JaxModule[PPOState[TEnvState], TrajectoryWithLastObs, EvalMetrics], Generic[TEnvState, TEnvParams]

Example of an RL algorithm written in Jax: PPO, based on rejax.PPO.

Differences w.r.t. rejax.PPO:#

  • The state / hparams are split into different, fully-typed structs:
    • The algorithm state is in a typed PPOState struct (vs an untyped, dynamically-generated struct in rejax).
    • The hyper-parameters are in a typed PPOHParams struct.
    • The state variables related to the collection of data from the environment is a TrajectoryCollectionState instead of everything being bunched up together.
      • This makes it easier to call the collect_episodes function with just what it needs.
  • The seeds for the networks and the environment data collection are separated.

The logic is exactly the same: The losses / updates are computed in the exact same way.

training_step #

training_step(
    batch_idx: int,
    ts: PPOState[TEnvState],
    batch: TrajectoryWithLastObs,
)

Training step in pure jax.

train #

train(
    rng: Array,
    train_state: PPOState[TEnvState] | None = None,
    skip_initial_evaluation: bool = False,
) -> tuple[PPOState[TEnvState], EvalMetrics]

Full training loop in jax.

This is only here to match the API of rejax.PPO.train. This doesn't get called when using the JaxTrainer, since JaxTrainer.fit already does the same thing, but also with support for some JaxCallbacks (as well as some lightning.Callbacks!).

Unfolded version of rejax.PPO.train.

fused_training_step #

fused_training_step(
    iteration: int, ts: PPOState[TEnvState]
)

Fused training step in jax (joined data collection + training).

MUCH faster than using pytorch-lightning, but you lose the callbacks and such.

NoOp #

Bases: LightningModule

No-op algorithm that does no learning and is used to benchmark the dataloading speed.