paper: CCNet: Criss-Cross Attention for Semantic Segmentation
github: https://github.com/speedinghzl/CCNet/tree/pure-python
Context information is very important in the semantic segmentation task. CCNet proposes the Criss Cross attention module, and introduces the circular operation, so that each pixel in the picture can establish contact with other pixels, so that each pixel can obtain rich semantic information. In addition, category consistent loss is proposed to make the Criss Cross attention module produce more discriminative features (the category consistent loss code is not seen in the source code, so this paper will not pay attention to this part).
1, Network
1. Network structure
The CCNet network structure is shown in the figure below. CNN represents the feature extractor. Reduction reduces the number of channels of the feature map and reduces the amount of subsequent calculation. Criss Cross attention is used to establish the relationship between pixels at different locations to enrich their semantic information. R represents the number of cycles of Criss Cross attention module. Note that multiple Criss Cross attention modules share parameters.
2,Criss-Cross Attention Module
The structure of Criss Cross attention module is shown in the figure below:
Assuming that the input is X:[N, C, H, W], in order to establish a connection between a pixel and pixels in other locations, first establish a connection in the vertical and horizontal directions of the pixel, taking the vertical as an example:
① Q is obtained by 1x1 convolution_ h:[N, Cr, H, W],K_h:[N, Cr, H, W], V_h:[N, C, H, W](Q_w\K_w\V_w is the same);
② Dimension transformation, reshape gets Q_h:[N * W,H,Cr],K_h:[N * W,Cr,H], V_h:[N * W,C,H];
③Q_h and K_h matrix multiplication to obtain energy_ h:[N*W, H, H]; (in the source code, a diagonal inf matrix with dimension [N*W, H, H] is added in the calculation of energy_h, but it is not added in the calculation of energy_w, which is a little unclear.)
④ Similar to the above process, get energy_h:[N * W, H, H] and energy_w:[N * H, W, W], energy is obtained by dimension transformation after reshape_ h: [n, h, W, H] and energy_w:[N, H, W, W], splicing to obtain energy:[N, H, W, H + W];
⑤ Use softmax in the last dimension of energy to obtain the attention coefficient;
⑥ Split the attention coefficient into attn_h:[N, H, W, H] and attn_w:[N, H, W, W], and V after dimension transformation_ H and V_w is multiplied to get the output out_h and out_w;
⑦ Will out_h+out_w. And multiply by a factor γ (parameters can be learned), plus residual connection to get the final output.
The code is as follows:
def INF(B,H,W): return -torch.diag(torch.tensor(float("inf")).cuda().repeat(H),0).unsqueeze(0).repeat(B*W,1,1) class CrissCrossAttention(nn.Module): """ Criss-Cross Attention Module""" def __init__(self, in_dim): super(CrissCrossAttention,self).__init__() self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) self.softmax = Softmax(dim=3) self.INF = INF self.gamma = nn.Parameter(torch.zeros(1)) def forward(self, x): m_batchsize, _, height, width = x.size() proj_query = self.query_conv(x) proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1) proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1) proj_key = self.key_conv(x) proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height) proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width) proj_value = self.value_conv(x) proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height) proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width) energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3) energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width) concate = self.softmax(torch.cat([energy_H, energy_W], 3)) att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height) #print(concate) #print(att_H) att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width) out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1) out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3) #print(out_H.size(),out_W.size()) return self.gamma*(out_H + out_W) + x
Returning to the above structure diagram, after the Criss Cross attention module, each pixel establishes contact with all its horizontal and vertical pixels. Only through the Criss Cross attention module, each pixel establishes contact with all other pixels, thus enriching the semantic information.
PS: in addition, the author expanded the 3D Criss Cross attention module, which will not be introduced here.
2, Experimental results
1. Training strategy
Optimizer: SGD (momentum 0.9, weight_decay 0.0001)
Learning rate: polynomial strategy
Data enhancement: Random scaling (0.75-2.0)
Random cropping (cityscapes: 769x769)
For ADE20K, use resize by short to select a value from {300375450525600} and resize the short edge of the picture to this value.
2,RCCA
Comparing the number of CCA cycles on the cityscapes verification set, it can be seen that when R=2, the miou increase is very obvious, and when R=3, the increase is relatively small. Because when R=1, each pixel can only get its vertical and horizontal semantic information, and when R=2, it can get global semantic information.
3,performance
The performance of CCNet on the cityscapes verification set and test set is as follows: