Skip to content

Main

Training script using Hydra.

This does the following: 1. Parses the config using Hydra; 2. Instantiated the components (trainer / algorithm), optionally datamodule and network; 3. Trains the model; 4. Optionally runs an evaluation loop.

main #

main(dict_config: DictConfig) -> dict

Main entry point for training a model.

This does roughly the same thing as https://github.com/ashleve/lightning-hydra-template/blob/main/src/train.py

  1. Instantiates the experiment components from the Hydra configuration:
    • trainer
    • algorithm
    • datamodule (optional)
  2. Calls train to train the algorithm
  3. Calls evaluation to evaluate the model
  4. Returns the evaluation metrics.

instantiate_algorithm #

instantiate_algorithm(
    config: Config,
    datamodule: LightningDataModule | None = None,
) -> LightningModule | JaxModule

Function used to instantiate the algorithm.

It is suggested that your algorithm (LightningModule) take in the datamodule and network as arguments, to make it easier to swap out different networks and datamodules during experiments.

The instantiated datamodule and network will be passed to the algorithm's constructor.