Skip to content

Main

Training script using Hydra.

This does the following: 1. Parses the config using Hydra; 2. Instantiated the components (trainer / algorithm), optionally datamodule and network; 3. Trains the model; 4. Optionally runs an evaluation loop.

main #

main(dict_config: DictConfig) -> dict

Main entry point for training a model.

This does roughly the same thing as https://github.com/ashleve/lightning-hydra-template/blob/main/src/train.py

  1. Instantiates the experiment components from the Hydra configuration:
    • trainer
    • algorithm
    • datamodule (optional)
  2. Calls train to train the algorithm
  3. Calls evaluation to evaluate the model
  4. Returns the evaluation metrics.

instantiate_values #

instantiate_values(
    config_dict: DictConfig | None,
) -> list[Any] | None

Returns the list of objects at the values in this dict of configs.

This is used for the config of the trainer/logger and trainer/callbacks fields, where we can combine multiple config groups by adding entries in a dict.

For example, using trainer/logger=wandb and trainer/logger=tensorboard would result in a dict with wandb and tensorboard as keys, and the corresponding config groups as values.

This would then return a list with the instantiated WandbLogger and TensorBoardLogger objects.

evaluate_lightningmodule #

evaluate_lightningmodule(
    algorithm: LightningModule,
    trainer: Trainer,
    datamodule: LightningDataModule | None,
) -> tuple[MetricName, float | None, dict]

Evaluates the algorithm and returns the metrics.

By default, if validation is to be performed, returns the validation error. Returns the training error when trainer.overfit_batches != 0 (e.g. when debugging or testing). Otherwise, if trainer.limit_val_batches == 0, returns the test error.

get_error_from_metrics #

get_error_from_metrics(
    metrics: _MetricsT,
) -> tuple[MetricName, float, dict]

Returns the main metric name, its value, and the full metrics dictionary.