Skip to content

Experiment

The training and evaluation functions.

NOTE: This has to be in a different file than main because we want to allow registering different variants of the train_and_evaluate function for different algorithms with functools.singledispatch. If we have everything in main.py, the registration doesn't happen correctly.

train_and_evaluate #

train_and_evaluate(
    algorithm,
    /,
    *,
    datamodule: LightningDataModule | None = None,
    config: Config,
)

Generic function that trains and evaluates a learning algorithm.

This by default assumes that the algorithm is a LightningModule, but can be extended to implement specific training / evaluation procedures for different algorithms.

The default implementation here does roughly the same thing as https://github.com/ashleve/lightning-hydra-template/blob/main/src/train.py

  1. Instantiates the experiment components from the Hydra configuration:
    • algorithm (already instantiated)
    • trainer
    • datamodule (optional)
  2. Calls trainer.fit to train the algorithm
  3. Calls trainer.evaluate or trainer.test to evaluate the model
  4. Returns the metrics.

Extending to other algorithms or training procedures#

For example, if your algorithm has to be trained in two distinct phases, or if you want to use a different kind of Trainer that does something other than just call .fit and .evaluate, you could do something like this:

@train_and_evaluate.register(MyAlgorithm)
def train_and_evaluate_my_algo(algorithm: MyAlgorithm, /, *, trainer, datamodule)
    # making this up, this isn't doable with any of the datamodules at the moment.
    datamodule.set_task(1)
    trainer.fit(algorithm, datamodule)

    datamodule.set_task(2)
    trainer.fit(algorithm, datamodule)

evaluate_lightning #

evaluate_lightning(
    algorithm: LightningModule,
    /,
    *,
    trainer: Trainer,
    datamodule: LightningDataModule | None = None,
) -> tuple[str, float | None, dict]

Evaluates the algorithm and returns the metrics.

By default, if validation is to be performed, returns the validation error. Returns the training error when trainer.overfit_batches != 0 (e.g. when debugging or testing). Otherwise, if trainer.limit_val_batches == 0, returns the test error.

instantiate_trainer #

instantiate_trainer(
    trainer_config: dict | DictConfig,
) -> Trainer | Any

Instantiates the callbacks and loggers first, then creates the trainer from its config.

instantiate_values #

instantiate_values(
    config_dict: DictConfig | None,
) -> list[Any] | None

Returns the list of objects at the values in this dict of configs.

This is used for the config of the trainer/logger and trainer/callbacks fields, where we can combine multiple config groups by adding entries in a dict.

For example, using trainer/logger=wandb and trainer/logger=tensorboard would result in a dict with wandb and tensorboard as keys, and the corresponding config groups as values.

This would then return a list with the instantiated WandbLogger and TensorBoardLogger objects.