Algorithms
ExampleAlgorithm #
Bases: LightningModule
Example learning algorithm for image classification.
__init__ #
__init__(
datamodule: ImageClassificationDataModule,
network: _Config[Module],
optimizer: _PartialConfig[Optimizer] = AdamConfig(
lr=0.0003
),
init_seed: int = 42,
)
Create a new instance of the algorithm.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
datamodule
|
ImageClassificationDataModule
|
Object used to load train/val/test data. See the lightning docs for LightningDataModule for more info. |
required |
network
|
_Config[Module]
|
The config of the network to instantiate and train. |
required |
optimizer
|
_PartialConfig[Optimizer]
|
The config for the Optimizer. Instantiating this will return a function (a functools.partial) that will create the Optimizer given the hyper-parameters. |
AdamConfig(lr=0.0003)
|
init_seed
|
int
|
The seed to use when initializing the weights of the network. |
42
|
configure_optimizers #
Creates the optimizers.
See lightning.pytorch.core.LightningModule.configure_optimizers
for more information.
HFExample #
Bases: LightningModule
Example of a lightning module used to train a huggingface model.
configure_optimizers #
Prepare optimizer and schedule (linear warmup and decay)
JaxExample #
Bases: LightningModule
Example of a learning algorithm (LightningModule
) that uses Jax.
In this case, the network is a flax.linen.Module, and its forward and backward passes are written in Jax, and the loss function is in pytorch.
HParams
dataclass
#
Hyper-parameters of the algo.
JaxRLExample #
Bases: PyTreeNode
, JaxModule[PPOState[TEnvState], TrajectoryWithLastObs, EvalMetrics]
, Generic[TEnvState, TEnvParams]
Example of an RL algorithm written in Jax: PPO, based on rejax.PPO
.
Differences w.r.t. rejax.PPO:#
- The state / hparams are split into different, fully-typed structs:
- The algorithm state is in a typed
PPOState
struct (vs an untyped, dynamically-generated struct in rejax). - The hyper-parameters are in a typed
PPOHParams
struct. - The state variables related to the collection of data from the environment is a
TrajectoryCollectionState
instead of everything being bunched up together.- This makes it easier to call the
collect_episodes
function with just what it needs.
- This makes it easier to call the
- The algorithm state is in a typed
- The seeds for the networks and the environment data collection are separated.
The logic is exactly the same: The losses / updates are computed in the exact same way.
training_step #
training_step(
batch_idx: int,
ts: PPOState[TEnvState],
batch: TrajectoryWithLastObs,
)
Training step in pure jax.
train #
train(
rng: Array,
train_state: PPOState[TEnvState] | None = None,
skip_initial_evaluation: bool = False,
) -> tuple[PPOState[TEnvState], EvalMetrics]
Full training loop in jax.
This is only here to match the API of rejax.PPO.train
. This doesn't get called when using
the JaxTrainer
, since JaxTrainer.fit
already does the same thing, but also with support
for some JaxCallback
s (as well as some lightning.Callback
s!).
Unfolded version of rejax.PPO.train
.
NoOp #
Bases: LightningModule
No-op algorithm that does no learning and is used to benchmark the dataloading speed.