Jax example
CNN #
Bases: Module
A simple CNN model.
Taken from https://flax.readthedocs.io/en/latest/quick_start.html#define-network
JaxExample #
Bases: LightningModule
Example of a learning algorithm (LightningModule
) that uses Jax.
In this case, the network is a flax.linen.Module, and its forward and backward passes are written in Jax, and the loss function is in pytorch.
HParams
dataclass
#
Hyper-parameters of the algo.
jit #
Small type hint fix for jax's jit
(preserves the signature of the callable).
value_and_grad #
value_and_grad(
fn: Callable[Concatenate[In, P], tuple[Out, Aux]],
argnums: Literal[0] = 0,
has_aux: Literal[True] = True,
) -> Callable[
Concatenate[In, P], tuple[tuple[Out, Aux], In]
]
Small type hint fix for jax's value_and_grad
(preserves the signature of the callable).