Practical tutorial | preservation and migration of pytoch model

Keywords: Python Machine Learning neural networks Pytorch Deep Learning

Practical tutorial | preservation and migration of pytoch model
In this article, the author first introduces several typical scenarios of model reuse; Then it introduces how to view the relevant parameter information in the pytoch model; Then it introduces how to load the model, how to carry out additional training and transfer learning of the model.

1 Introduction
Hello, friends. Welcome to Yuelai inn. Today, I'd like to introduce how to save and load models in the python framework, as well as model migration and retraining. Generally speaking, the most common scenario is the inference process after the model is trained. A network model usually needs to predict new samples after training. At this time, it only needs to build the forward propagation process of the model, and then load the trained parameters to initialize the network.

The second scenario is the retraining process of the model. After a model is trained on a batch of data, it needs to be saved locally, and a batch of new data may be collected after a period of time. Therefore, at this time, the previous model needs to be loaded for incremental training on the new data (or full training on the whole data).

The third application scenario is model transfer learning. This time is to take the pre model already trained by others and initialize it as part of your own network model parameters. For example, if you add several full connection layers to the Bert model to do the classification task, you need to load the parameters in the original Bert model to initialize the weight parameters of the Bert part of your network.

In the following article, the author takes the above three scenarios as examples to introduce how to use the pytoch framework to complete the above process.

2 model preservation and reuse
In Python, we can complete the main steps in the above scenario through torch.save() and torch.load(). Next, the author will take the previously introduced LeNet5 network model as an example. But before that, let's take a look at the saved form of model parameters in pytoch.

2.1 viewing network model parameters
(1) View parameters

First, define the network model structure of LeNet5, as shown in the following code:

class LeNet5(nn.Module):
    def __init__(self, ):
        super(LeNet5, self).__init__()
        self.conv = nn.Sequential(  # [n,1,28,28]
            nn.Conv2d(1, 6, 5, padding=2),  # in_channels, out_channels, kernel_size
            nn.ReLU(),  # [n,6,24,24]
            nn.MaxPool2d(2, 2),  # kernel_size, stride  [n,6,14,14]
            nn.Conv2d(6, 16, 5),  # [n,16,10,10]
            nn.ReLU(),
            nn.MaxPool2d(2, 2))  # [n,16,5,5]
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(16 * 5 * 5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, 10))
    def forward(self, img):
        output = self.conv(img)
        output = self.fc(output)
        return output

After defining the class of LeNet5 network structure, as long as we complete the instantiation of this class, the corresponding weight parameters in the network have also completed the initialization, that is, we have an initial value. At the same time, we can access it in the following ways:

Print model's state_dict

print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
The output results are:

conv.0.weight   torch.Size([6, 1, 5, 5])
conv.0.bias   torch.Size([6])
conv.3.weight   torch.Size([16, 6, 5, 5])

It can be found that the parameter model.state in the network model_ Dict() is actually saved in the form of a dictionary (essentially OrderedDict in the collections module):

print(model.state_dict().keys())

odict_keys(['conv.0.weight', 'conv.0.bias', 'conv.3.weight',

'conv.3.bias', 'fc.1.weight', 'fc.1.bias', 'fc.3.weight', 'fc.3.bias',
'fc.5.weight', 'fc.5.bias'])
(2) Custom parameter prefix

At the same time, there are two points worth noting here: ① the fc and conv prefixes in the parameter name are determined according to the name when you define nn.Sequential() above; ② The number in the parameter name indicates the location of the network layer in each sequential (). For example, the network structure is defined as follows:

class LeNet5(nn.Module):
    def __init__(self, ):
        super(LeNet5, self).__init__()
        self.moon = nn.Sequential(  # [n,1,28,28]
            nn.Conv2d(1, 6, 5, padding=2),  # in_channels, out_channels, kernel_size
            nn.ReLU(),  # [n,6,24,24]
            nn.MaxPool2d(2, 2),  # kernel_size, stride  [n,6,14,14]
            nn.Conv2d(6, 16, 5),  # [n,16,10,10]
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Flatten(),
            nn.Linear(16 * 5 * 5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, 10))

Then the parameter name is:

print(model.state_dict().keys())
odict_keys(['moon.0.weight', 'moon.0.bias', 'moon.3.weight',
 'moon.3.bias', 'moon.7.weight', 'moon.7.bias', 'moon.9.weight', 
'moon.9.bias', 'moon.11.weight', 'moon.11.bias'])

Understanding this is very helpful for us to analyze and load some pre training models.

In addition, the optimizer in also has a corresponding state_dict() method to obtain parameters for, for example:

optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
   print(var_name, "\t", optimizer.state_dict()[var_name])

Optimizer's state_dict:
state   {}
param_groups   [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 
'weight_decay': 0, 'nesterov': False, 
'params': [140239245300504, 140239208339784, 140239245311360, 
140239245310856, 140239266942480, 140239266942552, 140239266942624, 
140239266942696, 140239266942912, 140239267041352]}]

After introducing the viewing method of model parameters, you can enter the content introduction of model reuse stage.

2.2 load model for inference
(1) Model saving

In pytoch, it is very simple to save the model. Generally, it can be realized through the following two lines of code:

model_save_path = os.path.join(model_save_dir, 'model.pt')
torch.save(model.state_dict(), model_save_path)

When specifying the saved model name, the suffix officially recommended by Pytorch is. pt or. pth (of course, it is not mandatory). Finally, you only need to add the second line of code in the appropriate place to save the model.

At the same time, if you want to save the optimal model under certain conditions during training, you should do so in the following ways:

best_model_state = deepcopy(model.state_dict()) 
torch.save(best_model_state, model_save_path)

instead of:

best_model_state = model.state_dict() 
torch.save(best_model_state, model_save_path)

Because the latter is the best_model_state only gets model.state_dict(), which will still change with the training process.

(2) Reuse the model for inference

In the inference process, first complete the initialization of the network, and then load the existing model parameters to cover the weight parameters in the network. The example code is as follows:

def inference(data_iter, device, model_save_dir='./MODEL'):   
    model = LeNet5()  # Initializes the weight parameters of an existing model    
    model.to(device)    
    model_save_path = os.path.join(model_save_dir, 'model.pt')    
    if os.path.exists(model_save_path):        
        loaded_paras = torch.load(model_save_path)        
    model.load_state_dict(loaded_paras)  # Reinitialize the network weight parameters with the local existing model     
    model.eval() # Be careful not to forget    
    with torch.no_grad():        
        acc_sum, n = 0.0, 0        
        for x, y in data_iter:            
        x, y = x.to(device), y.to(device)            
        logits = model(x)            
        acc_sum += (logits.argmax(1) == y).float().sum().item()            
        n += len(y)        
        print("Accuracy in test data is :", acc_sum / n)

In the above code, lines 4-7 are used to load local model parameters and overwrite the original parameters in the network model. In this way, subsequent inference can be carried out:

Accuracy in test data is : 0.8851
2.3 loading model for training
After introducing the preservation and reuse of the model, the additional training of the network is very simple. The simplest way is to save only the network weight during the training process, and then only load the network weight parameters to initialize the network for training during the subsequent additional training. An example is as follows (see [2] for the complete code):

  def train(self):
        #......
        model_save_path = os.path.join(self.model_save_dir, 'model.pt')
        if os.path.exists(model_save_path):
            loaded_paras = torch.load(model_save_path)
            self.model.load_state_dict(loaded_paras)
            print("#### Successfully load the existing model for additional training... ")
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)  # Define optimizer
       #......
        for epoch in range(self.epochs):
            for i, (x, y) in enumerate(train_iter):
                x, y = x.to(device), y.to(device)
                logits = self.model(x)
                # ......
            print("Epochs[{}/{}]--acc on test {:.4}".format(epoch, self.epochs,
                                              self.evaluate(test_iter, self.model, device)))
            torch.save(self.model.state_dict(), model_save_path)

In this way, the additional training of the model is completed:

Successfully load the existing model for additional training

Epochs[0/5]—batch[938/0]—acc 0.9062—loss 0.2926
Epochs[0/5]—batch[938/100]—acc 0.9375—loss 0.1598
...
In addition, you can also save optimizer parameters and loss values when saving parameters, and then restore them together with other parameters when restoring the model. Examples are as follows:

model_save_path = os.path.join(model_save_dir, 'model.pt')
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, model_save_path)

The loading method is as follows:

checkpoint = torch.load(model_save_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

2.4 loading models for migration
(1) Define new model

So far, even if the introduction of the first two application scenarios is completed, it can be found that it is not complex on the whole, but it will be a little more complex for the application in scenario 3.

Suppose there is a LeNet6 network model, which adds a full connection layer based on LeNet5. Its definition is as follows:

class LeNet6(nn.Module):
    def __init__(self, ):
        super(LeNet6, self).__init__()
        self.conv = nn.Sequential(  # [n,1,28,28]
            nn.Conv2d(1, 6, 5, padding=2),  # in_channels, out_channels, kernel_size
            nn.ReLU(),  # [n,6,24,24]
            nn.MaxPool2d(2, 2),  # kernel_size, stride  [n,6,14,14]
            nn.Conv2d(6, 16, 5),  # [n,16,10,10]
            nn.ReLU(),
            nn.MaxPool2d(2, 2))  # [n,16,5,5]
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(16 * 5 * 5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, 64), 
            nn.ReLU(),
            nn.Linear(64, 10) ) # New full connection layer

Next, we need to migrate the weight parameters trained on LeNet5 to LeNet6 network. From the definition of LeNet6 above, we can find that although only one full connection layer is added, the dimension of the parameters of the penultimate layer has also changed. Therefore, for LeNet6, only the weight parameters of the first four layers of LeNet5 network can be reused.

(2) View model parameters

After getting a model parameter, we can first load it and view the information of relevant parameters:

model_save_path = os.path.join('./MODEL', 'model.pt')
loaded_paras = torch.load(model_save_path)
for param_tensor in loaded_paras:
    print(param_tensor, "\t", loaded_paras[param_tensor].size())
#----Reusable part
conv.0.weight    torch.Size([6, 1, 5, 5])
conv.0.bias      torch.Size([6])
conv.3.weight    torch.Size([16, 6, 5, 5])
conv.3.bias      torch.Size([16])
fc.1.weight      torch.Size([120, 400])
fc.1.bias    torch.Size([120])
fc.3.weight      torch.Size([84, 120])
fc.3.bias    torch.Size([84])
#-----Non reusable part
fc.5.weight      torch.Size([10, 84])
fc.5.bias    torch.Size([10])

Meanwhile, the parameter information for LeNet6 network is:

model = LeNet6()
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())
#
conv.0.weight    torch.Size([6, 1, 5, 5])
conv.0.bias      torch.Size([6])
conv.3.weight    torch.Size([16, 6, 5, 5])
conv.3.bias      torch.Size([16])
fc.1.weight      torch.Size([120, 400])
fc.1.bias    torch.Size([120])
fc.3.weight      torch.Size([84, 120])
fc.3.bias    torch.Size([84])
#------Newly added part
fc.5.weight      torch.Size([64, 84])
fc.5.bias    torch.Size([64])
fc.7.weight      torch.Size([10, 64])
fc.7.bias    torch.Size([10])

After clarifying the parameters of the old and new models, we can take out the parameters we need in LeNet5, and then switch to the network of LeNet6.

(3) Model migration

Although the locally loaded model parameters (above loaded_params) and the initialized model parameters (above model.state_dict()) are in the form of a dictionary, we cannot directly change model. State_ Weight parameter in dict(). You need to construct a state first_ Dict and then through model.load_ state_ The dict () method to reinitialize the parameters in the network.

Meanwhile, in this process, we need to filter out the non reusable parts of the local model. The specific code is as follows:

def para_state_dict(model, model_save_dir):
    state_dict = deepcopy(model.state_dict())
    model_save_path = os.path.join(model_save_dir, 'model.pt')
    if os.path.exists(model_save_path):
        loaded_paras = torch.load(model_save_path)
        for key in state_dict:  # Traverse the corresponding parameters in the new network model
            if key in loaded_paras and state_dict[key].size() == loaded_paras[key].size():
                print("Parameters initialized successfully:", key)
                state_dict[key] = loaded_paras[key]
    return state_dict

In the above code, the function of line 2 is to copy the original parameters in the network (LeNet6) first; Lines 6-9 replace the corresponding part in LeNet6 with the reusable part in the local model parameter (LeNet5), and line 7 is the condition for judging the availability. At the same time, it should be noted that the screening methods may be different in different cases, so the specific situation needs specific analysis, but the overall logic is the same.

Finally, we only need to call this function before model training, and then re initialize some of the weight parameters in LeNet6 to [2]:

state_dict = para_state_dict(self.model, self.model_save_dir)
self.model.load_state_dict(state_dict)

The training results are as follows:

Successfully initialized parameter: conv.0.weight
Successfully initialized parameter: conv.0.bias
Successfully initialized parameter: conv.3.weight
Successfully initialized parameter: conv.3.bias
Successfully initialized parameter: fc.1.weight
Successfully initialized parameter: fc.1.bias
Successfully initialized parameter: fc.3.weight
Successfully initialized parameter: fc.3.bias

Successfully load the existing model for additional training

Epochs[0/5]—batch[938/0]—acc 0.1094—loss 2.512
Epochs[0/5]—batch[938/100]—acc 0.9375—loss 0.2141
Epochs[0/5]—batch[938/200]—acc 0.9219—loss 0.2729
Epochs[0/5]—batch[938/300]—acc 0.8906—loss 0.2958
...
Epochs[0/5]—batch[938/900]—acc 0.8906—loss 0.2828
Epochs[0/5]–acc on test 0.8808
It can be found that after about 100 batch es, the accuracy of the model is improved.

3 Summary
In this article, the author first introduces several typical scenarios of model reuse; Then it introduces how to view the relevant parameter information in the pytoch model; Then it introduces how to load the model, how to carry out additional training and transfer learning of the model.

This is the end of the content. Thank you for reading!

quote
[1] SAVING AND LOADING MODELS https://pytorch.org/tutorials/beginner/saving_loading_models.html

[2] Sample code https://github.com/moon-hotel/DeepLearningWithMe

Posted by prometheos on Fri, 08 Oct 2021 23:22:46 -0700