[pytorch] freeze part of the network

Keywords: Python Pytorch Deep Learning

Preface

The best, most efficient and most concise is Plan One.

Scheme One

Step 1: Fixed basic network

Code template:

# Get the state_dict for the fixed part:
pre_state_dict = torch.load(model_path, map_location=torch.device('cpu')

# Imported (remember strict=False):
model.load_state_dict(pre_state_dict, strict=False)
print('Load model from %s success.' % model_path)

# Fixed basic network:
model = freeze_model(model=model, to_freeze_dict=pre_state_dict)
among freeze_model The functions are as follows: 
def freeze_model(model, to_freeze_dict, keep_step=None):

    for (name, param) in model.named_parameters():
        if name in to_freeze_dict:
            param.requires_grad = False
        else:
            pass

    # # Print the current fixed condition (ignorable):
    # freezed_num, pass_num = 0, 0
    # for (name, param) in model.named_parameters():
    #     if param.requires_grad == False:
    #         freezed_num += 1
    #     else:
    #         pass_num += 1
    # print('\n Total {} params, miss {} \n'.format(freezed_num + pass_num, pass_num))

    return model

Note:

  • If the preloaded model is a distributed model trained in model = nn.DataParallel(model) mode, then each parameter name is prefixed with a.module by default.
  • Correspondingly, this will result in the inability to import the single-machine model with a check mark. The following statements are required:
# Get the state_dict for the fixed part:
pre_state_dict = torch.load(model_path, map_location=torch.device('cpu')
Change to: 
# Get the state_dict for the fixed part:
pre_state_dict = torch.load(model_path, map_location=torch.device('cpu')
pre_state_dict = {k.replace('module.', ''): v for k, v in pre_state_dict.items()}

Step 2. Ask optimizer to avoid freeze parameters

Code template:

optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1, momentum=0.9, weight_decay=1e-4)

Step 3. freeze with.eval() when train

Because: even if requires_grad = False is set on bn, once model.train(), BN will secretly open update (model.eval() mode and stop update). (See BN for more details)
So: before each epoch of train, it needs to be redefined uniformly. Otherwise, it is easy to have problems.

model.eval()
model.stage4_xx.train()
model.pred_xx.train()

Option 2

The freeze operation under pytorch generally takes four steps.

Step 1: Fixed basic network

Code template:

# Get the state_dict for the fixed part:
pre_state_dict = torch.load(model_path, map_location=torch.device('cpu')

# Imported (remember strict=False):
model.load_state_dict(pre_state_dict, strict=False)
print('Load model from %s success.' % model_path)

# Fixed basic network:
model = freeze_model(model=model, to_freeze_dict=pre_state_dict)
among freeze_model The functions are as follows: 
def freeze_model(model, to_freeze_dict, keep_step=None):

    for (name, param) in model.named_parameters():
        if name in to_freeze_dict:
            param.requires_grad = False
        else:
            pass

    # # Print the current fixed condition (ignorable):
    # freezed_num, pass_num = 0, 0
    # for (name, param) in model.named_parameters():
    #     if param.requires_grad == False:
    #         freezed_num += 1
    #     else:
    #         pass_num += 1
    # print('\n Total {} params, miss {} \n'.format(freezed_num + pass_num, pass_num))

    return model

Step 2. Ask optimizer to avoid freeze parameters

Code template:

optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1, momentum=0.9, weight_decay=1e-4)

Step 3 Fixed bn

(Reference to "bn") Even if requires_grad = False is set on a pair of BNS by steps, once model.train(), BN will secretly open update (model.eval() mode and stop update again.
Therefore, additional deep fixation of bn is required:

  • Fixed momentum:momentum=0.0
  • Pinch track_running_stats:track_running_stats=False

Give an example:

self.bn = nn.BatchNorm2d(out_channels, eps=1e-5)

Modify to:

self.bn = nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.0, track_running_stats=False)

However, track_running_stats=False has side effects: each BN affected loses three corresponding key-value pairs in state_dict (each set of corresponding keys is xx.xx.bn.running_mean, xx.xx.bn.running_var and xx.xx.bn.num_batches_tracked)

Step 4. Normal Training

During the training process, remember to check regularly if the fixed part is constant:

  • For example, every time you eval, check the prediction accuracy of the fixed part.

Step 5. Post-processing

4.1 Restart track_running_stats

Give an example:

self.bn = nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.0, track_running_stats=False)

Modify to:

self.bn = nn.BatchNorm2d(out_channels, eps=1e-5, momentum=0.0)

At this point, each bn previously affected will recover the three missing key-value pairs in the state_dict (but the value is empty and to be filled).

Note:

  • Online training uses a network that has been freezed, but when testing offline, faithfully switch back to a network that has not been freezed. Otherwise, the results will not only be uneven, but will also be worse for both freeze and non-freeze task s!
4.2 Recovering missing value s

To overcome the side effects of track_running_stats=False, the final model relies on a combination of "original state_dict" and "trained state_dict", which supplements the missing value for the latter.

# Original state_dict:
origin_state_dict = torch.load(origin_model_path, map_location=torch.device('cpu'))
# Trained state_dict:
new_state_dict = torch.load(new_model_path, map_location=torch.device('cpu'))

# The latter supplements the missing key-value pairs from the former:
final_dict = new_state_dict.copy()
for (key, val) in origin_state_dict.items():
    if key not in final_dict:
        final_dict[key] = val

# Load the merged state_dict, which must pass strict=True:
model.load_state_dict(final_dict, strict=True)
Then come back again save Once model´╝îIs ultimately directly usable model File. 

Posted by tauchai83 on Tue, 14 Sep 2021 09:46:30 -0700