Experiment
The training and evaluation functions.
NOTE: This has to be in a different file than main
because we want to allow registering different
variants of the train_and_evaluate
function for different algorithms with
functools.singledispatch
. If we have everything in main.py
, the registration doesn't happen
correctly.
train_and_evaluate #
train_and_evaluate(
algorithm,
/,
*,
datamodule: LightningDataModule | None = None,
config: Config,
)
Generic function that trains and evaluates a learning algorithm.
This by default assumes that the algorithm is a LightningModule, but can be extended to implement specific training / evaluation procedures for different algorithms.
The default implementation here does roughly the same thing as https://github.com/ashleve/lightning-hydra-template/blob/main/src/train.py
- Instantiates the experiment components from the Hydra configuration:
- algorithm (already instantiated)
- trainer
- datamodule (optional)
- Calls
trainer.fit
to train the algorithm - Calls
trainer.evaluate
ortrainer.test
to evaluate the model - Returns the metrics.
Extending to other algorithms or training procedures#
For example, if your algorithm has to be trained in two distinct phases, or if you want to use
a different kind of Trainer that does something other than just call .fit
and .evaluate
,
you could do something like this:
@train_and_evaluate.register(MyAlgorithm)
def train_and_evaluate_my_algo(algorithm: MyAlgorithm, /, *, trainer, datamodule)
# making this up, this isn't doable with any of the datamodules at the moment.
datamodule.set_task(1)
trainer.fit(algorithm, datamodule)
datamodule.set_task(2)
trainer.fit(algorithm, datamodule)
evaluate_lightning #
evaluate_lightning(
algorithm: LightningModule,
/,
*,
trainer: Trainer,
datamodule: LightningDataModule | None = None,
) -> tuple[str, float | None, dict]
Evaluates the algorithm and returns the metrics.
By default, if validation is to be performed, returns the validation error. Returns the
training error when trainer.overfit_batches != 0
(e.g. when debugging or testing). Otherwise,
if trainer.limit_val_batches == 0
, returns the test error.
instantiate_trainer #
Instantiates the callbacks and loggers first, then creates the trainer from its config.
instantiate_values #
Returns the list of objects at the values in this dict of configs.
This is used for the config of the trainer/logger
and trainer/callbacks
fields, where
we can combine multiple config groups by adding entries in a dict.
For example, using trainer/logger=wandb
and trainer/logger=tensorboard
would result in a
dict with wandb
and tensorboard
as keys, and the corresponding config groups as values.
This would then return a list with the instantiated WandbLogger and TensorBoardLogger objects.