[pytorch] free some networks

preface

The best, most efficient and most concise is "scheme 1".

Scheme I

Step 1: fixed basic network

Code template:

# Gets the state of the part to be fixed_ dict:
pre_state_dict = torch.load(model_path, map_location=torch.device('cpu')

# Imported (remember strict=False):
model.load_state_dict(pre_state_dict, strict=False)
print('Load model from %s success.' % model_path)

# Fixed basic network:
model = freeze_model(model=model, to_freeze_dict=pre_state_dict)
among freeze_model The functions are as follows: 
def freeze_model(model, to_freeze_dict, keep_step=None):

    for (name, param) in model.named_parameters():
        if name in to_freeze_dict:
            param.requires_grad = False
        else:
            pass

    # # Print the current fixation (negligible):
    # freezed_num, pass_num = 0, 0
    # for (name, param) in model.named_parameters():
    #     if param.requires_grad == False:
    #         freezed_num += 1
    #     else:
    #         pass_num += 1
    # print('\n Total {} params, miss {} \n'.format(freezed_num + pass_num, pass_num))

    return model

Note:

  • If the preloaded model is a distributed model trained in model = nn.DataParallel(model) mode, each parameter name will be prefixed with. module by default.
  • Accordingly, it will result in the inability to import the stand-alone model. At this time, you need to add the following statement:
# Gets the state of the part to be fixed_ dict:
pre_state_dict = torch.load(model_path, map_location=torch.device('cpu')
Replace with: 
# Gets the state of the part to be fixed_ dict:
pre_state_dict = torch.load(model_path, map_location=torch.device('cpu')
pre_state_dict = {k.replace('module.', ''): v for k, v in pre_state_dict.items()}

Step 2: let optimizer avoid the parameter to free

Code template:

optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1, momentum=0.9, weight_decay=1e-4)

Step 3. Use. eval() to free the train

Because: even if requires is set for bn_ Grad = false. Once model.train(), bn will still secretly start update (update will be stopped in model.eval() mode). (see details) [pytorch]bn) Therefore: before each epoch of train, it is necessary to redefine this piece uniformly, otherwise it is easy to cause problems.

model.eval()
model.stage4_xx.train()
model.pred_xx.train()

Scheme II

To perform the free operation under pytorch, you generally need to go through the following four steps.

Step 1: fixed basic network

Code template:

# Get the state_dict of the part to be fixed:
pre_state_dict = torch.load(model_path, map_location=torch.device('cpu')

# Imported (remember strict=False):
model.load_state_dict(pre_state_dict, strict=False)
print('Load model from %s success.' % model_path)

# Fixed basic network:
model = freeze_model(model=model, to_freeze_dict=pre_state_dict)
among freeze_model The functions are as follows: 
def freeze_model(model, to_freeze_dict, keep_step=None):

    for (name, param) in model.named_parameters():
        if name in to_freeze_dict:
            param.requires_grad = False
        else:
            pass

    # # Print the current fixation (negligible):
    # freezed_num, pass_num = 0, 0
    # for (name, param) in model.named_parameters():
    #     if param.requires_grad == False:
    #         freezed_num += 1
    #     else:
    #         pass_num += 1
    # print('\n Total {} params, miss {} \n'.format(freezed_num + pass_num, pass_num))

    return model

Step 2: let optimizer avoid the parameter to free

Code template:

optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1, momentum=0.9, weight_decay=1e-4)

Step 3: fix bn

(refer to bn) even if requirements_grad = false is set for bn in step 1, bn will secretly start update once model.train() (update will be stopped in model.eval() mode). Therefore, additional deep fixation bn is required:

  • Fixed momentum: momentum=0.0
  • Turn off track_running_stats: track_running_stats=False

give an example:

self.bn = nn.BatchNorm2d(out_channels, eps=1e-5)

Amend to read:

self.bn = nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.0, track_running_stats=False)

However, track_running_stats=False will bring side effects: each affected BN will lose three corresponding key value pairs in state_dict (the corresponding keys of each group are xx.xx.bn.running_mean, xx.xx.bn.running_var and xx.xx.bn.num_batches_tracked)

Step 4: normal training

During training, remember to regularly check whether the fixed part is constant:

  • For example, every time you eval, check the prediction accuracy of the fixed part by the way.

Step 5: post processing

4.1 restart track_running_stats

give an example:

self.bn = nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.0, track_running_stats=False)

Amend to read:

self.bn = nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.0)

At this time, each bn previously affected will recover the three lost key value pairs in state_dict (but value is empty and to be filled).

Note:

  • Although the online training uses the network that has been freezed, you should honestly exchange the network that has not been freezed for offline testing. Otherwise, the results will not only be inconsistent, but the task s that have been freezed and have not been freezed will perform worse!
4.2 recovery of missing value s

In order to overcome the side effects caused by track_running_stats=False, the final model needs to rely on the combination of "original state_dict" and "trained state_dict". The former supplements the missing value for the latter.

# Original state_dict:
origin_state_dict = torch.load(origin_model_path, map_location=torch.device('cpu'))
# Trained state_dict:
new_state_dict = torch.load(new_model_path, map_location=torch.device('cpu'))

# The latter supplements the missing key value pairs from the former:
final_dict = new_state_dict.copy()
for (key, val) in origin_state_dict.items():
    if key not in final_dict:
        final_dict[key] = val

# When loading the merged state_dict, you can use strict=True:
model.load_state_dict(final_dict, strict=True)
Then start again save Again model,It can be used directly in the end model The file is. 

Posted by arcarocket on Mon, 06 Dec 2021 11:28:46 -0800