Skip to content

NLP (PyTorch)#

Overview#

The HFExample is a LightningModule for a simple auto-regressive text generation task.

It accepts a HFDataModule as input, along with a network.

Click to show the code for HFExample
class HFExample(LightningModule):
    """Example of a lightning module used to train a huggingface model."""

    def __init__(
        self,
        datamodule: HFDataModule,
        network: PreTrainedModel,
        hf_metric_name: str,
        learning_rate: float = 2e-5,
        adam_epsilon: float = 1e-8,
        warmup_steps: int = 0,
        weight_decay: float = 0.0,
        **kwargs,
    ):
        super().__init__()

        self.save_hyperparameters()
        self.num_labels = datamodule.num_labels
        self.task_name = datamodule.task_name
        self.network = network
        self.hf_metric_name = hf_metric_name
        self.metric = load_metric(
            self.hf_metric_name,
            self.task_name,
            experiment_id=datetime.now().strftime("%d-%m-%Y_%H-%M-%S"),
        )

        # Small fix for the `device` property in LightningModule, which is CPU by default.
        self._device = next((p.device for p in self.parameters()), torch.device("cpu"))

    def forward(
        self,
        input_ids: torch.Tensor,
        token_type_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels: torch.Tensor,
    ):
        return self.network(
            input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels
        )

    def model_step(self, batch: dict[str, torch.Tensor]):
        input_ids = batch["input_ids"]
        token_type_ids = batch["token_type_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]

        outputs = self.forward(input_ids, token_type_ids, attention_mask, labels)
        loss = outputs.loss
        logits = outputs.logits

        if self.num_labels > 1:
            preds = torch.argmax(logits, axis=1)
        else:
            preds = logits.squeeze()

        return loss, preds, labels

    def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int):
        loss, preds, labels = self.model_step(batch)
        self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        return {"loss": loss, "preds": preds, "labels": labels}

    def validation_step(
        self, batch: dict[str, torch.Tensor], batch_idx: int, dataloader_idx: int = 0
    ):
        val_loss, preds, labels = self.model_step(batch)
        self.log("val/loss", val_loss, on_step=False, on_epoch=True, prog_bar=True)
        return {"val/loss": val_loss, "preds": preds, "labels": labels}

    def configure_optimizers(self):
        """Prepare optimizer and schedule (linear warmup and decay)"""
        model = self.network
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p
                    for n, p in model.named_parameters()
                    if not any(nd_param in n for nd_param in no_decay)
                ],
                "weight_decay": self.hparams.weight_decay,
            },
            {
                "params": [
                    p
                    for n, p in model.named_parameters()
                    if any(nd_param in n for nd_param in no_decay)
                ],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=self.hparams.learning_rate,
            eps=self.hparams.adam_epsilon,
        )

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.hparams.warmup_steps,
            num_training_steps=self.trainer.estimated_stepping_batches,
        )
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
        return [optimizer], [scheduler]

Config files#

Algorithm config#

Click to show the Algorithm config

Source: project/configs/algorithm/hf_example.yaml

# Config for the JaxExample algorithm
defaults:
  - network: albert-base-v2.yaml
  # - /datamodule@_global_.datamodule: hf_text.yaml

_target_: project.algorithms.hf_example.HFExample
# NOTE: Why _partial_? Because the config doesn't create the algo directly, it creates a function
# that will accept the datamodule and network and return the algo.
_partial_: true
hf_metric_name: glue

Datamodule config#

Click to show the Datamodule config

Source: project/configs/datamodule/hf_text.yaml

_target_: project.datamodules.HFDataModule
tokenizer: albert-base-v2
hf_dataset_path: glue
task_name: cola
max_seq_length: 128
train_batch_size: 32
eval_batch_size: 32

Running the example#

Here is a configuration file that you can use to launch a simple experiment:

Click to show the yaml config file

Source: project/configs/experiment/hf_example.yaml

# @package _global_

defaults:
  - override /datamodule: hf_text
  - override /algorithm: hf_example
  - override /algorithm/network: albert-base-v2
  - override /trainer/callbacks: none

trainer:
  min_epochs: 1
  max_epochs: 2
  limit_train_batches: 2
  limit_val_batches: 1
  num_sanity_val_steps: 0
  enable_checkpointing: False

You can use it like so:

python project/main.py experiment=example