Skip to content

Cifar10

CIFAR10DataModule #

Bases: ImageClassificationDataModule

.. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2019/01/ Plot-of-a-Subset-of-Images-from-the-CIFAR-10-Dataset.png :width: 400 :alt: CIFAR-10

Specs
  • 10 classes (1 per class)
  • Each image is (3 x 32 x 32)

Standard CIFAR10, train, val, test splits and transforms

Transforms::

transforms = transform_lib.Compose([
    transform_lib.ToImage(),
    transform_lib.ToDtype(torch.float32, scale=True),
    transform_lib.Normalize(
        mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
        std=[x / 255.0 for x in [63.0, 62.1, 66.7]]
    )
])

Example::

from pl_bolts.datamodules import CIFAR10DataModule

dm = CIFAR10DataModule(PATH)
model = LitModel()

Trainer().fit(model, datamodule=dm)

Or you can set your own transforms

Example::

dm.train_transforms = ...
dm.test_transforms = ...
dm.val_transforms  = ...