Jax image classifier
JaxCNN #
Bases: Module
A simple CNN model.
Taken from https://flax.readthedocs.io/en/latest/quick_start.html#define-network
JaxImageClassifier #
Bases: LightningModule
Example of a learning algorithm (LightningModule
) that uses Jax.
In this case, the network is a flax.linen.Module, and its forward and backward passes are written in Jax, and the loss function is in pytorch.
configure_optimizers #
Creates the optimizers.
See lightning.pytorch.core.LightningModule.configure_optimizers
for more information.