Second week of CV transformer

Keywords: Python Machine Learning Computer Vision Deep Learning CV

6.2.1 data set introduction

The data used in this OCR experiment is based on Task 4.3:Word Recognition in icdar2015 incident scene text. This is a word recognition task. Remove some pictures to simplify the difficulty of this experiment.
The data set contains many text areas in natural scene images. The training set in the original data contains 4468 images and the test set contains 2077 images. They are cut out from the original large image according to the bounding box of the text area, and the text in the image is basically in the center of the picture.
The dataset contains folders containing images and corresponding label files, which are divided into train and test. Similar forms such as word_1.png, "join", i.e. word_ 1. The image result of PNG is the English word join.

In order to simplify the recognition difficulty of subsequent experiments, the data set provided roughly filters the vertically arranged images of characters with an aspect ratio > 1.5, so it is slightly different from the original data set of ICDAR2015.

6.2.2 data analysis and character mapping relationship construction

Before the experiment, simply analyze the data to understand the characteristics of the data.
The specific work is: tag character statistics (which characters are there and how many times each character appears), longest tag length statistics, image size analysis, etc. are carried out on the data, and the mapping relationship file lbl2id of character tags is constructed_ map.txt.

preparation

import os
import cv2

#Data set root directory, please download the data to this location
base_data_dir='./ICDAR_2015'

#Training data set and validation data
train_img_dir=os.path.join(base_data_dir,'train')
valid_img_dir=os.path.join(base_data_dir,'valid')
#Training set and validation set label file paths
train_lbl_path=os.path.join(base_data_dir,'train_gt.txt')
valid_lbl_path=os.path.join(base_data_dir,'valid_gt.txt')
#The intermediate file storage path stores the mapping relationship between label characters and their IDs
lbl2id_map_path=os.path.join(base_data_dir,'lbl2id_map.txt')

1. Statistics of the longest characters of the label

First, count the number of characters contained in the longest label in the data set. Here, count the longest label in both the training set and the verification set to get the characters contained in the longest label.

def statistics_max_len_label(lbl_path):
'''
Count the longest in the tag file label Number of characters contained
lbl_path:txt Label file path
'''
max_len=-1
with open(lbl_path,'r') as reader:
    for line in reader:
        items=line.rstrip().split(',')
        # img_name=items[0] #Extract image name
        lbl_str=items[1].strip()[1:-1] #Extract the label and remove the quotation mark "" from the label
        lbl_len=len(lbl_str)
        max_len=max_len if max_len>lbl_len else lbl_len
return max_len

train_max_label_len=statistics_max_len_label(train_lbl_path) #Longest training set label
valid_max_label_len=statistics_max_len_label(valid_lbl_path) #Verification set longest label
max_label_len=max(train_max_label_len,valid_max_label_len) #The longest label in the whole dataset
print(f"The data set contains the most characters label Count Reg{max_label_len}")

The longest label in the dataset contains 21 characters, which will provide a reference for the setting of time steps when the transformer model is built later.

2. Statistics of characters contained in labels

View all characters that have appeared in the dataset

def statistics_label_cnt(lbl_path,lbl_cnt_map):
'''
Statistics tag file label What characters are included and how many times they appear
lbl_path:Path to label file
lbl_cnt_map:A dictionary that records the number of occurrences of characters in a label
'''
with open(lbl_path,'r') as reader:
    for line in reader:
        items=line.rstrip().split(',')
        #img_name=items[0]
        lbl_str=items[1].strip()[1:-1] #Extract the label and remove the double quotation mark "" in the label
        for lbl in lbl_str:
            if lbl not in lbl_cnt_map.keys():
                lbl_cnt_map[lbl]=1
            else:
                lbl_cnt_map[lbl] +=1
         
lbl_cnt_map=dict() #A dictionary used to store the number of occurrences of characters
statistics_label_cnt(train_lbl_path,lbl_cnt_map)#Statistics of character occurrences in training set
print("Training concentration label Characters appearing in:")
print(lbl_cnt_map)
statistics_label_cnt(valid_lbl_path,lbl_cnt_map)#Statistics of character occurrences in training set and verification set
print("Training set+Validation set label Characters appearing in:")
print(lbl_cnt_map)     

lbl_cnt_map is a statistical Dictionary of the number of occurrences of characters. It will also be used to establish the mapping relationship between characters and their IDs.
From the statistical results of the data set, the test set contains characters that have not appeared in the training set. The number of such cases is small, so it should not be a problem, so no additional processing is performed on the data set here (but it is necessary to consciously check whether there is difference between the training set and the test set).

3. Construction of mapping dictionary between char and id

In the OCR task of this paper, it is necessary to predict each character in the picture. In order to achieve this purpose, we first need to establish a mapping relationship between a character and its id, and convert the text information into digital information that can be read by the model. This step is similar to establishing a corpus in NLP.
When building a mapping relationship, in addition to recording the characters appearing in all label files, three special characters need to be initialized to represent a sentence start character, sentence end character and padding identifier respectively.
Build the mapping relationship between char and ID. finally, the mapping relationship will be saved in lbl2id_ In the map.txt file

#Mapping between characters -- id in label construction
print("structure label Chinese character--id Mapping between:")
lbl2id_map=dict()
#Initialize three special characters
lbl2id_map['☯']=0 #padding identifier
lbl2id_map['■']=1 #Sentence starter
lbl2id_map['□']=2 #Sentence Terminator
#Generate id mapping relationships for the remaining characters
cur_id=3
for lbl in lbl_cnt_map.keys():
    lbl2id_map[lbl]=cur_id
    cur_id+=1
#Save the mapping between characters -- id to txt file
with open(lbl2id_map_path,'w',encoding='utf_8_sig') as writer:#The parameter encoding is optional. Some devices do not default to UTF_ eight
   for lbl in lbl2id_map.keys():
       cur_id=lbl2id_map[lbl]
       print(lbl,cur_id)
       line=lbl+'\t'+str(cur_id)+'\n'
       writer.write(line)

To establish a relationship mapping dictionary, you can read the file containing the mapping relationship txt to build a character to id and id to character mapping dictionary. This serves the subsequent transformer training process to facilitate the fast conversion of character relationships.

def load_lbl2id_map(lbl2id_map):
'''
Read character-id Of mapping relationship records txt File and return lbl->id and id->lbl Mapping dictionary
lbl2id_map_path:character-id Of mapping relationship records txt File path
'''
    lbl2id_map=dict()
    id2lbl_map=dict()
    with open(lbl2id_map_path,'r') as reader:
        for line in reader:
            items=line.rstrip().split('\t')
            label=items[0]
            cur_id=int(items[1])
            lbl2id_map[label]=cur_id
            id2lbl_map[cur_id]=label
    return lbl2id_map,id2lbl_map    

4. Data set image size analysis

When carrying out tasks such as image classification and detection, we often check the size distribution of the image, and then determine the appropriate image preprocessing method. For example, when carrying out target detection, we will count the size of the bounding box of the image size, analyze the aspect ratio, and then select the appropriate image clipping strategy and the appropriate initial anchor strategy.

Analyze the image width, height and aspect ratio to understand the characteristics of the data

#Analyze dataset picture size
print("Analysis dataset picture size:")

#Initialization parameters
min_h=1e10
min_W=1e10
max_h=-1
max_w=-1
min_ratio=1e10
max_ratio=0
#Traverse the dataset to calculate size information
for img_name in os.listdir(train_img_dir):
    img_path=os.path.join(train_img_dir,img_name)
    img=cv2.imread(img_path) #Read picture
    h,w=img.shape[:2] #Extract image height and width information
    ratio=w/h #Aspect ratio
    min_h=min_h if min_h <=h else h #Minimum picture height
    max_h=max_h if max_h >=h else h #Maximum picture height
    min_w=min_w if min_w<=w else w #Minimum picture width
    max_w=max_w if max_w >=w else w #Maximum picture width
    min_ratio=min_ratio if min_ratio <= ratio else ratio #Minimum aspect ratio
    max_ratio=max_ratio if max_ratio >= ratio else ratio #Maximum aspect ratio
#Output information
print('min_h',min_h)
print('max_h',max_h)
print('min_w',min_w)
print('max_w',max_w)
print('min_ratio',min_ratio)
print('max_ratio',max_ratio)

Most of the pictures are long strips lying down, and the maximum aspect ratio is > 8. It can be seen that there are extremely slender pictures.

6.2.3 how to introduce transformer into OCR

Why can transformer solve OCR problems?

  • transformer is widely used in NLP field and can solve the problem of sequence to sequence class such as machine translation.
  • OCR recognition task, which wants to identify the contents in the figure, can also be regarded as a sequence to sequence task in essence, but the input sequence information is represented in the form of pictures.
  • From the perspective of treating OCR problem as a sequence to sequence prediction problem, it seems to be a very natural and smooth idea to use transformer to solve OCR problem. The remaining problem is how to construct image information into input similar to word embedding to meet the input requirements of transformer.
  • Since the pictures to be predicted are long strips and the characters are basically arranged horizontally, we integrate the feature map along the horizontal direction, and each embedding obtained can be regarded as the feature of a slice in the vertical direction of the picture. Such feature sequence is handed over to the transformer, which uses its powerful attention ability to complete the prediction.
    The model framework pipeline is shown in the figure:

    By observing the above figure, it can be found that the whole pipeline is basically the same as the process of training machine translation with transformer, and the difference is mainly due to the process of extracting image features with a CNN network as the backbone to obtain input embedding.

6.2.4 detailed explanation of training framework code

The related codes of training framework are implemented in ocr_by_transformer.py file
Let's start to explain the code step by step, mainly including the following parts:

  • Build dataset - > image preprocessing, label processing, etc
  • Model construction - > backbone + transformer
  • model training
  • Reasoning - > greedy decoding

1. Preparation

First, import the library to be used later

import os
import time
import copy
import numpy as np
from PIL import Image

#torch related package
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.models as models
import torchvision.transforms as transforms

#Import tool class package
from analysis_recognition_dataset import load_lbl2id_map,statistics_max_len_label
from transformer import *
from train_utils import *

Then set some basic parameters

base_data_dir='../../../dataset/ICDAR_2015' #Data set root directory, please download the data to this location
device=torch.device("cuda") #'cpu' or 'cuda'
nrof_epochs=1500 #Number of iterations, 1500, revised according to requirements
batch_size=64  #Batch size, 64, corrected as required
model_save_path='./log/ex1_ocr_model.pth'

Read the mapping dictionary between the character and its id in the image label, which needs to be used for subsequent Dataset creation.

#Read the label ID mapping record file
lbl2id_map_pat=os.path.join(base_data_dir,'lbl2id_map.txt')
lbl2id_map,id2lbl_map=load_lbl2id_map(lbl2id_map_path)

#Statistics on the number of characters in all label s in the dataset that contain the most characters. The data set construction gt(ground truth) information needs to be used
train_lbl_path=os.path.join(base_data_dir,'train_gt.txt')
valid_lbl_path=os.path.join(base_data_dir,'valid_gt.txt')
train_max_label_len=statistics_max_len_label(train_lbl_path)
valid_max_label_len=statistics_max_len_label(valid_lbl_path)
#The case with the largest number of characters in the dataset is used as the sequence of gt_ len
sequence_len=max(train_max_label_len,valid_max_label_len)

2. Dataset construction

Let's introduce the related contents of Dataset construction. First, it's reasonable to think about how to preprocess images.

Picture preprocessing scheme

Suppose the picture size is [batch_size,3, H i H_i Hi​, W i W_i Wi​]
The size of the characteristic diagram after passing through the network is [batch_size, C f C_f Cf​, H f H_f Hf​, W f W_f Wf​]
Based on the previous analysis of the data set, the pictures are basically horizontal long strips, and the image content is words composed of horizontally arranged characters. Then, the position of the same vertical slice in the picture space basically has only one character, so the vertical resolution does not need to be very large, so take H f = 1 H_f=1 Hf = 1; the horizontal resolution needs to be larger. We need different embedding to encode the characteristics of different characters in the horizontal direction.

Here, we use the most classic resnet18 network as the backbone. Because its lower sampling multiple is 32 and the number of channel s in the last layer of characteristic graph is 512, then:
H i = H f ∗ 32 = 32 H_i=H_f*32=32 Hi​=Hf​∗32=32
C f = 512 C_f=512 Cf​=512
How to determine the width of the input picture? Here are two schemes, as shown in the figure below:

Method 1: set a fixed size, resize the image with its aspect ratio, and pad the empty area on the right;
Method 2: directly force the original image to resize to a preset fixed size.
Note: which scheme is better?
Here, the author chose method 1, because the aspect ratio of the picture is roughly proportional to the number of characters of words in the picture. If the aspect ratio of the original picture is maintained during preprocessing, the range of each pixel on the feature map corresponding to the character area on the original image is basically stable, which may have a better prediction effect.
Here's another detail. If you look at the above figure, you will find that each area with width: height = 1:1 is basically distributed with 2-3 characters. Therefore, in actual operation, we did not strictly keep the width height ratio unchanged, but increased the width height ratio by 3 times, that is, first lengthen the width of the original image to 3 times, then maintain the width height ratio and resize the height to 32.
Note: it is suggested to stop and think again. Why is this detail just now?
The purpose of this is to make each character on the picture have at least one pixel on the feature map corresponding to it, rather than one pixel on the wide dimension of the feature map, and encode the information of multiple characters in the original image. I think this will bring unnecessary difficulties to the prediction of transformer.
The resize scheme is determined, W i W_i What is the specific setting of Wi? Combined with the two important indicators we analyzed in the data set earlier, the longest character in the data set label is 21 and the longest aspect ratio is 8.6. We set the final aspect ratio to 24:1. Therefore, we summarize the settings of various parameters:
H i = H f ∗ 32 = 32 H_i=H_f*32=32 Hi​=Hf​∗32=32
W i = 24 ∗ H i = 768 W_i=24*H_i=768 Wi​=24∗Hi​=768
C f = 512 , H f = 1 , W f = 24 C_f=512,H_f=1,W_f=24 Cf​=512,Hf​=1,Wf​=24
Image preprocessing

#Image preprocessing
#load img
img=Image.open(img_path).convert('RGB')
#Zoom the picture approximately equally
#Reduce the height to 32 and the width to equal scale, but divide by 32
w,h=img.size
ratio=round((w/h)*3) #Lengthen the width three times and round it
if ratio==0:
    ratio = 1
if ratio>self.max_ratio:
    ratio=self.max_ratio
h_new=32
w_new=h_new*ratio
img_resize=img.resize((w_new,h_new)),Image.BILINEAR)

#padding the right half of the picture so that the width / height ratio is fixed = self.max_ratio
img_padd=Image.new('RGB',(32*self.max_ratio,32),(0,0,0))
img_padd.paste(img_resize,(0,0))

Image augmentation

Image enlargement is not the key point. Here, in addition to the above resize scheme, we only perform conventional random color transformation and normalization on the image.
Complete code for Dataset construction

class Recognition_Dataset(object):
   def __init__(self,dataset_root_dir,lbl2id_map,sequence_len,max_ration,phase='train',pad=0):
   if phase == 'train':
     self.img_dir=os.path.join(base_data_dir,'train')
     self.lbl_path=os.path.join(base_data_dir,'train_gt.txt')
   else:
     self.img_dir=os.path.join(base_data_dir,'valid')
     self.lbl_path=os.path.join(base_data_dir,'valid_gt.txt')
   self.lbl2id_map=lbl2id_map
   self.pad=pad #The id of the padding identifier. The default is 0
   self.sequence_len=sequence_len #Sequence length
   self.max_ratio=max_ratio*3 #Lengthen the width by 3 times

   self.imgs_list=[]
   self.lbls_list=[]
   with open(self.lbl_path,'r') as reader:
      for line in reader:
         items=line.rstrip().split(',')
         img_name=items[0]
         lbl_str=items[1].strip()[1:-1]
   
   #Define random color transformations
   self.color_trans=transforms.ColorJitter(0.1,0.1,0.1)
   #Define Normalize
   self.trans_Normalize=transforms.Compose([
       transforms.ToTensor(),
       transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
   
   def __getitem__(self,index):
   '''
   Get corresponding index Images and ground truth label,And data enhancement as appropriate
   '''  
   img_name=self.imgs_list[index]
   img_path=os.path.join(self.img_dir,img_name)
   lbl_str=self.lbls_list[index]

   #------------------------
   #Image preprocessing
   #------------------------
   #load image
   img=Image.open(img_path).convert('RGB')
   
   #Zoom the picture approximately equally
   #Reduce the height to 32 and the width to equal scale, but divide by 32
   w,h=img.size
   ratio=round((w/h)*3) #Lengthen the width three times and round it
   if ratio ==0:
       ratio=1
   if ratio > self.max_ratio:
       ratio=self.max_ratio
   h_new=32
   W_new=h_new*ratio
   img_resize=img.resize((w_new,h_new),Image.BILINEAR)

   #padding the right half of the picture so that the width / height ratio is fixed = self.max_ratio
   img_padd=Image.new("RGB",(32*self.max_ratio,32),(0,0,0))
   img_padd.paste(img_resieze,(0,0))

   #Random color transformation
   img_input=self.color_trans(img_padd)
   #Normalize
   img_input=self.trans_Normalize(img_input)

   #---------------------
   #label processing
   #---------------------

   #Construct mask of encoder
   encode_mask=[1]*ratio+[0]*(self.max_ratio-ratio)
   encode_mask=torch.tensor(encode_mask)
   encode_mask=(encode_mask !=0).unsqueeze(0)
   #Construct ground truth label
   gt=[]
   gt.append(1) #Add the sentence start character first
   for lbl in lbl_str:
       gt.append(self.lbl2id_map[lbl])
   gt.append(2)
   for i in range(len(lbl_str),self.sequence_len):
   #Except for the start and end characters, the length of lbl is sequence_len, and the remaining padding
       gt.append(0)
   #Truncate to the preset maximum sequence length
   gt=gt[:self.sequence_len+2]
   
   #Input of decoder
   decode_in=gt[:-1]
   decode_in=torch.tensor(decode_in)
   #Output of decoder
   decode_out=gt[1:]
   decode_out=torch.tensor(decode_out)
   #mask of decoder
   decode_mask=self.make_std_mask(decode_in,self.pad)
   #Number of valid tokens
   ntokens=(decode_out !=self.pad).data.sum()
   return img_input,encode_mask,decode_in,decode_out,decode_mask,ntokens
   
   @staticmethod
   def make_std_mask(tgt,pad):
   """
   Create a mask to hide padding and future words.
   padd and future words All in mask 0 in
   """   
       tgt_mask=(tgt != pad)
       tgt_mask=tgt_mask & Variable(subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
       tgt_mask=tgt_mask.squeeze(0) #When subsequence returns the shape of the value (1,N,N)
       return tgt_mask
   def __len__(self):
       return len(self.imgs_list)

encode_mask

Because we have adjusted the size of the image and padded the image according to the needs, and the padded position does not contain effective information, we need to construct the corresponding encode_mask according to the padding proportion, so that the transformer can ignore this meaningless area during calculation.

label processing

The prediction tags used in this experiment are basically the same as those used in machine translation model training, so there is little difference in processing methods.
In label processing, the characters in the label are converted into their corresponding id, and the start character is added at the beginning of the sentence, and the end character is added at the end of the sentence. When the length of sequence_len is not met, padding (0 filling) is performed at the remaining positions.

decode_mask

Generally, in the decoder, we will generate a mask in the form of upper triangular matrix according to the sequence_len of the label. Each line of the mask can control the current time_step. When the decoder is allowed to obtain only the character information before the current step, and it is prohibited to obtain the character information at the future time, which prevents cheating in model training.
The decode_mask is generated through a special function * * make_std_mask() *.
At the same time, it is also necessary to mask the padding part of the label of the decoder, so the decode_mask should also be written as False at the position corresponding to the label being padded.
The generated decode_mask is shown in the following figure:

Build a DataLoader for training

#Construct dataloader
max_ratio=8 #The maximum value of width / height during image preprocessing. If it does not exceed the guaranteed proportion, it will be forcibly compressed
train_dataset=Recognition_Dataset(base_data_dir,lbl2id_map,sequence_len,max_ratio,'train',pad=0)
valid_dataset=Recognition_Dataset(base_data_dir,lbl2id_map.sequence_len,max_ratio,'valid',pad=0)
#loader size info:
#--> img_input:[batch_size,c,h,w]-->[64,3,32,32*8*3]
#-->encode_mask:[batch_size,h/32,w/32]-->[64,1,24]#In this paper, the backbone adopts 32 times down sampling, so divide by 32
#-->decode_in:[bs,sequence_len-1]-->[64,20]
#-->decode_out:[bs,sequence_len-1]-->[64,20]
#-->decode_mask:[bs,sequence_len-1,sequence_len-1]-->[64,20,20]
#-->ntokens:[bs]-->[64]
train_loader=torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=4)
valid_loader=torch.utils.data.DataLoader(valid_dataset,batch_size=batch_size,shuffle=False,num_workers=4)

3. Model construction

Code through make_ocr_model and OCR_ The encoderdecoder class completes the construction of the model structure.
From make_ ocr_ The model function looks like it first calls Resnet-18 pre trained in pytorch as the backbone to extract image features. It can also be adjusted to other networks according to its own needs, but it needs to focus on the down sampling multiple of the network and the channel of the last layer of feature map_ Num, the parameters of relevant modules need to be adjusted synchronously. After that, OCR_ was called. The encoder decoder completes the construction of the transformer. Finally, the model parameters are initialized.
In OCR_ In the encoderdecoder class, this class is equivalent to an assembly line of basic components of a transformer, including encoder and decoder. For the basic components whose initial parameters already exist, the basic component codes are in the transformer.py file.
Review how the image is constructed as the input of the Transformer after it passes through the backbone:
After the image passes through the backbone, a feature map with dimension [batch_size,512,1,24] will be output. If you don't pay attention to batch_ On the premise of size, each image will get 1 with 512 channels as shown below × 24, as shown in the red box in the figure, the eigenvalues at the same position of different channels are spliced to form a new vector and used as the input of a time step. At this time, the input with dimension [batch_size, 24512] is constructed to meet the input requirements of Transformer.

model structure

class OCR_EncoderDecoder(nn.Module):
'''
A standard Encoder-Decoder architecture.
Base for this and many other models.
'''
 def __init__(self,encoder,decoder,src_embed,src_position,tgt_embed,generator):
   super(OCR_EncoderDecoder,self).__init__()
   self.encoder=encoder
   self.decoder=decoder
   self.src_embed=src_embed #input embedding module
   self.src_position=src_position
   self.tgt_embed=tgt_embed #output embedding mudule
   self.generator=generator #output generation module
 
 def forward(self,src,tgt,src_mask,tgt_mask):
   "Take in and process masked src and target sequences."
   #src-->[bs,3,32,768] [bs,c,h,w]
   #src_mask -->[bs,1,24] [bs,h/32,w/32]
   memory=self.encode(src,src_mask)
   #memory-->[bs,24,512]
   #tgt-->decode_in[bs,20] [bs,sequence_len-1]
   #tgt_mask-->decode_mask [bs,20] [bs,sequence_len-1]
   res=self.decode(memory,src_mask,tgt,tgt_mask) #[bs_20,512]
   return res
 
def encode(self,src,src_mask):
   #feature extract
   #src-->[bs,3,32,768]
   src_embedds=self.src_embed(src)
   #resnet18 is used here as the backbone output -- > [batchsize, C, h, w] -- > [BS, 512,1,24]
   #Set src_embedds is processed by shape(bs,model_dim,1,max_ratio) into the input shape(bs, time step, model_dim) expected by the transformer
   #[bs,512,1,24]-->[bs,24,512]
   src_embedds=src_embedds.squeeze(-2)
   src_embedds=src_embedds.permute(0,2,1)
   
   #position encode
   src_embedds=self. src_position(src_embedds) #[bs,24,512]
   return self.encoder(src_embedds,src_mask) #[bs,24,512]

def decode(self,memory,src_mask,tgt,tgt_mask):
   target_embedds=self.tgt_embed(tgt) #[bs,20,512]
   return self.decoder(targget_embedds,memory,src_mask,tgt_mask)
   
def make_ocr_model(tgt_vocab,N=6,d_model=512,d_ff=2048,h=8,dropout=0.1):
'''
Build model
params:
  tgt_vocab:Output dictionary size
  N:Number of encoder and decoder stack base modules
  d_model:In the model embedding of size,The default is 512
  d_ff:FeedForward Layer In layer embedding of size,The default is 2048
  h:MultiHeadAttention The number of multiple heads in must be d_model to be divisible by
  dropout: dropout Ratio of
'''
  c=copy.deepcopy
 
 #resnet18 pre trained in torch is used as feature extraction network, backbone
 backbone=models.resnet18(pretrained=True)
 backbone=nn.Sequential(*list(backbone.children())[:-2]) #Remove the last two layers (global average pooling and fc layer)
 
 attn=MultiHeadedAttention(h,d_model)
 ff=PositionwiseFeedForward(d_model,d_ff,dropout)
 position=PostionalEncoding(d_model,dropout)
 #Build model
 model=OCR_EncoderDecoder(
     Encoder(EncoderLayer(d_model,c(attn),c(ff),dropout),N),
     Decoder(DecoderLayer(d_model,c(attn),c(attn),c(ff),dropout),N),
     backbone,
     c(position),
     nn.Sequential(Embeddings(d_model,tgt_vocab),c(position)),
     Generator(d_model,tgt_vocab)) #The generator here is not called inside the class

#Initialize parameters with Glorot/fan_avg.
for child in model.children():
   if child is backbone:
    #Set the weight of the backbone to not calculate the gradient
      for param in child.parameters():
           param.requires_grad=False
    #The pre trained backbone is not initialized randomly, and the other modules are initialized randomly
      continue
   for p in child.parameters():
      if p.dim()>1:
          nn.init.xavier_uniform_(p)
   return model

The transformer model can be easily built through the above two classes:
transformer model

#build model
#use transformer as ocr recognize model
#OCR built here_ The model does not contain a Generator
tgt_vocab=len(lbl2id_map.keys())
d_model=512
ocr_model=make_ocr_model(tgt_vocab,N=5,d_model=d_model,d_ff=2048,h=8,dropout=0.1)
ocr_model.to(device)

4. Model training

Before model training, we also need to define model evaluation criteria, iterative optimizer, etc. During the training of this experiment, strategies such as label smoothing and network training warm-up are used. [the calling function is in the train_utils.py file]
Label smoothing can convert the original hard label into soft label, so as to increase the fault tolerance of the model and improve the generalization ability of the model. The * * LabelSmoothing() function in the code implements label smoothing, and the relative entropy function is used internally to calculate the loss between the predicted value and the real value.
The warmup strategy can effectively control the learning rate of the optimizer in the process of model training, and automatically realize the control of the model learning rate from small increase to gradual decrease, which helps the model to be more stable during training and realize the rapid convergence of loss. The noamot() * * function in the code realizes the warmup control, and the Adam optimizer is used to realize the automatic adjustment of the number of iterations of the learning rate.

Label smoothing before model training and network training warm-up are implemented

#train prepare
criterion=LabelSmoothing(size=tgt_vocab,padding_idx=0,smoothing=0.0) #label smoothing
optimizer=torch.optim.Adam(ocr_model.parameters(),lr=0,betas=(0.9,0.98),eps=1e-9)
model_opt=NoamOpt(d_model,1,400,optimizer) #warmup

The code of the model training process is as follows. Every 10 epochs are trained for verification. The calculation process of a single epoch is encapsulated in the * * run_epoch() * * function.

#train&valid...
for epoch in range(nrof_epochs):
    print(f"\nepoch {epoch}")

   print("train...") #train
   ocr_model.train()
   loss_compute=SimpleLossCompute(ocr_model.generator,criterion,model_opt)
   train_mean_loss=run_epoch(train_loader,ocr_model,loss_compute,device)
   if epoch %10 ==0 :
     print("valid...") #verification
     ocr_model.eval()
     valid_loss_compute=SimpleLossCompute(ocr_model.generator,criterion,None)
     valid_mean_loss=run_epoch(valid_loader,ocr_model,valid_loss_compute,device)
     print(f"valid loss:{valid_mean_loss}")

    #save model
    torch.save(ocr_model.state_dict(),'./trained_model/ocr_model.pt')   

The SimpleLossCompute() class implements the loss calculation of transformer output results. When using this class for direct calculation, the class needs to accept (x,y,norm) Three parameters, X is the output result of the decoder, y is the label data, norm is the normalization coefficient of loss, and the number of all valid token s in batch can be used. It can be seen that the construction of all networks of the transformer is being completed here to realize the flow of data calculation flow.
The run_epoch() function completes all the work of epoch training, including data loading, model reasoning, loss calculation and direction propagation, and prints the training process information.

def run_epoch(data_loader,model,loss_compute,device=None):
"Standard Training and Logging Function"
    start=time.time()
    total_tokens=0
    total_loss=0
    tokens=0
    
    for i ,batch in enumerate(data_loader):
        img_input,encode_mask,decode_in,decode_out,decode_mask,ntokens=batch
        img_input=img_input.to(device)
        encode_mask=encode_mask.to(device)
        decode_in=decode_in.to(device)
        decode_out=decode_out.to(device)
        decode_mask=decode_mask.to(device)
        ntokens=torch.sum(ntokens).to(device)

        out=model.forward(img_input,decode_in,encode_mask,decode_mask)
        #Out -- > [BS, 20512] forecast results
        #Decode_out -- > [BS, 20] actual results
        #Actual valid characters in tokens -- > tag
        
        loss=loss_compute(out,decode_out,ntokens)#Loss calculation
        total_loss+=loss
        total_tokens+=ntokens
        tokens+=ntokens
        if i % 50 ==1:
           elapsed=time.time()-start
           print("Epoch Step: %d Loss:%f Tokens per Sec:%f"%(i,loss/ntokens,tokens/elapsed)) 
           start=time.time()
           tokens=0
return total_loss/total_tokens

class SimpleLossCompute:
   "A simple loss compute and train function."
   def __init__(self,generator,criterion,opt=None):
      self.generator=generator
      self.criterion=criterion
      self.opt=opt
   
   def __call__(self,x,y,norm):
   '''
   norm:loss Normalization coefficient of, using batch All valid in token Just count
   '''
   #X -- > out -- > [BS, 20512] forecast results
   #Y -- > decode_out -- > [BS, 20] actual results
   #Actual valid characters in Norm -- > tokens -- > tag
   x=self.generator(x)
   #label smoothing needs to correspond to dimension changes
   x_=x.contiguous().view(-1,x.size(-1)) #[20bs,512]
   y_=y.contiguous().view(-1) #[20bs]
   loss=self.criterion(x_,y_)
   loss /=norm
   loss.backward()
   if self.opt is not None:
      self.opt.step()
      self.opt.optimizer.zero_grad()
   #return loss.data[0]*norm
   return loss.item()*norm   

5. Greedy decoding

We use the simplest greedy decoding to predict the OCR result directly. Because the model will only produce one output at a time, we select the character corresponding to the highest probability in the output probability distribution as the prediction result, and then predict the next character, which is the so-called greedy decoding.
In the experiment, each image is used as the input of the model, greedy decoding is carried out one by one, and the prediction accuracy of the training set and the verification set is finally given.
Greedy decoding

#After the training, the greedy decoding method is used to infer the training set and verification set, and the accuracy is counted
ocr_model.eval()

print('\n-----------------------------------------------')
print("greedy decode trainset")
total_img_num=0
total_correct_num=0
for batch_idx,batch in enumerate(train_loader):
   img_input,encode_mask,decode_in,decode_out,decode_mask,ntokens=batch
   img_input=img_input.to(device)
   encode_mask=encode_mask.to(device)
   
   #Get single image information
   bs=img_input.shape[0]
   for i in range(bs):
       cur_img_input=img_input[i].unsqueeze(0)
       cur_encode_mask=encode_masl[i].unsqueeze(0)
       cur_decode_out=decode_out[i]
       #Greedy decoding
       pred_result=greedy_decode(ocr_model,cur_img_input,cur_encode_mask,max_len=sequence_len,start_symbol=1,end_symbol=2)
       pred_result=pred_result.cpu()
       #Judge whether the prediction is correct
       is_correct=judge_is_correct(pred_result,cur_decode_out)
       total_correct_num+=is_correct
       total_img_num+=1
       if not is_correct:
           #Print case with wrong prediction
           print("----")
           print(cur_decode_out)
           print(pred_result)
       total_correct_rate=total_correct_num/total_img_num*100
       print(f"total correct rate of trainset:{total_correct_rate}%")

#Same as training set decoding code
print("\n----------------------------")
print("greedy decode validset")
total_img_num=0
total_correct_num=0
for batch_idx,batch in enumerate(valid_loader):
    img_input,encode_mask,decode_in,decode_out,decode_mask,ntokens=batch
    img_input=img_input.to(device)
    encode_mask=encode_mask.to(device)

    bs=img_input.shape[0]
    for i in range(bs):
       cur_img_input=img_input[i].unsqueeze(0)
       cur_encode_mask=encode_mask[i].unsqueeze(0)
       cir_decode_out=decode_out[i]
       
       pred_result=greedy_decode(ocr_model,cur_img_input,cur_encode_mask,max_len=sequence_len,start_symbol=1,end_symbol=2)
       pred_result=pred_result.cpu()

       is_correct=judge_is_correct(pred_result,cur_decode_out)
       total_correct_num+=is_correct
       total_img_num+=1
       if not is_correct:
          #Print case of forecast data
          print("----")
          print(cur_decode_out)
          print(pred_result)
       total_correct_rate=total_correct_num/total_img_num*100
       print(f"total correct rate of validset:{total_correct_rate}%")      

greedy decode

#greedy decode
def greedy_decode(model,src,src_mask,max_len,start_symbol,end_symbol):
  memory=model.encode(src,src_mask)
  #ys represents the currently generated sequence. Initially, it is a sequence containing only one starting character, and the prediction result is continuously appended to the end of the sequence
  ys=torch.ones(1,1).fill_(start_symbol).type_as(src.data).lo ng()
  for i in range(max_len-1):
       out=model.decode(memory,src_mask,Variable(ys),Variable(subsequent_mask(ys.size(1)).type_as(src.data)))
       prob=model.generator(out[:-1])
       _,next_word=torch.max(prob,dim=1)
       new_word=next_word.data[0]
       next_word=torch.ones(1,1).type_as(src.data).fill_(next_word).long()
       ys=torch.cat([ys,next_word],dim=1)
       next_word=int(next_word)
       if next_word==end_symbol:
           break  
       #ys=torch.cat([ys,torch.ones(1,1).type_as(src.data).fill_(next_word)],dim=1)
  ys=ys[0,1:]
  return ys

def judge_is_correct(pred,label):
    #Judge whether the predicted results of the model are consistent with the label
    pred_len=pred.shape[0]
    label=label[:pred_len]
    is_correct=1 if label.equal(pred) else 0
    return is_correct     

1

  1. Hands on CV pytoch version ↩︎

Posted by ManWithNoName on Sun, 24 Oct 2021 08:41:35 -0700