Skip to content

Layers

Branch #

Bases: Module, Module[P, dict[str, T]]

Module that executes each branch and returns a dictionary with the results of each.

Lambda #

Bases: Module, Module[..., OutT]

A simple nn.Module wrapping a function.

Any positional or keyword arguments passed to the constructor are saved into a args and kwargs attribute. During the forward pass, these arguments are then bound to the function f using a functools.partial. Any additional arguments to the forward method are then passed to the partial.

Merge #

Bases: Module, Module[[tuple[Tensor, ...] | dict[str, Tensor]], OutT]

Unpacks the output of the previous block (Branch) before it is fed to the wrapped module.

__init__ #

__init__(f: Module[..., OutT]) -> None

Unpacks the output of a previous block before it is fed to f.