Skip to content

Jax example

JaxExample #

Bases: LightningModule

Example of a learning algorithm (LightningModule) that uses Jax.

In this case, the network is a flax.linen.Module, and its forward and backward passes are written in Jax, and the loss function is in pytorch.

HParams dataclass #

Hyper-parameters of the algo.

jit #

jit(fn: Callable[P, Out]) -> Callable[P, Out]

Small type hint fix for jax's jit (preserves the signature of the callable).

value_and_grad #

value_and_grad(
    fn: Callable[Concatenate[In, P], tuple[Out, Aux]],
    argnums: Literal[0] = 0,
    has_aux: Literal[True] = True,
) -> Callable[
    Concatenate[In, P], tuple[tuple[Out, Aux], In]
]

Small type hint fix for jax's value_and_grad (preserves the signature of the callable).