Jax typing utils
Small typing helpers for Jax.
This makes jax.jit
preserve the signature of the wrapped callable.
field #
field(
*,
default: _T,
init: bool = True,
repr: bool = True,
hash: bool | None = None,
compare: bool = True,
metadata: Mapping[Any, Any] | None = None,
kw_only: bool = ...,
pytree_node: bool = True
) -> _T
field(
*,
default=MISSING,
default_factory=MISSING,
init=True,
repr=True,
hash=None,
compare=True,
metadata: Mapping[Any, Any] | None = None,
kw_only=MISSING,
pytree_node: bool | None = None
)
Small Typing fix for flax.struct.field
.
- Add type annotations so it doesn't drop the signature of the
dataclasses.field
function. - Make the
pytree_node
has a default value ofFalse
for ints and bools, andTrue
for everything else.