Skip to content

Jax image classifier

JaxCNN #

Bases: Module

A simple CNN model.

Taken from https://flax.readthedocs.io/en/latest/quick_start.html#define-network

JaxImageClassifier #

Bases: LightningModule

Example of a learning algorithm (LightningModule) that uses Jax.

In this case, the network is a flax.linen.Module, and its forward and backward passes are written in Jax, and the loss function is in pytorch.

configure_optimizers #

configure_optimizers()

Creates the optimizers.

See lightning.pytorch.core.LightningModule.configure_optimizers for more information.