Reinforcement Learning in Jax#
This example follows the same structure as the other examples:
- An "algorithm" (in this case
JaxRLExample
) is trained with a "trainer" (JaxTrainer
);
However, there are some very important differences:
- There is no "datamodule". The algorithm accepts an Environment (
gymnax.Environment
) as input. - The "Trainer" is a
JaxTrainer
, instead of alightning.Trainer
. - The full training loop is written in Jax;
- Some (but not all) PyTorch-Lightning callbacks can still be used with the JaxTrainer;
- The
JaxRLExample
class is an algorithm based on rejax.PPO.
JaxRLExample#
The JaxRLExample
is based on rejax.PPO.
It follows the structure of a JaxModule
, and is trained with a JaxTrainer
.
Click to show the code for JaxRLExample
class JaxRLExample(
flax.struct.PyTreeNode,
JaxModule[PPOState[TEnvState], TrajectoryWithLastObs, EvalMetrics],
Generic[TEnvState, TEnvParams],
):
"""Example of an RL algorithm written in Jax: PPO, based on `rejax.PPO`.
## Differences w.r.t. rejax.PPO:
- The state / hparams are split into different, fully-typed structs:
- The algorithm state is in a typed `PPOState` struct (vs an untyped,
dynamically-generated struct in rejax).
- The hyper-parameters are in a typed `PPOHParams` struct.
- The state variables related to the collection of data from the environment is a
`TrajectoryCollectionState` instead of everything being bunched up together.
- This makes it easier to call the `collect_episodes` function with just what it needs.
- The seeds for the networks and the environment data collection are separated.
The logic is exactly the same: The losses / updates are computed in the exact same way.
"""
env: Environment[TEnvState, TEnvParams] = flax.struct.field(pytree_node=False)
env_params: TEnvParams
actor: flax.linen.Module = flax.struct.field(pytree_node=False)
critic: flax.linen.Module = flax.struct.field(pytree_node=False)
hp: PPOHParams
@classmethod
def create(
cls,
env_id: str | None = None,
env: Environment[TEnvState, TEnvParams] | None = None,
env_params: TEnvParams | None = None,
hp: PPOHParams | None = None,
) -> JaxRLExample[TEnvState, TEnvParams]:
from brax.envs import _envs as brax_envs
from rejax.compat.brax2gymnax import create_brax
# env_params: gymnax.EnvParams
if env_id is None:
assert env is not None
env_params = env_params or env.default_params # type: ignore
elif env_id in brax_envs:
env, env_params = create_brax( # type: ignore
env_id,
episode_length=1000,
action_repeat=1,
auto_reset=True,
batch_size=None,
backend="generalized",
)
elif isinstance(env_id, str):
env, env_params = gymnax.make(env_id=env_id) # type: ignore
else:
raise NotImplementedError(env_id)
assert env is not None
assert env_params is not None
return cls(
env=env,
env_params=env_params,
actor=cls.create_actor(env, env_params),
critic=cls.create_critic(),
hp=hp or PPOHParams(),
)
@classmethod
def create_networks(
cls,
env: Environment[gymnax.EnvState, TEnvParams],
env_params: TEnvParams,
config: _NetworkConfig,
):
# Equivalent to:
# return rejax.PPO.create_agent(config, env, env_params)
return {
"actor": cls.create_actor(env, env_params, **config["agent_kwargs"]),
"critic": cls.create_actor(env, env_params, **config["agent_kwargs"]),
}
_TEnvParams = TypeVar("_TEnvParams", bound=gymnax.EnvParams, covariant=True)
_TEnvState = TypeVar("_TEnvState", bound=gymnax.EnvState, covariant=True)
@classmethod
def create_actor(
cls,
env: Environment[_TEnvState, _TEnvParams],
env_params: _TEnvParams,
activation: str | Callable[[jax.Array], jax.Array] = "swish",
hidden_layer_sizes: Sequence[int] = (64, 64),
**actor_kwargs,
) -> DiscretePolicy | GaussianPolicy:
activation_fn: Callable[[jax.Array], jax.Array] = (
getattr(flax.linen, activation) if not callable(activation) else activation
)
hidden_layer_sizes = tuple(hidden_layer_sizes)
action_space = env.action_space(env_params)
if isinstance(action_space, gymnax.environments.spaces.Discrete):
return DiscretePolicy(
action_space.n,
activation=activation_fn,
hidden_layer_sizes=hidden_layer_sizes,
**actor_kwargs,
)
assert isinstance(action_space, gymnax.environments.spaces.Box)
return GaussianPolicy(
np.prod(action_space.shape),
(action_space.low, action_space.high), # type: ignore
activation=activation_fn,
hidden_layer_sizes=hidden_layer_sizes,
**actor_kwargs,
)
@classmethod
def create_critic(
cls,
activation: str | Callable[[jax.Array], jax.Array] = "swish",
hidden_layer_sizes: Sequence[int] = (64, 64),
**critic_kwargs,
) -> VNetwork:
activation_fn: Callable[[jax.Array], jax.Array] = (
getattr(flax.linen, activation) if isinstance(activation, str) else activation
)
hidden_layer_sizes = tuple(hidden_layer_sizes)
return VNetwork(
hidden_layer_sizes=hidden_layer_sizes, activation=activation_fn, **critic_kwargs
)
def init_train_state(self, rng: chex.PRNGKey) -> PPOState[TEnvState]:
rng, networks_rng, env_rng = jax.random.split(rng, 3)
rng_actor, rng_critic = jax.random.split(networks_rng, 2)
obs_ph = jnp.empty([1, *self.env.observation_space(self.env_params).shape])
actor_params = self.actor.init(rng_actor, obs_ph, rng_actor)
critic_params = self.critic.init(rng_critic, obs_ph)
tx = optax.adam(learning_rate=self.hp.learning_rate)
# TODO: Why isn't the `apply_fn` not set in rejax?
actor_ts = TrainState.create(apply_fn=self.actor.apply, params=actor_params, tx=tx)
critic_ts = TrainState.create(apply_fn=self.critic.apply, params=critic_params, tx=tx)
env_rng, reset_rng = jax.random.split(env_rng)
obs, env_state = jax.vmap(self.env.reset, in_axes=(0, None))(
jax.random.split(reset_rng, self.hp.num_envs), self.env_params
)
collection_state = TrajectoryCollectionState(
last_obs=obs,
rms_state=RMSState.create(shape=obs_ph.shape),
global_step=0,
env_state=env_state,
last_done=jnp.zeros(self.hp.num_envs, dtype=bool),
rng=env_rng,
)
return PPOState(
actor_ts=actor_ts,
critic_ts=critic_ts,
rng=rng,
data_collection_state=collection_state,
)
# @jit
def training_step(self, batch_idx: int, ts: PPOState[TEnvState], batch: TrajectoryWithLastObs):
"""Training step in pure jax."""
trajectories = batch
ts, (actor_losses, critic_losses) = jax.lax.scan(
functools.partial(self.ppo_update_epoch, trajectories=trajectories),
init=ts,
xs=jnp.arange(self.hp.num_epochs), # type: ignore
length=self.hp.num_epochs,
)
# todo: perhaps we could have a callback that updates a progress bar?
# jax.debug.print("actor_losses {}: {}", iteration, actor_losses.mean())
# jax.debug.print("critic_losses {}: {}", iteration, critic_losses.mean())
return ts, TrainStepMetrics(actor_losses=actor_losses, critic_losses=critic_losses)
# @jit
def ppo_update_epoch(
self, ts: PPOState[TEnvState], epoch_index: int, trajectories: TrajectoryWithLastObs
):
minibatch_rng = jax.random.fold_in(ts.rng, epoch_index)
last_val = self.critic.apply(ts.critic_ts.params, ts.data_collection_state.last_obs)
assert isinstance(last_val, jax.Array)
last_val = jnp.where(ts.data_collection_state.last_done, 0, last_val)
advantages, targets = calculate_gae(
trajectories, last_val, gamma=self.hp.gamma, gae_lambda=self.hp.gae_lambda
)
batch = AdvantageMinibatch(trajectories.trajectories, advantages, targets)
minibatches = shuffle_and_split(
batch, minibatch_rng, num_minibatches=self.hp.num_minibatches
)
# shuffle the data and split it into minibatches
num_steps = self.hp.num_steps
num_envs = self.hp.num_envs
num_minibatches = self.hp.num_minibatches
assert (num_envs * num_steps) % num_minibatches == 0
minibatches = shuffle_and_split(
batch,
minibatch_rng,
num_minibatches=num_minibatches,
)
return jax.lax.scan(self.ppo_update, ts, minibatches, length=self.hp.num_minibatches)
# @jit
def ppo_update(self, ts: PPOState[TEnvState], batch: AdvantageMinibatch):
actor_loss, actor_grads = jax.value_and_grad(actor_loss_fn)(
ts.actor_ts.params,
actor=self.actor,
batch=batch,
clip_eps=self.hp.clip_eps,
ent_coef=self.hp.ent_coef,
)
assert isinstance(actor_loss, jax.Array)
critic_loss, critic_grads = jax.value_and_grad(critic_loss_fn)(
ts.critic_ts.params,
critic=self.critic,
batch=batch,
clip_eps=self.hp.clip_eps,
vf_coef=self.hp.vf_coef,
)
assert isinstance(critic_loss, jax.Array)
# TODO: to log the loss here?
actor_ts = ts.actor_ts.apply_gradients(grads=actor_grads)
critic_ts = ts.critic_ts.apply_gradients(grads=critic_grads)
return ts.replace(actor_ts=actor_ts, critic_ts=critic_ts), (actor_loss, critic_loss)
def eval_callback(
self, ts: PPOState[TEnvState], rng: chex.PRNGKey | None = None
) -> EvalMetrics:
if rng is None:
rng = ts.rng
actor = make_actor(ts=ts, hp=self.hp)
ep_lengths, cum_rewards = evaluate(
actor,
ts.rng,
self.env,
self.env_params,
num_seeds=self.hp.num_seeds_per_eval,
max_steps_in_episode=self.env_params.max_steps_in_episode,
)
return EvalMetrics(episode_length=ep_lengths, cumulative_reward=cum_rewards)
def get_batch(
self, ts: PPOState[TEnvState], batch_idx: int
) -> tuple[PPOState[TEnvState], TrajectoryWithLastObs]:
data_collection_state, trajectories = self.collect_trajectories(
ts.data_collection_state,
actor_params=ts.actor_ts.params,
critic_params=ts.critic_ts.params,
)
ts = ts.replace(data_collection_state=data_collection_state)
return ts, trajectories
# @jit
def collect_trajectories(
self,
collection_state: TrajectoryCollectionState[TEnvState],
actor_params: FrozenVariableDict,
critic_params: FrozenVariableDict,
):
env_step_fn = functools.partial(
self.env_step,
# env=self.env,
# env_params=self.env_params,
# actor=self.actor,
# critic=self.critic,
# num_envs=self.hp.num_envs,
actor_params=actor_params,
critic_params=critic_params,
# discrete=self.discrete,
# normalize_observations=self.hp.normalize_observations,
)
collection_state, trajectories = jax.lax.scan(
env_step_fn,
collection_state,
xs=jnp.arange(self.hp.num_steps),
length=self.hp.num_steps,
)
trajectories_with_last = TrajectoryWithLastObs(
trajectories=trajectories,
last_done=collection_state.last_done,
last_obs=collection_state.last_obs,
)
return collection_state, trajectories_with_last
# @jit
def env_step(
self,
collection_state: TrajectoryCollectionState[TEnvState],
step_index: jax.Array,
actor_params: FrozenVariableDict,
critic_params: FrozenVariableDict,
):
# Get keys for sampling action and stepping environment
# doing it this way to try to get *exactly* the same rngs as in rejax.PPO.
rng, new_rngs = jax.random.split(collection_state.rng, 2)
rng_steps, rng_action = jax.random.split(new_rngs, 2)
rng_steps = jax.random.split(rng_steps, self.hp.num_envs)
# Sample action
unclipped_action, log_prob = self.actor.apply(
actor_params, collection_state.last_obs, rng_action, method="action_log_prob"
)
assert isinstance(log_prob, jax.Array)
value = self.critic.apply(critic_params, collection_state.last_obs)
assert isinstance(value, jax.Array)
# Clip action
if self.discrete:
action = unclipped_action
else:
low = self.env.action_space(self.env_params).low
high = self.env.action_space(self.env_params).high
action = jnp.clip(unclipped_action, low, high)
# Step environment
next_obs, env_state, reward, done, _ = jax.vmap(self.env.step, in_axes=(0, 0, 0, None))(
rng_steps,
collection_state.env_state,
action,
self.env_params,
)
if self.hp.normalize_observations:
# rms_state, next_obs = learner.update_and_normalize(collection_state.rms_state, next_obs)
rms_state = _update_rms(collection_state.rms_state, obs=next_obs, batched=True)
next_obs = _normalize_obs(rms_state, obs=next_obs)
collection_state = collection_state.replace(rms_state=rms_state)
# Return updated runner state and transition
transition = Trajectory(
collection_state.last_obs, unclipped_action, log_prob, reward, value, done
)
collection_state = collection_state.replace(
env_state=env_state,
last_obs=next_obs,
last_done=done,
global_step=collection_state.global_step + self.hp.num_envs,
rng=rng,
)
return collection_state, transition
@property
def discrete(self) -> bool:
return isinstance(
self.env.action_space(self.env_params), gymnax.environments.spaces.Discrete
)
def visualize(self, ts: PPOState, gif_path: str | Path, eval_rng: chex.PRNGKey | None = None):
actor = make_actor(ts=ts, hp=self.hp)
render_episode(
actor=actor,
env=self.env,
env_params=self.env_params,
gif_path=Path(gif_path),
rng=eval_rng if eval_rng is not None else ts.rng,
)
## These here aren't currently used. They are here to mirror rejax.PPO where the training loop
# is in the algorithm.
@functools.partial(jit, static_argnames=["skip_initial_evaluation"])
def train(
self,
rng: jax.Array,
train_state: PPOState[TEnvState] | None = None,
skip_initial_evaluation: bool = False,
) -> tuple[PPOState[TEnvState], EvalMetrics]:
"""Full training loop in jax.
This is only here to match the API of `rejax.PPO.train`. This doesn't get called when using
the `JaxTrainer`, since `JaxTrainer.fit` already does the same thing, but also with support
for some `JaxCallback`s (as well as some `lightning.Callback`s!).
Unfolded version of `rejax.PPO.train`.
"""
if train_state is None and rng is None:
raise ValueError("Either train_state or rng must be provided")
ts = train_state if train_state is not None else self.init_train_state(rng)
initial_evaluation: EvalMetrics | None = None
if not skip_initial_evaluation:
initial_evaluation = self.eval_callback(ts)
num_evals = np.ceil(self.hp.total_timesteps / self.hp.eval_freq).astype(int)
ts, evaluation = jax.lax.scan(
self._training_epoch,
init=ts,
xs=None,
length=num_evals,
)
if not skip_initial_evaluation:
assert initial_evaluation is not None
evaluation = jax.tree.map(
lambda i, ev: jnp.concatenate((jnp.expand_dims(i, 0), ev)),
initial_evaluation,
evaluation,
)
assert isinstance(evaluation, EvalMetrics)
return ts, evaluation
# @jit
def _training_epoch(
self, ts: PPOState[TEnvState], epoch: int
) -> tuple[PPOState[TEnvState], EvalMetrics]:
# Run a few training iterations
iteration_steps = self.hp.num_envs * self.hp.num_steps
num_iterations = np.ceil(self.hp.eval_freq / iteration_steps).astype(int)
ts = jax.lax.fori_loop(
0,
num_iterations,
# drop metrics for now
lambda i, train_state_i: self._fused_training_step(i, train_state_i)[0],
ts,
)
# Run evaluation
return ts, self.eval_callback(ts)
# @jit
def _fused_training_step(self, iteration: int, ts: PPOState[TEnvState]):
"""Fused training step in jax (joined data collection + training).
This is the equivalent of the training step from rejax.PPO. It is only used in tests to
verify the correctness of the training step.
"""
data_collection_state, trajectories = self.collect_trajectories(
# env=self.env,
# env_params=self.env_params,
# actor=self.actor,
# critic=self.critic,
collection_state=ts.data_collection_state,
actor_params=ts.actor_ts.params,
critic_params=ts.critic_ts.params,
# num_envs=self.hp.num_envs,
# num_steps=self.hp.num_steps,
# discrete=discrete,
# normalize_observations=self.hp.normalize_observations,
)
ts = ts.replace(data_collection_state=data_collection_state)
return self.training_step(iteration, ts, trajectories)
JaxModule#
The JaxModule
class is made to look a bit like the lightning.LightningModule
class:
@runtime_checkable
class JaxModule(Protocol[Ts, _B, _MetricsT]):
"""A protocol for algorithms that can be trained by the `JaxTrainer`.
The `JaxRLExample` is an example that follows this structure and can be trained with a
`JaxTrainer`.
"""
def init_train_state(self, rng: chex.PRNGKey) -> Ts:
"""Create the initial training state."""
raise NotImplementedError
def get_batch(self, ts: Ts, batch_idx: int) -> tuple[Ts, _B]:
"""Produces a batch of data."""
raise NotImplementedError
def training_step(
self, batch_idx: int, ts: Ts, batch: _B
) -> tuple[Ts, flax.struct.PyTreeNode]:
"""Update the training state using a "batch" of data."""
raise NotImplementedError
def eval_callback(self, ts: Ts) -> _MetricsT:
"""Perform evaluation and return metrics."""
raise NotImplementedError
JaxTrainer#
The JaxTrainer
follows a roughly similar structure as the lightning.Trainer
:
- JaxTrainer.fit
is called with a JaxModule
to train the algorithm.
Click to show the code for JaxTrainer
class JaxTrainer(flax.struct.PyTreeNode):
"""A simplified version of the `lightning.Trainer` with a fully jitted training loop.
## Assumptions:
- The algo object must match the `JaxModule` protocol (in other words, it should implement its
methods).
## Training loop
This is the training loop, which is fully jitted:
```python
ts = algo.init_train_state(rng)
setup("fit")
on_fit_start()
on_train_start()
eval_metrics = []
for epoch in range(self.max_epochs):
on_train_epoch_start()
for step in range(self.training_steps_per_epoch):
batch = algo.get_batch(ts, step)
on_train_batch_start()
ts, metrics = algo.training_step(step, ts, batch)
on_train_batch_end()
on_train_epoch_end()
# Evaluation "loop"
on_validation_epoch_start()
epoch_eval_metrics = self.eval_epoch(ts, epoch, algo)
on_validation_epoch_start()
eval_metrics.append(epoch_eval_metrics)
return ts, eval_metrics
```
## Caveats
- Some lightning callbacks can be used with this trainer and work well, but not all of them.
- You can either use Regular pytorch-lightning callbacks, or use `jax.vmap` on the `fit` method,
but not both.
- If you want to use [jax.vmap][] on the `fit` method, just remove the callbacks on the
Trainer for now.
## TODOs / ideas
- Add a checkpoint callback with orbax-checkpoint?
"""
max_epochs: int = flax.struct.field(pytree_node=False)
training_steps_per_epoch: int = flax.struct.field(pytree_node=False)
limit_val_batches: int = 0
limit_test_batches: int = 0
# TODO: Getting some errors with the schema generation for lightning.Callback and
# lightning.pytorch.loggers.logger.Logger here if we keep the type annotation.
callbacks: Sequence = dataclasses.field(metadata={"pytree_node": False}, default_factory=tuple)
logger: Any | None = flax.struct.field(pytree_node=False, default=None)
# accelerator: str = flax.struct.field(pytree_node=False, default="auto")
# strategy: str = flax.struct.field(pytree_node=False, default="auto")
# devices: int | str = flax.struct.field(pytree_node=False, default="auto")
# path to output directory, created dynamically by hydra
# path generation pattern is specified in `configs/hydra/default.yaml`
# use it to store all files generated during the run, like checkpoints and metrics
default_root_dir: str | Path | None = flax.struct.field(
pytree_node=False,
default_factory=lambda: HydraConfig.get().runtime.output_dir,
)
# State variables:
# TODO: figure out how to cleanly store / update these.
current_epoch: int = flax.struct.field(pytree_node=True, default=0)
global_step: int = flax.struct.field(pytree_node=True, default=0)
logged_metrics: dict = flax.struct.field(pytree_node=True, default_factory=dict)
callback_metrics: dict = flax.struct.field(pytree_node=True, default_factory=dict)
# todo: get the metrics from the callbacks?
# lightning.pytorch.loggers.CSVLogger.log_metrics
# TODO: Take a look at this method:
# lightning.pytorch.callbacks.progress.rich_progress.RichProgressBar.get_metrics
# return lightning.Trainer._logger_connector.progress_bar_metrics
progress_bar_metrics: dict = flax.struct.field(pytree_node=True, default_factory=dict)
verbose: bool = flax.struct.field(pytree_node=False, default=False)
@functools.partial(jit, static_argnames=["skip_initial_evaluation"])
def fit(
self,
algo: JaxModule[Ts, _B, _MetricsT],
rng: chex.PRNGKey,
train_state: Ts | None = None,
skip_initial_evaluation: bool = False,
) -> tuple[Ts, _MetricsT]:
"""Full training loop in pure jax (a lot faster than when using pytorch-lightning).
Unfolded version of `rejax.PPO.train`.
Training loop in pure jax (a lot faster than when using pytorch-lightning).
"""
if train_state is None and rng is None:
raise ValueError("Either train_state or rng must be provided")
train_state = train_state if train_state is not None else algo.init_train_state(rng)
if self.progress_bar_callback is not None:
if self.verbose:
jax.debug.print("Enabling the progress bar callback.")
jax.experimental.io_callback(self.progress_bar_callback.enable, ())
self._callback_hook("setup", self, algo, ts=train_state, partial_kwargs=dict(stage="fit"))
self._callback_hook("on_fit_start", self, algo, ts=train_state)
self._callback_hook("on_train_start", self, algo, ts=train_state)
if self.logger:
jax.experimental.io_callback(
lambda algo: self.logger and self.logger.log_hyperparams(hparams_to_dict(algo)),
(),
algo,
ordered=True,
)
initial_evaluation: _MetricsT | None = None
if not skip_initial_evaluation:
initial_evaluation = algo.eval_callback(train_state)
# Run the epoch loop `self.max_epoch` times.
train_state, evaluations = jax.lax.scan(
functools.partial(self.epoch_loop, algo=algo),
init=train_state,
xs=jnp.arange(self.max_epochs), # type: ignore
length=self.max_epochs,
)
if not skip_initial_evaluation:
assert initial_evaluation is not None
evaluations: _MetricsT = jax.tree.map(
lambda i, ev: jnp.concatenate((jnp.expand_dims(i, 0), ev)),
initial_evaluation,
evaluations,
)
if self.logger is not None:
jax.block_until_ready((train_state, evaluations))
# jax.debug.print("Saving...")
jax.experimental.io_callback(
functools.partial(self.logger.finalize, status="success"), ()
)
self._callback_hook("on_fit_end", self, algo, ts=train_state)
self._callback_hook("on_train_end", self, algo, ts=train_state)
self._callback_hook(
"teardown", self, algo, ts=train_state, partial_kwargs={"stage": "fit"}
)
return train_state, evaluations
# @jit
def epoch_loop(self, ts: Ts, epoch: int, algo: JaxModule[Ts, _B, _MetricsT]):
# todo: Some lightning callbacks try to get the "trainer.current_epoch".
# FIXME: Hacky: Present a trainer with a different value of `self.current_epoch` to
# the callbacks.
# chex.assert_scalar_in(epoch, 0, self.max_epochs)
# TODO: Can't just set current_epoch to `epoch` as `epoch` is a Traced value.
# todo: need to have the callback take in the actual int value.
# jax.debug.print("Starting epoch {epoch}", epoch=epoch)
self = self.replace(current_epoch=epoch) # doesn't quite work?
ts = self.training_epoch(ts=ts, epoch=epoch, algo=algo)
eval_metrics = self.eval_epoch(ts=ts, epoch=epoch, algo=algo)
return ts, eval_metrics
# @jit
def training_epoch(self, ts: Ts, epoch: int, algo: JaxModule[Ts, _B, _MetricsT]):
# Run a few training iterations
self._callback_hook("on_train_epoch_start", self, algo, ts=ts)
ts = jax.lax.fori_loop(
0,
self.training_steps_per_epoch,
# drop training metrics for now.
functools.partial(self.training_step, algo=algo),
ts,
)
self._callback_hook("on_train_epoch_end", self, algo, ts=ts)
return ts
# @jit
def eval_epoch(self, ts: Ts, epoch: int, algo: JaxModule[Ts, _B, _MetricsT]):
self._callback_hook("on_validation_epoch_start", self, algo, ts=ts)
# todo: split up into eval batch and eval step?
eval_metrics = algo.eval_callback(ts=ts)
self._callback_hook("on_validation_epoch_end", self, algo, ts=ts)
return eval_metrics
# @jit
def training_step(self, batch_idx: int, ts: Ts, algo: JaxModule[Ts, _B, _MetricsT]):
"""Training step in pure jax (joined data collection + training).
*MUCH* faster than using pytorch-lightning, but you lose the callbacks and such.
"""
# todo: rename to `get_training_batch`?
ts, batch = algo.get_batch(ts, batch_idx=batch_idx)
self._callback_hook("on_train_batch_start", self, algo, batch, batch_idx, ts=ts)
ts, metrics = algo.training_step(batch_idx=batch_idx, ts=ts, batch=batch)
if self.logger is not None:
# todo: Clean this up. logs metrics.
jax.experimental.io_callback(
lambda metrics, batch_index: self.logger
and self.logger.log_metrics(
jax.tree.map(lambda v: v.mean(), metrics), batch_index
),
(),
dataclasses.asdict(metrics) if dataclasses.is_dataclass(metrics) else metrics,
batch_idx,
)
self._callback_hook("on_train_batch_end", self, algo, metrics, batch, batch_idx, ts=ts)
return ts
### Hooks to mimic those of lightning.Trainer
def _callback_hook(
self,
hook_name: str,
/,
*hook_args,
ts: Ts,
partial_kwargs: dict | None = None,
sharding: jax.sharding.SingleDeviceSharding | None = None,
ordered: bool = True,
**hook_kwargs,
):
"""Call a hook on all callbacks."""
# with jax.disable_jit():
for i, callback in enumerate(self.callbacks):
# assert hasattr(callback, hook_name)
method = getattr(callback, hook_name)
if partial_kwargs:
method = functools.partial(method, **partial_kwargs)
if self.verbose:
jax.debug.print(
"Epoch {current_epoch}/{max_epochs}: "
+ f"Calling hook {hook_name} on callback {callback}"
+ "{i}",
i=i,
current_epoch=self.current_epoch,
ordered=True,
max_epochs=self.max_epochs,
)
jax.experimental.io_callback(
method,
(),
*hook_args,
**({"ts": ts} if isinstance(callback, JaxCallback) else {}),
**hook_kwargs,
sharding=sharding,
ordered=ordered if not isinstance(callback, JaxCallback) else False,
)
# Compat for RichProgressBar
@property
def is_global_zero(self) -> bool:
return True
@property
def num_training_batches(self) -> int:
return self.training_steps_per_epoch
@property
def loggers(self) -> list[lightning.pytorch.loggers.Logger]:
if isinstance(self.logger, list | tuple):
return list(self.logger)
if self.logger is not None:
return [self.logger]
return []
# @property
# def progress_bar_metrics(self) -> dict[str, float]:
# return {}
@property
def progress_bar_callback(self) -> lightning.pytorch.callbacks.ProgressBar | None:
for c in self.callbacks:
if isinstance(c, lightning.pytorch.callbacks.ProgressBar):
return c
return None
@property
def state(self):
from lightning.pytorch.trainer.states import (
RunningStage,
TrainerFn,
TrainerState,
TrainerStatus,
)
return TrainerState(
fn=TrainerFn.FITTING,
status=TrainerStatus.RUNNING,
stage=RunningStage.TRAINING,
)
# self._trainer.state.fn != "fit"
# or self._trainer.sanity_checking
# or self._trainer.progress_bar_callback.train_progress_bar_id != task.id
# ):
@property
def sanity_checking(self) -> bool:
from lightning.pytorch.trainer.states import RunningStage
return self.state.stage == RunningStage.SANITY_CHECKING
@property
def training(self) -> bool:
from lightning.pytorch.trainer.states import RunningStage
return self.state.stage == RunningStage.TRAINING
@property
def log_dir(self) -> Path | None:
# copied from lightning.Trainer
if len(self.loggers) > 0:
if not isinstance(
self.loggers[0],
lightning.pytorch.loggers.TensorBoardLogger | lightning.pytorch.loggers.CSVLogger,
):
dirpath = self.loggers[0].save_dir
else:
dirpath = self.loggers[0].log_dir
else:
dirpath = self.default_root_dir
if dirpath:
return Path(dirpath)
return None