Datamodules
Datamodules (datasets + preprocessing + dataloading)
See the :ref:lightning.LightningDataModule
class for more information.
ImageClassificationDataModule #
Bases: VisionDataModule[BatchType]
, ClassificationDataModule[BatchType]
Lightning data modules for image classification.
CIFAR10DataModule #
Bases: ImageClassificationDataModule
, VisionDataModule
.. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2019/01/ Plot-of-a-Subset-of-Images-from-the-CIFAR-10-Dataset.png :width: 400 :alt: CIFAR-10
Specs
- 10 classes (1 per class)
- Each image is (3 x 32 x 32)
Standard CIFAR10, train, val, test splits and transforms
Transforms::
transforms = transform_lib.Compose([
transform_lib.ToImage(),
transform_lib.ToDtype(torch.float32, scale=True),
transform_lib.Normalize(
mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
std=[x / 255.0 for x in [63.0, 62.1, 66.7]]
)
])
Example::
from pl_bolts.datamodules import CIFAR10DataModule
dm = CIFAR10DataModule(PATH)
model = LitModel()
Trainer().fit(model, datamodule=dm)
Or you can set your own transforms
Example::
dm.train_transforms = ...
dm.test_transforms = ...
dm.val_transforms = ...
FashionMNISTDataModule #
Bases: MNISTDataModule
.. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/ wp-content/uploads/2019/02/Plot-of-a-Subset-of-Images-from-the-Fashion-MNIST-Dataset.png :width: 400 :alt: Fashion MNIST
Specs
- 10 classes (1 per type)
- Each image is (1 x 28 x 28)
Standard FashionMNIST, train, val, test splits and transforms
Transforms::
mnist_transforms = transform_lib.Compose([
transform_lib.ToTensor()
])
Example::
from pl_bolts.datamodules import FashionMNISTDataModule
dm = FashionMNISTDataModule('.')
model = LitModel()
Trainer().fit(model, datamodule=dm)
ImageNetDataModule #
Bases: VisionDataModule
ImageNet datamodule.
Extracted from https://github.com/Lightning-Universe/lightning-bolts/blob/master/src/pl_bolts/datamodules/imagenet_datamodule.py - Made this a subclass of VisionDataModule
Notes:
- train_dataloader uses the train split of imagenet2012 and puts away a portion of it for the validation split.
- val_dataloader uses the part of the train split of imagenet2012 that was not used for training via
num_imgs_per_val_class
- TODO: needs to pass split='val' to UnlabeledImagenet.
- test_dataloader uses the validation split of imagenet2012 for testing.
- TODO: need to pass num_imgs_per_class=-1 for test dataset and split="test".
dims
class-attribute
instance-attribute
#
dims: tuple[C, H, W] = (
C(3),
H(image_size),
W(image_size),
)
A tuple describing the shape of the data.
__init__ #
__init__(
data_dir: str | Path = DATA_DIR,
*,
val_split: int | float = 0.01,
num_workers: int = NUM_WORKERS,
normalize: bool = False,
image_size: int = 224,
batch_size: int = 32,
seed: int = 42,
shuffle: bool = True,
pin_memory: bool = True,
drop_last: bool = False,
train_transforms: Callable | None = None,
val_transforms: Callable | None = None,
test_transforms: Callable | None = None,
**kwargs
)
Creates an ImageNet datamodule (doesn't load or prepare the dataset yet).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_dir
|
str | Path
|
path to the imagenet dataset file |
DATA_DIR
|
val_split
|
int | float
|
save |
0.01
|
image_size
|
int
|
final image size |
224
|
num_workers
|
int
|
how many data workers |
NUM_WORKERS
|
batch_size
|
int
|
batch_size |
32
|
shuffle
|
bool
|
If true shuffles the data every epoch |
True
|
pin_memory
|
bool
|
If true, the data loader will copy Tensors into CUDA pinned memory before returning them |
True
|
drop_last
|
bool
|
If true drops the last incomplete batch |
False
|
train_transform #
The standard imagenet transforms.
.. code-block:: python
transform_lib.Compose([
transform_lib.RandomResizedCrop(self.image_size),
transform_lib.RandomHorizontalFlip(),
transform_lib.ToTensor(),
transform_lib.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
val_transform #
val_transform() -> Callable
The standard imagenet transforms for validation.
.. code-block:: python
transform_lib.Compose([
transform_lib.Resize(self.image_size + 32),
transform_lib.CenterCrop(self.image_size),
transform_lib.ToTensor(),
transform_lib.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
ImageNet32DataModule #
INaturalistDataModule #
MNISTDataModule #
Bases: ImageClassificationDataModule
.. figure:: https://miro.medium.com/max/744/1*AO2rIhzRYzFVQlFLx9DM9A.png :width: 400 :alt: MNIST
Specs
- 10 classes (1 per digit)
- Each image is (1 x 28 x 28)
Standard MNIST, train, val, test splits and transforms
Transforms::
mnist_transforms = transform_lib.Compose([
transform_lib.ToTensor()
])
Example::
from pl_bolts.datamodules import MNISTDataModule
dm = MNISTDataModule('.')
model = LitModel()
Trainer().fit(model, datamodule=dm)
__init__ #
__init__(
data_dir: str | None = None,
val_split: int | float = 0.2,
num_workers: int | None = 0,
normalize: bool = False,
batch_size: int = 32,
seed: int = 42,
shuffle: bool = True,
pin_memory: bool = True,
drop_last: bool = False,
*args: Any,
**kwargs: Any
) -> None
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_dir
|
str | None
|
Where to save/load the data |
None
|
val_split
|
int | float
|
Percent (float) or number (int) of samples to use for the validation split |
0.2
|
num_workers
|
int | None
|
How many workers to use for loading data |
0
|
normalize
|
bool
|
If true applies image normalize |
False
|
batch_size
|
int
|
How many samples per batch to load |
32
|
seed
|
int
|
Random seed to be used for train/val/test splits |
42
|
shuffle
|
bool
|
If true shuffles the train data every epoch |
True
|
pin_memory
|
bool
|
If true, the data loader will copy Tensors into CUDA pinned memory before returning them |
True
|
drop_last
|
bool
|
If true drops the last incomplete batch |
False
|
HFDataModule #
Bases: LightningDataModule
Lightning data module for HF datasets.
VisionDataModule #
Bases: LightningDataModule
, DataModule[BatchType_co]
A LightningDataModule for image datasets.
(Taken from pl_bolts which is not very well maintained.)
__init__ #
__init__(
data_dir: str | Path = DATA_DIR,
val_split: int | float = 0.2,
num_workers: int = NUM_WORKERS,
normalize: bool = False,
batch_size: int = 32,
seed: int = 42,
shuffle: bool = True,
pin_memory: bool = True,
drop_last: bool = False,
train_transforms: Callable | None = None,
val_transforms: Callable | None = None,
test_transforms: Callable | None = None,
**kwargs
) -> None
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_dir
|
str | Path
|
Where to save/load the data |
DATA_DIR
|
val_split
|
int | float
|
Percent (float) or number (int) of samples to use for the validation split |
0.2
|
num_workers
|
int
|
How many workers to use for loading data |
NUM_WORKERS
|
normalize
|
bool
|
If true applies image normalize |
False
|
batch_size
|
int
|
How many samples per batch to load |
32
|
seed
|
int
|
Random seed to be used for train/val/test splits |
42
|
shuffle
|
bool
|
If true shuffles the train data every epoch |
True
|
pin_memory
|
bool
|
If true, the data loader will copy Tensors into CUDA pinned memory before returning them |
True
|
drop_last
|
bool
|
If true drops the last incomplete batch |
False
|
train_transforms
|
Callable | None
|
transformations you can apply to train dataset |
None
|
val_transforms
|
Callable | None
|
transformations you can apply to validation dataset |
None
|
test_transforms
|
Callable | None
|
transformations you can apply to test dataset |
None
|
default_transforms
abstractmethod
#
default_transforms() -> Callable
Default transform for the dataset.
train_dataloader #
train_dataloader(
_dataloader_fn: Callable[
Concatenate[Dataset, P], DataLoader
] = DataLoader,
*args: args,
**kwargs: kwargs
) -> DataLoader
The train dataloader.
val_dataloader #
val_dataloader(
_dataloader_fn: Callable[
Concatenate[Dataset, P], DataLoader
] = DataLoader,
*args: args,
**kwargs: kwargs
) -> DataLoader
The val dataloader.
test_dataloader #
test_dataloader(
_dataloader_fn: Callable[
Concatenate[Dataset, P], DataLoader
] = DataLoader,
*args: args,
**kwargs: kwargs
) -> DataLoader
The test dataloader.