Main
Training script using Hydra.
This does the following: 1. Parses the config using Hydra; 2. Instantiated the components (trainer / algorithm), optionally datamodule and network; 3. Trains the model; 4. Optionally runs an evaluation loop.
main #
main(dict_config: DictConfig) -> dict
Main entry point for training a model.
This does roughly the same thing as https://github.com/ashleve/lightning-hydra-template/blob/main/src/train.py
- Instantiates the experiment components from the Hydra configuration:
- trainer
- algorithm
- datamodule (optional)
- Calls
train
to train the algorithm - Calls
evaluation
to evaluate the model - Returns the evaluation metrics.
instantiate_values #
Returns the list of objects at the values in this dict of configs.
This is used for the config of the trainer/logger
and trainer/callbacks
fields, where
we can combine multiple config groups by adding entries in a dict.
For example, using trainer/logger=wandb
and trainer/logger=tensorboard
would result in a
dict with wandb
and tensorboard
as keys, and the corresponding config groups as values.
This would then return a list with the instantiated WandbLogger and TensorBoardLogger objects.
evaluate_lightningmodule #
evaluate_lightningmodule(
algorithm: LightningModule,
trainer: Trainer,
datamodule: LightningDataModule | None,
) -> tuple[MetricName, float | None, dict]
Evaluates the algorithm and returns the metrics.
By default, if validation is to be performed, returns the validation error. Returns the
training error when trainer.overfit_batches != 0
(e.g. when debugging or testing). Otherwise,
if trainer.limit_val_batches == 0
, returns the test error.