Skip to content

Llm finetuning test

Unit tests for the llm finetuning example.

TestLLMFinetuningExample #

Bases: LearningAlgorithmTests[LLMFinetuningExample]

train_dataloader #

train_dataloader(
    algorithm: LLMFinetuningExample,
    request: FixtureRequest,
    trainer: Trainer,
) -> DataLoader

Fixture that creates and returns the training dataloader.

NOTE: Here we're purpusefully redefining the project.conftest.train_dataloader fixture because it assumes that the algorithm uses a datamodule. Here we change the fixture scope.

forward_pass_input #

forward_pass_input(
    training_batch: PyTree[Tensor], device: device
)

Extracts the model input from a batch of data coming from the dataloader.

Overwrite this if your batches are not tuples of tensors (i.e. if your algorithm isn't a simple supervised learning algorithm like the example).