Or take it for self-study, thank you, thank you for your understanding
0, Introduction to Dataset and DataLoader functions
Pytorch usually uses two utility classes, Dataset and DataLoader, to build data pipelines.
Dataset defines the content of the data set, which is equivalent to a list-like data structure with a certain length, and can use the index to obtain the elements in the data set.
DataLoader defines a method for loading datasets by batch. It is an iterable object that implements the __iter__ method, and each iteration outputs a batch of data.
The DataLoader can control the size of the batch, the sampling method of the elements in the batch, and the method of organizing the batch results into the input form required by the model, and can use multiple processes to read the data.
In most cases, users only need to implement the __len__ method and __getitem__ method of Dataset, and they can easily build their own datasets and load them with the default data pipeline.
For some complex datasets, the user may also design the collate_fn method in the DataLoader to organize the obtained batch of data into the input form required by the model. whoosoft aiot http://143ai.com
1. In-depth understanding of the principles of Dataset and DataLoader
1. Steps to obtain a batch data
Let's consider what steps are required to get a batch of data from a dataset.
(Assuming that the features and labels of the dataset are represented as tensors X and Y, respectively, the dataset can be represented as (X,Y), assuming a batch size of m)
1. First we need to determine the length n of the dataset. Similar results: n = 1000.
2, then we sample m numbers (batch size) from the range 0 to n-1. Assuming m=4, the result obtained is a list, similar to: indices = [1,4,8,9]
3. Then we take the elements corresponding to the subscripts of the m numbers from the data set. The result obtained is a list of tuples, similar to: samples = [(X[1],Y[1]),(X[4],Y[4]),(X[8],Y[8]) ,(X[9],Y[9])]
4. Finally we organize the result into two tensors as output.
The result obtained is two tensors, similar to batch = (features,labels),
where features = torch.stack([X[1],X[4],X[8],X[9]])
labels = torch.stack([Y[1],Y[4],Y[8],Y[9]])
2. Function division of Dataset and DataLoader
The first step above to determine the length of the dataset is implemented by the __len__ method of the Dataset.
The method for the second step to sample m numbers from the range 0 to n-1 is specified by the sampler and batch_sampler parameters of the DataLoader.
The sampler parameter specifies the sampling method of a single element, which generally does not need to be set by the user. By default, the program uses random sampling when the DataLoader parameter shuffle=True, and sequential sampling when shuffle=False.
The batch_sampler parameter organizes multiple sampled elements into a list, which generally does not need to be set by the user. The default method will discard the last batch of the dataset whose length cannot be divisible by the batch size when the parameter drop_last=True of the DataLoader, and keep it when drop_last=False last batch.
The core logic of the third step is to fetch the elements in the dataset according to the subscript, which is implemented by the __getitem__ method of the Dataset.
The logic of the fourth step is specified by the parameter collate_fn of DataLoader. In general, no user settings are required.
The general usage of Dataset and DataLoader is as follows:
import torch from torch.utils.data import TensorDataset,Dataset,DataLoader from torch.utils.data import RandomSampler,BatchSampler ds = TensorDataset(torch.randn(1000,3), torch.randint(low=0,high=2,size=(1000,)).float()) dl = DataLoader(ds,batch_size=4,drop_last = False) features,labels = next(iter(dl)) print("features = ",features ) print("labels = ",labels )
The steps of the internal calling method of DataLoader are disassembled as follows:
#step1: Determine the length of the dataset (implemented by the __len__ method of Dataset) ds = TensorDataset(torch.randn(1000,3), torch.randint(low=0,high=2,size=(1000,)).float()) print("n = ", len(ds)) #len(ds) is equivalent to ds.__len__() #step2: Determine sampling indices (Sampler and BatchSampler implementation in DataLoader) sampler = RandomSampler(data_source = ds) batch_sampler = BatchSampler(sampler = sampler, batch_size = 4, drop_last = False) for idxs in batch_sampler: indices = idxs break print("indices = ",indices) #step3: Take out a batch of samples batch (implemented by the __getitem__ method of Dataset) batch = [ds[i] for i in indices] #ds[i] is equivalent to ds.__getitem__(i) print("batch = ", batch) #step4: Organize into features and labels (implemented by the collate_fn method of DataLoader) def collate_fn(batch): features = torch.stack([sample[0] for sample in batch]) labels = torch.stack([sample[1] for sample in batch]) return features,labels features,labels = collate_fn(batch) print("features = ",features) print("labels = ",labels)
3. The core source code of Dataset and DataLoader
The following is the core source code of Dataset and DataLoader, omitting the code related to multi-process reading data introduced to improve performance.
import torch class Dataset(object): def __init__(self): pass def __len__(self): raise NotImplementedError def __getitem__(self,index): raise NotImplementedError class DataLoader(object): def __init__(self,dataset,batch_size,collate_fn = None,shuffle = True,drop_last = False): self.dataset = dataset self.sampler =torch.utils.data.RandomSampler if shuffle else \ torch.utils.data.SequentialSampler self.batch_sampler = torch.utils.data.BatchSampler self.sample_iter = self.batch_sampler( self.sampler(self.dataset), batch_size = batch_size,drop_last = drop_last) self.collate_fn = collate_fn if collate_fn is not None else \ torch.utils.data._utils.collate.default_collate def __next__(self): indices = next(iter(self.sample_iter)) batch = self.collate_fn([self.dataset[i] for i in indices]) return batch def __iter__(self): return self
let's test
class ToyDataset(Dataset): def __init__(self,X,Y): self.X = X self.Y = Y def __len__(self): return len(self.X) def __getitem__(self,index): return self.X[index],self.Y[index] X,Y = torch.randn(1000,3),torch.randint(low=0,high=2,size=(1000,)).float() ds = ToyDataset(X,Y) dl = DataLoader(ds,batch_size=4,drop_last = False) features,labels = next(iter(dl)) print("features = ",features ) print("labels = ",labels )
Perfect, as expected!
Second, use Dataset to create a dataset
Dataset The commonly used methods for creating datasets are:
-
Use torch.utils.data.TensorDataset to create a dataset based on Tensor (numpy array, Pandas DataFrame needs to be converted to Tensor first).
-
Use torchvision.datasets.ImageFolder to create image datasets from the image directory.
-
Inherit torch.utils.data.Dataset to create custom datasets.
In addition, through
-
torch.utils.data.random_split splits a dataset into multiple parts, often used to split training set, validation set and test set.
-
Call the addition operator (+) of Dataset to combine multiple datasets into one dataset.
1. Create a dataset based on Tensor
import numpy as np import torch from torch.utils.data import TensorDataset,Dataset,DataLoader,random_split
#Create a dataset from Tensor from sklearn import datasets iris = datasets.load_iris() ds_iris = TensorDataset(torch.tensor(iris.data),torch.tensor(iris.target)) #Split into training set and prediction set n_train = int(len(ds_iris)*0.8) n_val = len(ds_iris) - n_train ds_train,ds_val = random_split(ds_iris,[n_train,n_val]) print(type(ds_iris)) print(type(ds_train))
#Loading datasets with DataLoader dl_train,dl_val = DataLoader(ds_train,batch_size = 8),DataLoader(ds_val,batch_size = 8) for features,labels in dl_train: print(features,labels) break
#Demonstrate the combined effect of the addition operator (`+`) ds_data = ds_train + ds_val print('len(ds_train) = ',len(ds_train)) print('len(ds_valid) = ',len(ds_val)) print('len(ds_train+ds_valid) = ',len(ds_data)) print(type(ds_data))
2. Create a picture dataset based on the picture directory
import numpy as np import torch from torch.utils.data import DataLoader from torchvision import transforms,datasets
#Demonstrate some common image enhancement operations
from PIL import Image img = Image.open('./data/cat.jpeg') img
#random value flip transforms.RandomVerticalFlip()(img)
#random rotation transforms.RandomRotation(45)(img)
#Define image enhancement operations transform_train = transforms.Compose([ transforms.RandomHorizontalFlip(), #Random horizontal flip transforms.RandomVerticalFlip(), #Random vertical flip transforms.RandomRotation(45), #Rotate randomly within a 45 degree angle transforms.ToTensor() #Convert to Tensor ] ) transform_valid = transforms.Compose([ transforms.ToTensor() ] )
#Create a dataset from an image directory def transform_label(x): return torch.tensor([x]).float() ds_train = datasets.ImageFolder("./eat_pytorch_datasets/cifar2/train/", transform = transform_train,target_transform= transform_label) ds_val = datasets.ImageFolder("./eat_pytorch_datasets/cifar2/test/", transform = transform_valid, target_transform= transform_label) print(ds_train.class_to_idx) #Loading datasets with DataLoader dl_train = DataLoader(ds_train,batch_size = 50,shuffle = True) dl_val = DataLoader(ds_val,batch_size = 50,shuffle = True) for features,labels in dl_train: print(features.shape) print(labels.shape) break
3. Create a custom dataset
Next, we build a data pipeline for cifar2 in another way, that is, by inheriting torch.utils.data.Dataset to create a custom dataset.
from pathlib import Path from PIL import Image class Cifar2Dataset(Dataset): def __init__(self,imgs_dir,img_transform): self.files = list(Path(imgs_dir).rglob("*.jpg")) self.transform = img_transform def __len__(self,): return len(self.files) def __getitem__(self,i): file_i = str(self.files[i]) img = Image.open(file_i) tensor = self.transform(img) label = torch.tensor([1.0]) if "1_automobile" in file_i else torch.tensor([0.0]) return tensor,label train_dir = "./eat_pytorch_datasets/cifar2/train/" test_dir = "./eat_pytorch_datasets/cifar2/test/"
#Define image enhancement transform_train = transforms.Compose([ transforms.RandomHorizontalFlip(), #Random horizontal flip transforms.RandomVerticalFlip(), #Random vertical flip transforms.RandomRotation(45), #Rotate randomly within a 45 degree angle transforms.ToTensor() #Convert to Tensor ] ) transform_val = transforms.Compose([ transforms.ToTensor() ] )
ds_train = Cifar2Dataset(train_dir,transform_train) ds_val = Cifar2Dataset(test_dir,transform_val) dl_train = DataLoader(ds_train,batch_size = 50,shuffle = True) dl_val = DataLoader(ds_val,batch_size = 50,shuffle = True) for features,labels in dl_train: print(features.shape) print(labels.shape) break
Third, use DataLoader to load datasets
The DataLoader can control the size of the batch, the sampling method of the elements in the batch, and the method of organizing the batch results into the input form required by the model, and can use multiple processes to read the data.
The function signature of DataLoader is as follows.
DataLoader( dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, )
Under normal circumstances, we only configure the six parameters of dataset, batch_size, shuffle, num_workers,pin_memory, drop_last,
Sometimes for some complex data sets, you also need to customize the collate_fn function, and other parameters generally use the default values.
In addition to the torch.utils.data.Dataset we mentioned earlier, DataLoader can also load another dataset torch.utils.data.IterableDataset.
Unlike Dataset, which is equivalent to a list structure, IterableDataset is equivalent to an iterator structure. It is more complex and generally less used.
-
dataset : dataset
-
batch_size: batch size
-
shuffle: whether to shuffle
-
sampler: Sample sampling function, generally do not need to be set.
-
batch_sampler: batch sampling function, generally do not need to be set.
-
num_workers: Use multiple processes to read data, set the number of processes.
-
collate_fn: A function to collate a batch of data.
-
pin_memory: Whether to set as lock industry memory. The default is False, the lock industry memory will not use virtual memory (hard disk), and the speed of copying from the lock industry memory to the GPU will be faster.
-
drop_last: Whether to drop the last batch of samples whose number is less than batch_size.
-
timeout: The longest waiting time for loading a data batch, generally does not need to be set.
-
worker_init_fn: The initialization function of dataset in each worker, commonly used in IterableDataset. Generally not used.
#Build an input data pipeline ds = TensorDataset(torch.arange(1,50)) dl = DataLoader(ds, batch_size = 10, shuffle= True, num_workers=2, drop_last = True) #iterate over data for batch, in dl: print(batch)
tensor([43, 44, 21, 36, 9, 5, 28, 16, 20, 14]) tensor([23, 49, 35, 38, 2, 34, 45, 18, 15, 40]) tensor([26, 6, 27, 39, 8, 4, 24, 19, 32, 17]) tensor([ 1, 29, 11, 47, 12, 22, 48, 42, 10, 7])