[pytorch] CNN practice - flower species identification

data set

Using a public dataset on Kaggle, the download connection is as follows:
https://www.kaggle.com/datasets/alxmamaev/flowers-recognition
They are photos of some flowers, including five categories and more than 4000 photos.

data processing

The whole data set is not large, so you can read it into memory (video memory) first, instead of reading it from the hard disk every time you need it, which can effectively improve the running speed.
The number of pictures is not large, so it needs to be used Picture enlargement Technology.

Read data set

The data on Kaggle has classified pictures according to folders, so when reading pictures, you need to classify them according to folders.

class Flower_Dataset(Dataset):
    def __init__(self, path , is_train, augs):
        data_root = pathlib.Path(path)
        all_image_paths = list(data_root.glob('*/*'))
        self.all_image_paths = [str(path) for path in all_image_paths]
        label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())
        label_to_index = dict((label, index) for index, label in enumerate(label_names))
        self.all_image = [cv.imread(path) for path in self.all_image_paths]
        self.all_image_labels = [label_to_index[path.parent.name] for path in all_image_paths]

Picture enlargement

Considering the picture of a flower, it is still a flower after horizontal transformation, so this augmentation method can be used.
This is, brightness, contrast and other adjustments can be used.

color_aug = torchvision.transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)
augs = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(), color_aug])

iterator

Extract a batch size from the dataset each time.
Generally, we use the way of disordering the order.

train_iter = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers= 4)
test_iter = DataLoader(test_set, batch_size=batch_size, num_workers= 4)

CNN model

Using the classic RESNET model, because the data set size is limited, it is not suitable to use too complex networks, so resnet18 is selected here, which has 68 layers, not too deep. The specific structure is as follows:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          36,864
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
       BasicBlock-11           [-1, 64, 56, 56]               0
           Conv2d-12           [-1, 64, 56, 56]          36,864
      BatchNorm2d-13           [-1, 64, 56, 56]             128
             ReLU-14           [-1, 64, 56, 56]               0
           Conv2d-15           [-1, 64, 56, 56]          36,864
      BatchNorm2d-16           [-1, 64, 56, 56]             128
             ReLU-17           [-1, 64, 56, 56]               0
       BasicBlock-18           [-1, 64, 56, 56]               0
           Conv2d-19          [-1, 128, 28, 28]          73,728
      BatchNorm2d-20          [-1, 128, 28, 28]             256
             ReLU-21          [-1, 128, 28, 28]               0
           Conv2d-22          [-1, 128, 28, 28]         147,456
      BatchNorm2d-23          [-1, 128, 28, 28]             256
           Conv2d-24          [-1, 128, 28, 28]           8,192
      BatchNorm2d-25          [-1, 128, 28, 28]             256
             ReLU-26          [-1, 128, 28, 28]               0
       BasicBlock-27          [-1, 128, 28, 28]               0
           Conv2d-28          [-1, 128, 28, 28]         147,456
      BatchNorm2d-29          [-1, 128, 28, 28]             256
             ReLU-30          [-1, 128, 28, 28]               0
           Conv2d-31          [-1, 128, 28, 28]         147,456
      BatchNorm2d-32          [-1, 128, 28, 28]             256
             ReLU-33          [-1, 128, 28, 28]               0
       BasicBlock-34          [-1, 128, 28, 28]               0
           Conv2d-35          [-1, 256, 14, 14]         294,912
      BatchNorm2d-36          [-1, 256, 14, 14]             512
             ReLU-37          [-1, 256, 14, 14]               0
           Conv2d-38          [-1, 256, 14, 14]         589,824
      BatchNorm2d-39          [-1, 256, 14, 14]             512
           Conv2d-40          [-1, 256, 14, 14]          32,768
      BatchNorm2d-41          [-1, 256, 14, 14]             512
             ReLU-42          [-1, 256, 14, 14]               0
       BasicBlock-43          [-1, 256, 14, 14]               0
           Conv2d-44          [-1, 256, 14, 14]         589,824
      BatchNorm2d-45          [-1, 256, 14, 14]             512
             ReLU-46          [-1, 256, 14, 14]               0
           Conv2d-47          [-1, 256, 14, 14]         589,824
      BatchNorm2d-48          [-1, 256, 14, 14]             512
             ReLU-49          [-1, 256, 14, 14]               0
       BasicBlock-50          [-1, 256, 14, 14]               0
           Conv2d-51            [-1, 512, 7, 7]       1,179,648
      BatchNorm2d-52            [-1, 512, 7, 7]           1,024
             ReLU-53            [-1, 512, 7, 7]               0
           Conv2d-54            [-1, 512, 7, 7]       2,359,296
      BatchNorm2d-55            [-1, 512, 7, 7]           1,024
           Conv2d-56            [-1, 512, 7, 7]         131,072
      BatchNorm2d-57            [-1, 512, 7, 7]           1,024
             ReLU-58            [-1, 512, 7, 7]               0
       BasicBlock-59            [-1, 512, 7, 7]               0
           Conv2d-60            [-1, 512, 7, 7]       2,359,296
      BatchNorm2d-61            [-1, 512, 7, 7]           1,024
             ReLU-62            [-1, 512, 7, 7]               0
           Conv2d-63            [-1, 512, 7, 7]       2,359,296
      BatchNorm2d-64            [-1, 512, 7, 7]           1,024
             ReLU-65            [-1, 512, 7, 7]               0
       BasicBlock-66            [-1, 512, 7, 7]               0
AdaptiveAvgPool2d-67            [-1, 512, 1, 1]               0
           Linear-68                    [-1, 5]           2,565
================================================================
Total params: 11,179,077
Trainable params: 11,179,077
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 62.79
Params size (MB): 42.64
Estimated Total Size (MB): 106.00
----------------------------------------------------------------

Fine tuning technology

Considering that the pictures in this dataset are similar to ImageNet, this technology can be used.
The only thing that needs to be modified is to set the original output to 5 in the last layer.
In addition, the learning rate of each level also needs to be modified.

net = torchvision.models.resnet18(pretrained=True)
    
net.fc = nn.Linear(net.fc.in_features, 5)
nn.init.xavier_uniform_(net.fc.weight)
summary(net , input_size=(3,224,224) , device="cpu")

lr = 0.0005
loss = nn.CrossEntropyLoss(reduction="mean")

params_1x = [param for name, param in net.named_parameters()
    if name not in ["fc.weight", "fc.bias"]]
trainer = torch.optim.SGD([{'params': params_1x},{'params': net.fc.parameters(),'lr': lr * 80}],lr=lr, weight_decay=0.001)

train

This part is similar to other neural networks, so I won't repeat it.

from tqdm import tqdm
import numpy as np

#Training
Accuracies = []
Losses = []
T_Accuracies = []
T_Losses = []
for epoch in range(epochs):
    net.train()
    loop = tqdm(enumerate(train_iter), total = len(train_iter)) # Define progress bar
    loop.set_description(f'Epoch [{epoch+ 1}/{epochs}]')# Set start
    T_Accuracies.append(0)
    T_Losses.append(0)
    for index, (X, Y) in loop:
        scores = net(X)
        l = loss(scores, Y)
        trainer.zero_grad()
        l.backward()
        
        _ , predictions = scores.max(1)
        num_correct = (predictions == Y).sum()
        running_train_acc = float(num_correct) / float(X.shape[0])
        if index == 0:
            T_Accuracies[-1] = running_train_acc
            T_Losses[-1] = l.item()
        else:
            T_Accuracies[-1] = T_Accuracies[-1] * 0.9 + 0.1 * running_train_acc
            T_Losses[-1] = T_Losses[-1] * 0.9 + 0.1 * l.item()
        
        loop.set_postfix(loss='{:.3f}'.format(T_Losses[-1]), accuracy='{:.3f}'.format(T_Accuracies[-1] )) # End of definition
        
        trainer.step()
        pass
    a , b = testing()
    Accuracies.append(a)
    Losses.append(b)

result

According to the data of training set and test set, the following images are drawn:

It can be seen that the accuracy of both the training set and the test set is relatively high, indicating that the fine-tuning technology is useful.
And the test accuracy has exceeded 90% in the fifth round, which can be said to have reached a high level in a short time.
In addition, the accuracy of the training set is lower than that of the test set, which is due to the use of image augmentation on the training set but not on the test set.

Complete code

Download link

Tags: Python Deep Learning Pytorch CNN

Posted by wikstov on Tue, 09 Aug 2022 17:45:47 +0930