Pytorch's Dataset and DataLoader

Or take it for self-study, thank you, thank you for your understanding

0, Introduction to Dataset and DataLoader functions

Pytorch usually uses two utility classes, Dataset and DataLoader, to build data pipelines.

Dataset defines the content of the data set, which is equivalent to a list-like data structure with a certain length, and can use the index to obtain the elements in the data set.

DataLoader defines a method for loading datasets by batch. It is an iterable object that implements the __iter__ method, and each iteration outputs a batch of data.

The DataLoader can control the size of the batch, the sampling method of the elements in the batch, and the method of organizing the batch results into the input form required by the model, and can use multiple processes to read the data.

In most cases, users only need to implement the __len__ method and __getitem__ method of Dataset, and they can easily build their own datasets and load them with the default data pipeline.

For some complex datasets, the user may also design the collate_fn method in the DataLoader to organize the obtained batch of data into the input form required by the model. whoosoft aiot

1. In-depth understanding of the principles of Dataset and DataLoader

1. Steps to obtain a batch data

Let's consider what steps are required to get a batch of data from a dataset.

(Assuming that the features and labels of the dataset are represented as tensors X and Y, respectively, the dataset can be represented as (X,Y), assuming a batch size of m)

1. First we need to determine the length n of the dataset. Similar results: n = 1000.

2, then we sample m numbers (batch size) from the range 0 to n-1. Assuming m=4, the result obtained is a list, similar to: indices = [1,4,8,9]

3. Then we take the elements corresponding to the subscripts of the m numbers from the data set. The result obtained is a list of tuples, similar to: samples = [(X[1],Y[1]),(X[4],Y[4]),(X[8],Y[8]) ,(X[9],Y[9])]

4. Finally we organize the result into two tensors as output.

The result obtained is two tensors, similar to batch = (features,labels),

where features = torch.stack([X[1],X[4],X[8],X[9]])

labels = torch.stack([Y[1],Y[4],Y[8],Y[9]])

2. Function division of Dataset and DataLoader

The first step above to determine the length of the dataset is implemented by the __len__ method of the Dataset.

The method for the second step to sample m numbers from the range 0 to n-1 is specified by the sampler and batch_sampler parameters of the DataLoader.

The sampler parameter specifies the sampling method of a single element, which generally does not need to be set by the user. By default, the program uses random sampling when the DataLoader parameter shuffle=True, and sequential sampling when shuffle=False.

The batch_sampler parameter organizes multiple sampled elements into a list, which generally does not need to be set by the user. The default method will discard the last batch of the dataset whose length cannot be divisible by the batch size when the parameter drop_last=True of the DataLoader, and keep it when drop_last=False last batch.

The core logic of the third step is to fetch the elements in the dataset according to the subscript, which is implemented by the __getitem__ method of the Dataset.

The logic of the fourth step is specified by the parameter collate_fn of DataLoader. In general, no user settings are required.

The general usage of Dataset and DataLoader is as follows:

import torch   
from import TensorDataset,Dataset,DataLoader  
from import RandomSampler,BatchSampler   
ds = TensorDataset(torch.randn(1000,3),  
dl = DataLoader(ds,batch_size=4,drop_last = False)  
features,labels = next(iter(dl))  
print("features = ",features )  
print("labels = ",labels )

The steps of the internal calling method of DataLoader are disassembled as follows:

#step1: Determine the length of the dataset (implemented by the __len__ method of Dataset)
ds = TensorDataset(torch.randn(1000,3),  
print("n = ", len(ds)) #len(ds) is equivalent to ds.__len__()
#step2: Determine sampling indices (Sampler and BatchSampler implementation in DataLoader)
sampler = RandomSampler(data_source = ds)  
batch_sampler = BatchSampler(sampler = sampler,   
                             batch_size = 4, drop_last = False)  
for idxs in batch_sampler:  
    indices = idxs  
print("indices = ",indices)  
#step3: Take out a batch of samples batch (implemented by the __getitem__ method of Dataset)
batch = [ds[i] for i in  indices]  #ds[i] is equivalent to ds.__getitem__(i)
print("batch = ", batch)  
#step4: Organize into features and labels (implemented by the collate_fn method of DataLoader)
def collate_fn(batch):  
    features = torch.stack([sample[0] for sample in batch])  
    labels = torch.stack([sample[1] for sample in batch])  
    return features,labels   
features,labels = collate_fn(batch)  
print("features = ",features)  
print("labels = ",labels) 

3. The core source code of Dataset and DataLoader

The following is the core source code of Dataset and DataLoader, omitting the code related to multi-process reading data introduced to improve performance.

import torch   
class Dataset(object):  
    def __init__(self):  
    def __len__(self):  
        raise NotImplementedError  
    def __getitem__(self,index):  
        raise NotImplementedError  
class DataLoader(object):  
    def __init__(self,dataset,batch_size,collate_fn = None,shuffle = True,drop_last = False):  
        self.dataset = dataset  
        self.sampler if shuffle else \  
        self.batch_sampler =  
        self.sample_iter = self.batch_sampler(  
            batch_size = batch_size,drop_last = drop_last)  
        self.collate_fn = collate_fn if collate_fn is not None else \  
    def __next__(self):  
        indices = next(iter(self.sample_iter))  
        batch = self.collate_fn([self.dataset[i] for i in indices])  
        return batch  
    def __iter__(self):  
        return self 

let's test

class ToyDataset(Dataset):  
    def __init__(self,X,Y):  
        self.X = X  
        self.Y = Y   
    def __len__(self):  
        return len(self.X)  
    def __getitem__(self,index):  
        return self.X[index],self.Y[index]  
X,Y = torch.randn(1000,3),torch.randint(low=0,high=2,size=(1000,)).float()  
ds = ToyDataset(X,Y)  
dl = DataLoader(ds,batch_size=4,drop_last = False)  
features,labels = next(iter(dl))   
print("features = ",features )  
print("labels = ",labels )

Perfect, as expected!

Second, use Dataset to create a dataset

Dataset The commonly used methods for creating datasets are:

  • Use to create a dataset based on Tensor (numpy array, Pandas DataFrame needs to be converted to Tensor first).

  • Use torchvision.datasets.ImageFolder to create image datasets from the image directory.

  • Inherit to create custom datasets.

In addition, through

  • splits a dataset into multiple parts, often used to split training set, validation set and test set.

  • Call the addition operator (+) of Dataset to combine multiple datasets into one dataset.

1. Create a dataset based on Tensor

import numpy as np   
import torch   
from import TensorDataset,Dataset,DataLoader,random_split   
#Create a dataset from Tensor
from sklearn import datasets   
iris = datasets.load_iris()  
ds_iris = TensorDataset(torch.tensor(,torch.tensor(  
#Split into training set and prediction set
n_train = int(len(ds_iris)*0.8)  
n_val = len(ds_iris) - n_train  
ds_train,ds_val = random_split(ds_iris,[n_train,n_val])  
#Loading datasets with DataLoader
dl_train,dl_val = DataLoader(ds_train,batch_size = 8),DataLoader(ds_val,batch_size = 8)  
for features,labels in dl_train:  
#Demonstrate the combined effect of the addition operator (`+`)
ds_data = ds_train + ds_val  
print('len(ds_train) = ',len(ds_train))  
print('len(ds_valid) = ',len(ds_val))  
print('len(ds_train+ds_valid) = ',len(ds_data))  

2. Create a picture dataset based on the picture directory

import numpy as np   
import torch   
from import DataLoader  
from torchvision import transforms,datasets   
#Demonstrate some common image enhancement operations
from PIL import Image  
img ='./data/cat.jpeg')  
#random value flip
#random rotation
#Define image enhancement operations
transform_train = transforms.Compose([  
   transforms.RandomHorizontalFlip(), #Random horizontal flip
   transforms.RandomVerticalFlip(), #Random vertical flip
   transforms.RandomRotation(45),  #Rotate randomly within a 45 degree angle
   transforms.ToTensor() #Convert to Tensor
transform_valid = transforms.Compose([  
#Create a dataset from an image directory
def transform_label(x):  
    return torch.tensor([x]).float()  
ds_train = datasets.ImageFolder("./eat_pytorch_datasets/cifar2/train/",  
            transform = transform_train,target_transform= transform_label)  
ds_val = datasets.ImageFolder("./eat_pytorch_datasets/cifar2/test/",  
                              transform = transform_valid,  
                              target_transform= transform_label)  
#Loading datasets with DataLoader
dl_train = DataLoader(ds_train,batch_size = 50,shuffle = True)  
dl_val = DataLoader(ds_val,batch_size = 50,shuffle = True)  
for features,labels in dl_train:  

3. Create a custom dataset

Next, we build a data pipeline for cifar2 in another way, that is, by inheriting to create a custom dataset.

from pathlib import Path   
from PIL import Image   
class Cifar2Dataset(Dataset):  
    def __init__(self,imgs_dir,img_transform):  
        self.files = list(Path(imgs_dir).rglob("*.jpg"))  
        self.transform = img_transform  
    def __len__(self,):  
        return len(self.files)  
    def __getitem__(self,i):  
        file_i = str(self.files[i])  
        img =  
        tensor = self.transform(img)  
        label = torch.tensor([1.0]) if  "1_automobile" in file_i else torch.tensor([0.0])  
        return tensor,label   
train_dir = "./eat_pytorch_datasets/cifar2/train/"  
test_dir = "./eat_pytorch_datasets/cifar2/test/"  
#Define image enhancement
transform_train = transforms.Compose([  
   transforms.RandomHorizontalFlip(), #Random horizontal flip
   transforms.RandomVerticalFlip(), #Random vertical flip
   transforms.RandomRotation(45),  #Rotate randomly within a 45 degree angle
   transforms.ToTensor() #Convert to Tensor
transform_val = transforms.Compose([  
ds_train = Cifar2Dataset(train_dir,transform_train)  
ds_val = Cifar2Dataset(test_dir,transform_val)  
dl_train = DataLoader(ds_train,batch_size = 50,shuffle = True)  
dl_val = DataLoader(ds_val,batch_size = 50,shuffle = True)  
for features,labels in dl_train:  

Third, use DataLoader to load datasets

The DataLoader can control the size of the batch, the sampling method of the elements in the batch, and the method of organizing the batch results into the input form required by the model, and can use multiple processes to read the data.

The function signature of DataLoader is as follows.


Under normal circumstances, we only configure the six parameters of dataset, batch_size, shuffle, num_workers,pin_memory, drop_last,

Sometimes for some complex data sets, you also need to customize the collate_fn function, and other parameters generally use the default values.

In addition to the we mentioned earlier, DataLoader can also load another dataset

Unlike Dataset, which is equivalent to a list structure, IterableDataset is equivalent to an iterator structure. It is more complex and generally less used.

  • dataset : dataset

  • batch_size: batch size

  • shuffle: whether to shuffle

  • sampler: Sample sampling function, generally do not need to be set.

  • batch_sampler: batch sampling function, generally do not need to be set.

  • num_workers: Use multiple processes to read data, set the number of processes.

  • collate_fn: A function to collate a batch of data.

  • pin_memory: Whether to set as lock industry memory. The default is False, the lock industry memory will not use virtual memory (hard disk), and the speed of copying from the lock industry memory to the GPU will be faster.

  • drop_last: Whether to drop the last batch of samples whose number is less than batch_size.

  • timeout: The longest waiting time for loading a data batch, generally does not need to be set.

  • worker_init_fn: The initialization function of dataset in each worker, commonly used in IterableDataset. Generally not used.

#Build an input data pipeline
ds = TensorDataset(torch.arange(1,50))  
dl = DataLoader(ds,  
                batch_size = 10,  
                shuffle= True,  
                drop_last = True)  
#iterate over data
for batch, in dl:  
tensor([43, 44, 21, 36,  9,  5, 28, 16, 20, 14])  
tensor([23, 49, 35, 38,  2, 34, 45, 18, 15, 40])  
tensor([26,  6, 27, 39,  8,  4, 24, 19, 32, 17])  
tensor([ 1, 29, 11, 47, 12, 22, 48, 42, 10,  7])  

Tags: Deep Learning Pytorch AI

Posted by moe on Fri, 02 Sep 2022 03:36:34 +0930