Skip to content

Jax typing utils

Small typing helpers for Jax.

This makes jax.jit preserve the signature of the wrapped callable.

field #

field(
    *,
    default: _T,
    init: bool = True,
    repr: bool = True,
    hash: bool | None = None,
    compare: bool = True,
    metadata: Mapping[Any, Any] | None = None,
    kw_only: bool = ...,
    pytree_node: bool = True
) -> _T
field(
    *,
    default_factory: Callable[[], _T],
    init: bool = True,
    repr: bool = True,
    hash: bool | None = None,
    compare: bool = True,
    metadata: Mapping[Any, Any] | None = None,
    kw_only: bool = ...,
    pytree_node: bool = True
) -> _T
field(
    *,
    init: bool = True,
    repr: bool = True,
    hash: bool | None = None,
    compare: bool = True,
    metadata: Mapping[Any, Any] | None = None,
    kw_only: bool = ...,
    pytree_node: bool = True
) -> Any
field(
    *,
    default=MISSING,
    default_factory=MISSING,
    init=True,
    repr=True,
    hash=None,
    compare=True,
    metadata: Mapping[Any, Any] | None = None,
    kw_only=MISSING,
    pytree_node: bool | None = None
)

Small Typing fix for flax.struct.field.

  • Add type annotations so it doesn't drop the signature of the dataclasses.field function.
  • Make the pytree_node has a default value of False for ints and bools, and True for everything else.