Cifar10
CIFAR10DataModule #
Bases: ImageClassificationDataModule
.. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2019/01/ Plot-of-a-Subset-of-Images-from-the-CIFAR-10-Dataset.png :width: 400 :alt: CIFAR-10
Specs
- 10 classes (1 per class)
- Each image is (3 x 32 x 32)
Standard CIFAR10, train, val, test splits and transforms
Transforms::
transforms = transform_lib.Compose([
transform_lib.ToImage(),
transform_lib.ToDtype(torch.float32, scale=True),
transform_lib.Normalize(
mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
std=[x / 255.0 for x in [63.0, 62.1, 66.7]]
)
])
Example::
from pl_bolts.datamodules import CIFAR10DataModule
dm = CIFAR10DataModule(PATH)
model = LitModel()
Trainer().fit(model, datamodule=dm)
Or you can set your own transforms
Example::
dm.train_transforms = ...
dm.test_transforms = ...
dm.val_transforms = ...