Preface
The best, most efficient and most concise is Plan One.
Scheme One
Step 1: Fixed basic network
Code template:
# Get the state_dict for the fixed part: pre_state_dict = torch.load(model_path, map_location=torch.device('cpu') # Imported (remember strict=False): model.load_state_dict(pre_state_dict, strict=False) print('Load model from %s success.' % model_path) # Fixed basic network: model = freeze_model(model=model, to_freeze_dict=pre_state_dict) among freeze_model The functions are as follows: def freeze_model(model, to_freeze_dict, keep_step=None): for (name, param) in model.named_parameters(): if name in to_freeze_dict: param.requires_grad = False else: pass # # Print the current fixed condition (ignorable): # freezed_num, pass_num = 0, 0 # for (name, param) in model.named_parameters(): # if param.requires_grad == False: # freezed_num += 1 # else: # pass_num += 1 # print('\n Total {} params, miss {} \n'.format(freezed_num + pass_num, pass_num)) return model
Note:
- If the preloaded model is a distributed model trained in model = nn.DataParallel(model) mode, then each parameter name is prefixed with a.module by default.
- Correspondingly, this will result in the inability to import the single-machine model with a check mark. The following statements are required:
# Get the state_dict for the fixed part: pre_state_dict = torch.load(model_path, map_location=torch.device('cpu') Change to: # Get the state_dict for the fixed part: pre_state_dict = torch.load(model_path, map_location=torch.device('cpu') pre_state_dict = {k.replace('module.', ''): v for k, v in pre_state_dict.items()}
Step 2. Ask optimizer to avoid freeze parameters
Code template:
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1, momentum=0.9, weight_decay=1e-4)
Step 3. freeze with.eval() when train
Because: even if requires_grad = False is set on bn, once model.train(), BN will secretly open update (model.eval() mode and stop update). (See BN for more details)
So: before each epoch of train, it needs to be redefined uniformly. Otherwise, it is easy to have problems.
model.eval() model.stage4_xx.train() model.pred_xx.train()
Option 2
The freeze operation under pytorch generally takes four steps.
Step 1: Fixed basic network
Code template:
# Get the state_dict for the fixed part: pre_state_dict = torch.load(model_path, map_location=torch.device('cpu') # Imported (remember strict=False): model.load_state_dict(pre_state_dict, strict=False) print('Load model from %s success.' % model_path) # Fixed basic network: model = freeze_model(model=model, to_freeze_dict=pre_state_dict) among freeze_model The functions are as follows: def freeze_model(model, to_freeze_dict, keep_step=None): for (name, param) in model.named_parameters(): if name in to_freeze_dict: param.requires_grad = False else: pass # # Print the current fixed condition (ignorable): # freezed_num, pass_num = 0, 0 # for (name, param) in model.named_parameters(): # if param.requires_grad == False: # freezed_num += 1 # else: # pass_num += 1 # print('\n Total {} params, miss {} \n'.format(freezed_num + pass_num, pass_num)) return model
Step 2. Ask optimizer to avoid freeze parameters
Code template:
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1, momentum=0.9, weight_decay=1e-4)
Step 3 Fixed bn
(Reference to "bn") Even if requires_grad = False is set on a pair of BNS by steps, once model.train(), BN will secretly open update (model.eval() mode and stop update again.
Therefore, additional deep fixation of bn is required:
- Fixed momentum:momentum=0.0
- Pinch track_running_stats:track_running_stats=False
Give an example:
self.bn = nn.BatchNorm2d(out_channels, eps=1e-5)
Modify to:
self.bn = nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.0, track_running_stats=False)
However, track_running_stats=False has side effects: each BN affected loses three corresponding key-value pairs in state_dict (each set of corresponding keys is xx.xx.bn.running_mean, xx.xx.bn.running_var and xx.xx.bn.num_batches_tracked)
Step 4. Normal Training
During the training process, remember to check regularly if the fixed part is constant:
- For example, every time you eval, check the prediction accuracy of the fixed part.
Step 5. Post-processing
4.1 Restart track_running_stats
Give an example:
self.bn = nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.0, track_running_stats=False)
Modify to:
self.bn = nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.0)
At this point, each bn previously affected will recover the three missing key-value pairs in the state_dict (but the value is empty and to be filled).
Note:
- Online training uses a network that has been freezed, but when testing offline, faithfully switch back to a network that has not been freezed. Otherwise, the results will not only be uneven, but will also be worse for both freeze and non-freeze task s!
4.2 Recovering missing value s
To overcome the side effects of track_running_stats=False, the final model relies on a combination of "original state_dict" and "trained state_dict", which supplements the missing value for the latter.
# Original state_dict: origin_state_dict = torch.load(origin_model_path, map_location=torch.device('cpu')) # Trained state_dict: new_state_dict = torch.load(new_model_path, map_location=torch.device('cpu')) # The latter supplements the missing key-value pairs from the former: final_dict = new_state_dict.copy() for (key, val) in origin_state_dict.items(): if key not in final_dict: final_dict[key] = val # Load the merged state_dict, which must pass strict=True: model.load_state_dict(final_dict, strict=True) Then come back again save Once model,Is ultimately directly usable model File.