Skip to content

Text Classification (⚑ + πŸ€—)#

Overview#

The TextClassifier is a LightningModule for a simple text classification task.

It accepts a TextClassificationDataModule as input, along with a network.

Click to show the code of the lightningmodule
class TextClassifier(LightningModule):
    """Example of a lightning module used to train a huggingface model for text classification."""

    def __init__(
        self,
        datamodule: TextClassificationDataModule,
        network: HydraConfigFor[PreTrainedModel],
        hf_metric_name: str,
        learning_rate: float = 2e-5,
        adam_epsilon: float = 1e-8,
        warmup_steps: int = 0,
        weight_decay: float = 0.0,
        init_seed: int = 42,
    ):
        super().__init__()
        self.datamodule = datamodule
        self.network_config = network
        self.num_labels = datamodule.num_classes
        self.task_name = datamodule.task_name
        self.init_seed = init_seed
        self.hf_metric_name = hf_metric_name
        self.learning_rate = learning_rate
        self.adam_epsilon = adam_epsilon
        self.warmup_steps = warmup_steps
        self.weight_decay = weight_decay

        self.metric = evaluate.load(
            self.hf_metric_name,
            self.task_name,
            # todo: replace with hydra job id perhaps?
            experiment_id=datetime.now().strftime("%d-%m-%Y_%H-%M-%S"),
        )

        self.save_hyperparameters(ignore=["datamodule"])

    def configure_model(self) -> None:
        with torch.random.fork_rng(devices=[self.device]):
            # deterministic weight initialization
            torch.manual_seed(self.init_seed)
            self.network = hydra_zen.instantiate(self.network_config)

        return super().configure_model()

    def forward(self, inputs: dict[str, torch.Tensor]) -> BaseModelOutput:
        return self.network(**inputs)

    def shared_step(self, batch: dict[str, torch.Tensor], batch_idx: int, stage: str):
        outputs: CausalLMOutput | SequenceClassifierOutput = self(batch)
        loss = outputs.loss
        assert isinstance(loss, torch.Tensor), loss
        # todo: log the output of the metric.
        self.log(f"{stage}/loss", loss, prog_bar=True)
        if isinstance(outputs, SequenceClassifierOutput):
            metric_value = self.metric.compute(
                # logits=outputs.logits,
                predictions=outputs.logits.argmax(-1),
                references=batch["labels"],
            )
            assert isinstance(metric_value, dict)
            for k, v in metric_value.items():
                self.log(
                    f"{stage}/{k}",
                    v,
                    prog_bar=True,
                )
        return loss

    def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int):
        return self.shared_step(batch, batch_idx, "train")

    def validation_step(
        self, batch: dict[str, torch.Tensor], batch_idx: int, dataloader_idx: int = 0
    ):
        return self.shared_step(batch, batch_idx, "val")

    def configure_optimizers(self):
        """Prepare optimizer and schedule (linear warmup and decay)"""
        model = self.network
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p
                    for n, p in model.named_parameters()
                    if not any(nd_param in n for nd_param in no_decay)
                ],
                "weight_decay": self.weight_decay,
            },
            {
                "params": [
                    p
                    for n, p in model.named_parameters()
                    if any(nd_param in n for nd_param in no_decay)
                ],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=self.learning_rate,
            eps=self.adam_epsilon,
        )

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.warmup_steps,
            num_training_steps=self.trainer.estimated_stepping_batches,
        )
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
        return [optimizer], [scheduler]

Config files#

Algorithm config#

Click to show the Algorithm config

Source: project/configs/algorithm/text_classifier.yaml

# Config for the Text classification example algorithm
_target_: project.algorithms.text_classifier.TextClassifier
_recursive_: false
network:
  _target_: transformers.models.auto.modeling_auto.AutoModelForSequenceClassification.from_pretrained
  pretrained_model_name_or_path: albert-base-v2

# NOTE: Why _partial_? Because the config doesn't create the algo directly, it creates a function
# that will accept the datamodule and network and return the algo.
_partial_: true
hf_metric_name: glue

Datamodule config#

Click to show the Datamodule config

Source: project/configs/datamodule/glue_cola.yaml

_target_: project.datamodules.text.text_classification.TextClassificationDataModule
data_dir: ${oc.env:SCRATCH,.}/data
hf_dataset_path: glue
task_name: cola
text_fields:
  - "sentence"
tokenizer:
  _target_: transformers.models.auto.tokenization_auto.AutoTokenizer.from_pretrained
  use_fast: true
  # Note: We could interpolate this value with `${/algorithm/network/pretrained_model_name_or_path}`
  # to avoid duplicating a value, but this also makes it harder to use this by itself or with
  # another algorithm.
  pretrained_model_name_or_path: albert-base-v2
  cache_dir: ${..data_dir}
  trust_remote_code: true
num_classes: 2
max_seq_length: 128
train_batch_size: 32
eval_batch_size: 32

Running the example#

Here is a configuration file that you can use to launch a simple experiment:

Click to show the yaml config file

Source: project/configs/experiment/text_classification_example.yaml

# @package _global_
defaults:
  - override /algorithm: text_classifier
  - override /datamodule: glue_cola
  - override /trainer/callbacks: none

trainer:
  min_epochs: 1
  max_epochs: 2
  limit_train_batches: 2
  limit_val_batches: 1
  num_sanity_val_steps: 0
  enable_checkpointing: False

You can use it like so:

python project/main.py experiment=text_classification_example