Skip to content

Jax example

CNN #

Bases: Module

A simple CNN model.

Taken from https://flax.readthedocs.io/en/latest/quick_start.html#define-network

JaxExample #

Bases: LightningModule

Example of a learning algorithm (LightningModule) that uses Jax.

In this case, the network is a flax.linen.Module, and its forward and backward passes are written in Jax, and the loss function is in pytorch.

HParams dataclass #

Hyper-parameters of the algo.