Skip to content

Reinforcement Learning in Jax#

This example follows the same structure as the other examples:

However, there are some very important differences:

  • There is no "datamodule". The algorithm accepts an Environment (gymnax.Environment) as input.
  • The "Trainer" is a JaxTrainer, instead of a lightning.Trainer.
  • The full training loop is written in Jax;
  • Some (but not all) PyTorch-Lightning callbacks can still be used with the JaxTrainer;
  • The JaxRLExample class is an algorithm based on rejax.PPO.

JaxRLExample#

The JaxRLExample is based on rejax.PPO. It follows the structure of a JaxModule, and is trained with a JaxTrainer.

Click to show the code for JaxRLExample
class JaxRLExample(
    flax.struct.PyTreeNode,
    JaxModule[PPOState[TEnvState], TrajectoryWithLastObs, EvalMetrics],
    Generic[TEnvState, TEnvParams],
):
    """Example of an RL algorithm written in Jax: PPO, based on `rejax.PPO`.

    ## Differences w.r.t. rejax.PPO:

    - The state / hparams are split into different, fully-typed structs:
        - The algorithm state is in a typed `PPOState` struct (vs an untyped,
            dynamically-generated struct in rejax).
        - The hyper-parameters are in a typed `PPOHParams` struct.
        - The state variables related to the collection of data from the environment is a
            `TrajectoryCollectionState` instead of everything being bunched up together.
            - This makes it easier to call the `collect_episodes` function with just what it needs.
    - The seeds for the networks and the environment data collection are separated.

    The logic is exactly the same: The losses / updates are computed in the exact same way.
    """

    env: Environment[TEnvState, TEnvParams] = flax.struct.field(pytree_node=False)
    env_params: TEnvParams
    actor: flax.linen.Module = flax.struct.field(pytree_node=False)
    critic: flax.linen.Module = flax.struct.field(pytree_node=False)
    hp: PPOHParams

    @classmethod
    def create(
        cls,
        env_id: str | None = None,
        env: Environment[TEnvState, TEnvParams] | None = None,
        env_params: TEnvParams | None = None,
        hp: PPOHParams | None = None,
    ) -> JaxRLExample[TEnvState, TEnvParams]:
        from brax.envs import _envs as brax_envs
        from rejax.compat.brax2gymnax import create_brax

        # env_params: gymnax.EnvParams
        if env_id is None:
            assert env is not None
            env_params = env_params or env.default_params  # type: ignore
        elif env_id in brax_envs:
            env, env_params = create_brax(  # type: ignore
                env_id,
                episode_length=1000,
                action_repeat=1,
                auto_reset=True,
                batch_size=None,
                backend="generalized",
            )
        elif isinstance(env_id, str):
            env, env_params = gymnax.make(env_id=env_id)  # type: ignore
        else:
            raise NotImplementedError(env_id)

        assert env is not None
        assert env_params is not None
        return cls(
            env=env,
            env_params=env_params,
            actor=cls.create_actor(env, env_params),
            critic=cls.create_critic(),
            hp=hp or PPOHParams(),
        )

    @classmethod
    def create_networks(
        cls,
        env: Environment[gymnax.EnvState, TEnvParams],
        env_params: TEnvParams,
        config: _NetworkConfig,
    ):
        # Equivalent to:
        # return rejax.PPO.create_agent(config, env, env_params)
        return {
            "actor": cls.create_actor(env, env_params, **config["agent_kwargs"]),
            "critic": cls.create_actor(env, env_params, **config["agent_kwargs"]),
        }

    _TEnvParams = TypeVar("_TEnvParams", bound=gymnax.EnvParams, covariant=True)
    _TEnvState = TypeVar("_TEnvState", bound=gymnax.EnvState, covariant=True)

    @classmethod
    def create_actor(
        cls,
        env: Environment[_TEnvState, _TEnvParams],
        env_params: _TEnvParams,
        activation: str | Callable[[jax.Array], jax.Array] = "swish",
        hidden_layer_sizes: Sequence[int] = (64, 64),
        **actor_kwargs,
    ) -> DiscretePolicy | GaussianPolicy:
        activation_fn: Callable[[jax.Array], jax.Array] = (
            getattr(flax.linen, activation) if not callable(activation) else activation
        )
        hidden_layer_sizes = tuple(hidden_layer_sizes)
        action_space = env.action_space(env_params)

        if isinstance(action_space, gymnax.environments.spaces.Discrete):
            return DiscretePolicy(
                action_space.n,
                activation=activation_fn,
                hidden_layer_sizes=hidden_layer_sizes,
                **actor_kwargs,
            )
        assert isinstance(action_space, gymnax.environments.spaces.Box)
        return GaussianPolicy(
            np.prod(action_space.shape),
            (action_space.low, action_space.high),  # type: ignore
            activation=activation_fn,
            hidden_layer_sizes=hidden_layer_sizes,
            **actor_kwargs,
        )

    @classmethod
    def create_critic(
        cls,
        activation: str | Callable[[jax.Array], jax.Array] = "swish",
        hidden_layer_sizes: Sequence[int] = (64, 64),
        **critic_kwargs,
    ) -> VNetwork:
        activation_fn: Callable[[jax.Array], jax.Array] = (
            getattr(flax.linen, activation) if isinstance(activation, str) else activation
        )
        hidden_layer_sizes = tuple(hidden_layer_sizes)
        return VNetwork(
            hidden_layer_sizes=hidden_layer_sizes, activation=activation_fn, **critic_kwargs
        )

    def init_train_state(self, rng: chex.PRNGKey) -> PPOState[TEnvState]:
        rng, networks_rng, env_rng = jax.random.split(rng, 3)

        rng_actor, rng_critic = jax.random.split(networks_rng, 2)

        obs_ph = jnp.empty([1, *self.env.observation_space(self.env_params).shape])

        actor_params = self.actor.init(rng_actor, obs_ph, rng_actor)
        critic_params = self.critic.init(rng_critic, obs_ph)

        tx = optax.adam(learning_rate=self.hp.learning_rate)
        # TODO: Why isn't the `apply_fn` not set in rejax?
        actor_ts = TrainState.create(apply_fn=self.actor.apply, params=actor_params, tx=tx)
        critic_ts = TrainState.create(apply_fn=self.critic.apply, params=critic_params, tx=tx)

        env_rng, reset_rng = jax.random.split(env_rng)
        obs, env_state = jax.vmap(self.env.reset, in_axes=(0, None))(
            jax.random.split(reset_rng, self.hp.num_envs), self.env_params
        )

        collection_state = TrajectoryCollectionState(
            last_obs=obs,
            rms_state=RMSState.create(shape=obs_ph.shape),
            global_step=0,
            env_state=env_state,
            last_done=jnp.zeros(self.hp.num_envs, dtype=bool),
            rng=env_rng,
        )

        return PPOState(
            actor_ts=actor_ts,
            critic_ts=critic_ts,
            rng=rng,
            data_collection_state=collection_state,
        )

    # @jit
    def training_step(self, batch_idx: int, ts: PPOState[TEnvState], batch: TrajectoryWithLastObs):
        """Training step in pure jax."""
        trajectories = batch

        ts, (actor_losses, critic_losses) = jax.lax.scan(
            functools.partial(self.ppo_update_epoch, trajectories=trajectories),
            init=ts,
            xs=jnp.arange(self.hp.num_epochs),  # type: ignore
            length=self.hp.num_epochs,
        )
        # todo: perhaps we could have a callback that updates a progress bar?
        # jax.debug.print("actor_losses {}: {}", iteration, actor_losses.mean())
        # jax.debug.print("critic_losses {}: {}", iteration, critic_losses.mean())

        return ts, TrainStepMetrics(actor_losses=actor_losses, critic_losses=critic_losses)

    # @jit
    def ppo_update_epoch(
        self, ts: PPOState[TEnvState], epoch_index: int, trajectories: TrajectoryWithLastObs
    ):
        minibatch_rng = jax.random.fold_in(ts.rng, epoch_index)

        last_val = self.critic.apply(ts.critic_ts.params, ts.data_collection_state.last_obs)
        assert isinstance(last_val, jax.Array)
        last_val = jnp.where(ts.data_collection_state.last_done, 0, last_val)
        advantages, targets = calculate_gae(
            trajectories, last_val, gamma=self.hp.gamma, gae_lambda=self.hp.gae_lambda
        )
        batch = AdvantageMinibatch(trajectories.trajectories, advantages, targets)
        minibatches = shuffle_and_split(
            batch, minibatch_rng, num_minibatches=self.hp.num_minibatches
        )

        # shuffle the data and split it into minibatches

        num_steps = self.hp.num_steps
        num_envs = self.hp.num_envs
        num_minibatches = self.hp.num_minibatches
        assert (num_envs * num_steps) % num_minibatches == 0
        minibatches = shuffle_and_split(
            batch,
            minibatch_rng,
            num_minibatches=num_minibatches,
        )
        return jax.lax.scan(self.ppo_update, ts, minibatches, length=self.hp.num_minibatches)

    # @jit
    def ppo_update(self, ts: PPOState[TEnvState], batch: AdvantageMinibatch):
        actor_loss, actor_grads = jax.value_and_grad(actor_loss_fn)(
            ts.actor_ts.params,
            actor=self.actor,
            batch=batch,
            clip_eps=self.hp.clip_eps,
            ent_coef=self.hp.ent_coef,
        )
        assert isinstance(actor_loss, jax.Array)
        critic_loss, critic_grads = jax.value_and_grad(critic_loss_fn)(
            ts.critic_ts.params,
            critic=self.critic,
            batch=batch,
            clip_eps=self.hp.clip_eps,
            vf_coef=self.hp.vf_coef,
        )
        assert isinstance(critic_loss, jax.Array)

        # TODO: to log the loss here?
        actor_ts = ts.actor_ts.apply_gradients(grads=actor_grads)
        critic_ts = ts.critic_ts.apply_gradients(grads=critic_grads)

        return ts.replace(actor_ts=actor_ts, critic_ts=critic_ts), (actor_loss, critic_loss)

    def eval_callback(
        self, ts: PPOState[TEnvState], rng: chex.PRNGKey | None = None
    ) -> EvalMetrics:
        if rng is None:
            rng = ts.rng
        actor = make_actor(ts=ts, hp=self.hp)
        ep_lengths, cum_rewards = evaluate(
            actor,
            ts.rng,
            self.env,
            self.env_params,
            num_seeds=self.hp.num_seeds_per_eval,
            max_steps_in_episode=self.env_params.max_steps_in_episode,
        )
        return EvalMetrics(episode_length=ep_lengths, cumulative_reward=cum_rewards)

    def get_batch(
        self, ts: PPOState[TEnvState], batch_idx: int
    ) -> tuple[PPOState[TEnvState], TrajectoryWithLastObs]:
        data_collection_state, trajectories = self.collect_trajectories(
            ts.data_collection_state,
            actor_params=ts.actor_ts.params,
            critic_params=ts.critic_ts.params,
        )
        ts = ts.replace(data_collection_state=data_collection_state)
        return ts, trajectories

    # @jit
    def collect_trajectories(
        self,
        collection_state: TrajectoryCollectionState[TEnvState],
        actor_params: FrozenVariableDict,
        critic_params: FrozenVariableDict,
    ):
        env_step_fn = functools.partial(
            self.env_step,
            # env=self.env,
            # env_params=self.env_params,
            # actor=self.actor,
            # critic=self.critic,
            # num_envs=self.hp.num_envs,
            actor_params=actor_params,
            critic_params=critic_params,
            # discrete=self.discrete,
            # normalize_observations=self.hp.normalize_observations,
        )
        collection_state, trajectories = jax.lax.scan(
            env_step_fn,
            collection_state,
            xs=jnp.arange(self.hp.num_steps),
            length=self.hp.num_steps,
        )
        trajectories_with_last = TrajectoryWithLastObs(
            trajectories=trajectories,
            last_done=collection_state.last_done,
            last_obs=collection_state.last_obs,
        )
        return collection_state, trajectories_with_last

    # @jit
    def env_step(
        self,
        collection_state: TrajectoryCollectionState[TEnvState],
        step_index: jax.Array,
        actor_params: FrozenVariableDict,
        critic_params: FrozenVariableDict,
    ):
        # Get keys for sampling action and stepping environment
        # doing it this way to try to get *exactly* the same rngs as in rejax.PPO.
        rng, new_rngs = jax.random.split(collection_state.rng, 2)
        rng_steps, rng_action = jax.random.split(new_rngs, 2)
        rng_steps = jax.random.split(rng_steps, self.hp.num_envs)

        # Sample action
        unclipped_action, log_prob = self.actor.apply(
            actor_params, collection_state.last_obs, rng_action, method="action_log_prob"
        )
        assert isinstance(log_prob, jax.Array)
        value = self.critic.apply(critic_params, collection_state.last_obs)
        assert isinstance(value, jax.Array)

        # Clip action
        if self.discrete:
            action = unclipped_action
        else:
            low = self.env.action_space(self.env_params).low
            high = self.env.action_space(self.env_params).high
            action = jnp.clip(unclipped_action, low, high)

        # Step environment
        next_obs, env_state, reward, done, _ = jax.vmap(self.env.step, in_axes=(0, 0, 0, None))(
            rng_steps,
            collection_state.env_state,
            action,
            self.env_params,
        )

        if self.hp.normalize_observations:
            # rms_state, next_obs = learner.update_and_normalize(collection_state.rms_state, next_obs)
            rms_state = _update_rms(collection_state.rms_state, obs=next_obs, batched=True)
            next_obs = _normalize_obs(rms_state, obs=next_obs)

            collection_state = collection_state.replace(rms_state=rms_state)

        # Return updated runner state and transition
        transition = Trajectory(
            collection_state.last_obs, unclipped_action, log_prob, reward, value, done
        )
        collection_state = collection_state.replace(
            env_state=env_state,
            last_obs=next_obs,
            last_done=done,
            global_step=collection_state.global_step + self.hp.num_envs,
            rng=rng,
        )
        return collection_state, transition

    @property
    def discrete(self) -> bool:
        return isinstance(
            self.env.action_space(self.env_params), gymnax.environments.spaces.Discrete
        )

    def visualize(self, ts: PPOState, gif_path: str | Path, eval_rng: chex.PRNGKey | None = None):
        actor = make_actor(ts=ts, hp=self.hp)
        render_episode(
            actor=actor,
            env=self.env,
            env_params=self.env_params,
            gif_path=Path(gif_path),
            rng=eval_rng if eval_rng is not None else ts.rng,
        )

    ## These here aren't currently used. They are here to mirror rejax.PPO where the training loop
    # is in the algorithm.

    @functools.partial(jit, static_argnames=["skip_initial_evaluation"])
    def train(
        self,
        rng: jax.Array,
        train_state: PPOState[TEnvState] | None = None,
        skip_initial_evaluation: bool = False,
    ) -> tuple[PPOState[TEnvState], EvalMetrics]:
        """Full training loop in jax.

        This is only here to match the API of `rejax.PPO.train`. This doesn't get called when using
        the `JaxTrainer`, since `JaxTrainer.fit` already does the same thing, but also with support
        for some `JaxCallback`s (as well as some `lightning.Callback`s!).

        Unfolded version of `rejax.PPO.train`.
        """
        if train_state is None and rng is None:
            raise ValueError("Either train_state or rng must be provided")

        ts = train_state if train_state is not None else self.init_train_state(rng)

        initial_evaluation: EvalMetrics | None = None
        if not skip_initial_evaluation:
            initial_evaluation = self.eval_callback(ts)

        num_evals = np.ceil(self.hp.total_timesteps / self.hp.eval_freq).astype(int)
        ts, evaluation = jax.lax.scan(
            self._training_epoch,
            init=ts,
            xs=None,
            length=num_evals,
        )

        if not skip_initial_evaluation:
            assert initial_evaluation is not None
            evaluation = jax.tree.map(
                lambda i, ev: jnp.concatenate((jnp.expand_dims(i, 0), ev)),
                initial_evaluation,
                evaluation,
            )
            assert isinstance(evaluation, EvalMetrics)

        return ts, evaluation

    # @jit
    def _training_epoch(
        self, ts: PPOState[TEnvState], epoch: int
    ) -> tuple[PPOState[TEnvState], EvalMetrics]:
        # Run a few training iterations
        iteration_steps = self.hp.num_envs * self.hp.num_steps
        num_iterations = np.ceil(self.hp.eval_freq / iteration_steps).astype(int)
        ts = jax.lax.fori_loop(
            0,
            num_iterations,
            # drop metrics for now
            lambda i, train_state_i: self._fused_training_step(i, train_state_i)[0],
            ts,
        )
        # Run evaluation
        return ts, self.eval_callback(ts)

    # @jit
    def _fused_training_step(self, iteration: int, ts: PPOState[TEnvState]):
        """Fused training step in jax (joined data collection + training).

        This is the equivalent of the training step from rejax.PPO. It is only used in tests to
        verify the correctness of the training step.
        """

        data_collection_state, trajectories = self.collect_trajectories(
            # env=self.env,
            # env_params=self.env_params,
            # actor=self.actor,
            # critic=self.critic,
            collection_state=ts.data_collection_state,
            actor_params=ts.actor_ts.params,
            critic_params=ts.critic_ts.params,
            # num_envs=self.hp.num_envs,
            # num_steps=self.hp.num_steps,
            # discrete=discrete,
            # normalize_observations=self.hp.normalize_observations,
        )
        ts = ts.replace(data_collection_state=data_collection_state)
        return self.training_step(iteration, ts, trajectories)

JaxModule#

The JaxModule class is made to look a bit like the lightning.LightningModule class:

@runtime_checkable
class JaxModule(Protocol[Ts, _B, _MetricsT]):
    """A protocol for algorithms that can be trained by the `JaxTrainer`.

    The `JaxRLExample` is an example that follows this structure and can be trained with a
    `JaxTrainer`.
    """

    def init_train_state(self, rng: chex.PRNGKey) -> Ts:
        """Create the initial training state."""
        raise NotImplementedError

    def get_batch(self, ts: Ts, batch_idx: int) -> tuple[Ts, _B]:
        """Produces a batch of data."""
        raise NotImplementedError

    def training_step(
        self, batch_idx: int, ts: Ts, batch: _B
    ) -> tuple[Ts, flax.struct.PyTreeNode]:
        """Update the training state using a "batch" of data."""
        raise NotImplementedError

    def eval_callback(self, ts: Ts) -> _MetricsT:
        """Perform evaluation and return metrics."""
        raise NotImplementedError

JaxTrainer#

The JaxTrainer follows a roughly similar structure as the lightning.Trainer: - JaxTrainer.fit is called with a JaxModule to train the algorithm.

Click to show the code for JaxTrainer
class JaxTrainer(flax.struct.PyTreeNode):
    """A simplified version of the `lightning.Trainer` with a fully jitted training loop.

    ## Assumptions:

    - The algo object must match the `JaxModule` protocol (in other words, it should implement its
      methods).

    ## Training loop

    This is the training loop, which is fully jitted:

    ```python
    ts = algo.init_train_state(rng)

    setup("fit")
    on_fit_start()
    on_train_start()

    eval_metrics = []
    for epoch in range(self.max_epochs):
        on_train_epoch_start()

        for step in range(self.training_steps_per_epoch):

            batch = algo.get_batch(ts, step)

            on_train_batch_start()

            ts, metrics = algo.training_step(step, ts, batch)

            on_train_batch_end()

        on_train_epoch_end()

        # Evaluation "loop"
        on_validation_epoch_start()
        epoch_eval_metrics = self.eval_epoch(ts, epoch, algo)
        on_validation_epoch_start()

        eval_metrics.append(epoch_eval_metrics)

    return ts, eval_metrics
    ```

    ## Caveats

    - Some lightning callbacks can be used with this trainer and work well, but not all of them.
    - You can either use Regular pytorch-lightning callbacks, or use `jax.vmap` on the `fit` method,
      but not both.
      - If you want to use [jax.vmap][] on the `fit` method, just remove the callbacks on the
        Trainer for now.

    ## TODOs / ideas

    - Add a checkpoint callback with orbax-checkpoint?
    """

    max_epochs: int = flax.struct.field(pytree_node=False)

    training_steps_per_epoch: int = flax.struct.field(pytree_node=False)

    limit_val_batches: int = 0
    limit_test_batches: int = 0

    # TODO: Getting some errors with the schema generation for lightning.Callback and
    # lightning.pytorch.loggers.logger.Logger here if we keep the type annotation.
    callbacks: Sequence = dataclasses.field(metadata={"pytree_node": False}, default_factory=tuple)

    logger: Any | None = flax.struct.field(pytree_node=False, default=None)

    # accelerator: str = flax.struct.field(pytree_node=False, default="auto")
    # strategy: str = flax.struct.field(pytree_node=False, default="auto")
    # devices: int | str = flax.struct.field(pytree_node=False, default="auto")

    # path to output directory, created dynamically by hydra
    # path generation pattern is specified in `configs/hydra/default.yaml`
    # use it to store all files generated during the run, like checkpoints and metrics

    default_root_dir: str | Path | None = flax.struct.field(
        pytree_node=False,
        default_factory=lambda: HydraConfig.get().runtime.output_dir,
    )

    # State variables:
    # TODO: figure out how to cleanly store / update these.
    current_epoch: int = flax.struct.field(pytree_node=True, default=0)
    global_step: int = flax.struct.field(pytree_node=True, default=0)

    logged_metrics: dict = flax.struct.field(pytree_node=True, default_factory=dict)
    callback_metrics: dict = flax.struct.field(pytree_node=True, default_factory=dict)
    # todo: get the metrics from the callbacks?
    # lightning.pytorch.loggers.CSVLogger.log_metrics
    # TODO: Take a look at this method:
    # lightning.pytorch.callbacks.progress.rich_progress.RichProgressBar.get_metrics
    # return lightning.Trainer._logger_connector.progress_bar_metrics
    progress_bar_metrics: dict = flax.struct.field(pytree_node=True, default_factory=dict)

    verbose: bool = flax.struct.field(pytree_node=False, default=False)

    @functools.partial(jit, static_argnames=["skip_initial_evaluation"])
    def fit(
        self,
        algo: JaxModule[Ts, _B, _MetricsT],
        rng: chex.PRNGKey,
        train_state: Ts | None = None,
        skip_initial_evaluation: bool = False,
    ) -> tuple[Ts, _MetricsT]:
        """Full training loop in pure jax (a lot faster than when using pytorch-lightning).

        Unfolded version of `rejax.PPO.train`.

        Training loop in pure jax (a lot faster than when using pytorch-lightning).
        """

        if train_state is None and rng is None:
            raise ValueError("Either train_state or rng must be provided")

        train_state = train_state if train_state is not None else algo.init_train_state(rng)

        if self.progress_bar_callback is not None:
            if self.verbose:
                jax.debug.print("Enabling the progress bar callback.")
            jax.experimental.io_callback(self.progress_bar_callback.enable, ())

        self._callback_hook("setup", self, algo, ts=train_state, partial_kwargs=dict(stage="fit"))
        self._callback_hook("on_fit_start", self, algo, ts=train_state)
        self._callback_hook("on_train_start", self, algo, ts=train_state)

        if self.logger:
            jax.experimental.io_callback(
                lambda algo: self.logger and self.logger.log_hyperparams(hparams_to_dict(algo)),
                (),
                algo,
                ordered=True,
            )

        initial_evaluation: _MetricsT | None = None
        if not skip_initial_evaluation:
            initial_evaluation = algo.eval_callback(train_state)

        # Run the epoch loop `self.max_epoch` times.
        train_state, evaluations = jax.lax.scan(
            functools.partial(self.epoch_loop, algo=algo),
            init=train_state,
            xs=jnp.arange(self.max_epochs),  # type: ignore
            length=self.max_epochs,
        )

        if not skip_initial_evaluation:
            assert initial_evaluation is not None
            evaluations: _MetricsT = jax.tree.map(
                lambda i, ev: jnp.concatenate((jnp.expand_dims(i, 0), ev)),
                initial_evaluation,
                evaluations,
            )

        if self.logger is not None:
            jax.block_until_ready((train_state, evaluations))
            # jax.debug.print("Saving...")
            jax.experimental.io_callback(
                functools.partial(self.logger.finalize, status="success"), ()
            )

        self._callback_hook("on_fit_end", self, algo, ts=train_state)
        self._callback_hook("on_train_end", self, algo, ts=train_state)
        self._callback_hook(
            "teardown", self, algo, ts=train_state, partial_kwargs={"stage": "fit"}
        )

        return train_state, evaluations

    # @jit
    def epoch_loop(self, ts: Ts, epoch: int, algo: JaxModule[Ts, _B, _MetricsT]):
        # todo: Some lightning callbacks try to get the "trainer.current_epoch".
        # FIXME: Hacky: Present a trainer with a different value of `self.current_epoch` to
        # the callbacks.
        # chex.assert_scalar_in(epoch, 0, self.max_epochs)
        # TODO: Can't just set current_epoch to `epoch` as `epoch` is a Traced value.
        # todo: need to have the callback take in the actual int value.
        # jax.debug.print("Starting epoch {epoch}", epoch=epoch)

        self = self.replace(current_epoch=epoch)  # doesn't quite work?
        ts = self.training_epoch(ts=ts, epoch=epoch, algo=algo)
        eval_metrics = self.eval_epoch(ts=ts, epoch=epoch, algo=algo)
        return ts, eval_metrics

    # @jit
    def training_epoch(self, ts: Ts, epoch: int, algo: JaxModule[Ts, _B, _MetricsT]):
        # Run a few training iterations
        self._callback_hook("on_train_epoch_start", self, algo, ts=ts)

        ts = jax.lax.fori_loop(
            0,
            self.training_steps_per_epoch,
            # drop training metrics for now.
            functools.partial(self.training_step, algo=algo),
            ts,
        )

        self._callback_hook("on_train_epoch_end", self, algo, ts=ts)
        return ts

    # @jit
    def eval_epoch(self, ts: Ts, epoch: int, algo: JaxModule[Ts, _B, _MetricsT]):
        self._callback_hook("on_validation_epoch_start", self, algo, ts=ts)

        # todo: split up into eval batch and eval step?
        eval_metrics = algo.eval_callback(ts=ts)

        self._callback_hook("on_validation_epoch_end", self, algo, ts=ts)

        return eval_metrics

    # @jit
    def training_step(self, batch_idx: int, ts: Ts, algo: JaxModule[Ts, _B, _MetricsT]):
        """Training step in pure jax (joined data collection + training).

        *MUCH* faster than using pytorch-lightning, but you lose the callbacks and such.
        """
        # todo: rename to `get_training_batch`?
        ts, batch = algo.get_batch(ts, batch_idx=batch_idx)

        self._callback_hook("on_train_batch_start", self, algo, batch, batch_idx, ts=ts)

        ts, metrics = algo.training_step(batch_idx=batch_idx, ts=ts, batch=batch)

        if self.logger is not None:
            # todo: Clean this up. logs metrics.
            jax.experimental.io_callback(
                lambda metrics, batch_index: self.logger
                and self.logger.log_metrics(
                    jax.tree.map(lambda v: v.mean(), metrics), batch_index
                ),
                (),
                dataclasses.asdict(metrics) if dataclasses.is_dataclass(metrics) else metrics,
                batch_idx,
            )

        self._callback_hook("on_train_batch_end", self, algo, metrics, batch, batch_idx, ts=ts)

        return ts

    ### Hooks to mimic those of lightning.Trainer

    def _callback_hook(
        self,
        hook_name: str,
        /,
        *hook_args,
        ts: Ts,
        partial_kwargs: dict | None = None,
        sharding: jax.sharding.SingleDeviceSharding | None = None,
        ordered: bool = True,
        **hook_kwargs,
    ):
        """Call a hook on all callbacks."""
        # with jax.disable_jit():
        for i, callback in enumerate(self.callbacks):
            # assert hasattr(callback, hook_name)

            method = getattr(callback, hook_name)
            if partial_kwargs:
                method = functools.partial(method, **partial_kwargs)
            if self.verbose:
                jax.debug.print(
                    "Epoch {current_epoch}/{max_epochs}: "
                    + f"Calling hook {hook_name} on callback {callback}"
                    + "{i}",
                    i=i,
                    current_epoch=self.current_epoch,
                    ordered=True,
                    max_epochs=self.max_epochs,
                )
            jax.experimental.io_callback(
                method,
                (),
                *hook_args,
                **({"ts": ts} if isinstance(callback, JaxCallback) else {}),
                **hook_kwargs,
                sharding=sharding,
                ordered=ordered if not isinstance(callback, JaxCallback) else False,
            )

    # Compat for RichProgressBar
    @property
    def is_global_zero(self) -> bool:
        return True

    @property
    def num_training_batches(self) -> int:
        return self.training_steps_per_epoch

    @property
    def loggers(self) -> list[lightning.pytorch.loggers.Logger]:
        if isinstance(self.logger, list | tuple):
            return list(self.logger)
        if self.logger is not None:
            return [self.logger]
        return []

    # @property
    # def progress_bar_metrics(self) -> dict[str, float]:

    #     return {}

    @property
    def progress_bar_callback(self) -> lightning.pytorch.callbacks.ProgressBar | None:
        for c in self.callbacks:
            if isinstance(c, lightning.pytorch.callbacks.ProgressBar):
                return c
        return None

    @property
    def state(self):
        from lightning.pytorch.trainer.states import (
            RunningStage,
            TrainerFn,
            TrainerState,
            TrainerStatus,
        )

        return TrainerState(
            fn=TrainerFn.FITTING,
            status=TrainerStatus.RUNNING,
            stage=RunningStage.TRAINING,
        )
        #     self._trainer.state.fn != "fit"
        #     or self._trainer.sanity_checking
        #     or self._trainer.progress_bar_callback.train_progress_bar_id != task.id
        # ):

    @property
    def sanity_checking(self) -> bool:
        from lightning.pytorch.trainer.states import RunningStage

        return self.state.stage == RunningStage.SANITY_CHECKING

    @property
    def training(self) -> bool:
        from lightning.pytorch.trainer.states import RunningStage

        return self.state.stage == RunningStage.TRAINING

    @property
    def log_dir(self) -> Path | None:
        # copied from lightning.Trainer
        if len(self.loggers) > 0:
            if not isinstance(
                self.loggers[0],
                lightning.pytorch.loggers.TensorBoardLogger | lightning.pytorch.loggers.CSVLogger,
            ):
                dirpath = self.loggers[0].save_dir
            else:
                dirpath = self.loggers[0].log_dir
        else:
            dirpath = self.default_root_dir
        if dirpath:
            return Path(dirpath)
        return None