When pytorch defines the model, it is inheritance (nn.Module) class. Generally, when we initialize, we define different modules and call them in forward function. Sometimes, when we initialize, we initialize some unused modules without deleting them or not using them in forward function, which will affect the convergence speed of the network. For instance
The first is that self.attention and self.decoder are not deleted, nor are they used in forward.
This slows down the convergence rate.
class Bert_Ocr(nn.Module): def __init__(self, cfg): super().__init__() self.transformer = Transformer(cfg, cfg.attention_layers) self.attention = Parallel_Attention(cfg) self.decoder = Two_Stage_Decoder(cfg) def forward(self, x, mask): x1 = self.transformer(x, mask) # x_atten = self.attention(x1) # glimpses = torch.bmm(x_atten.permute(0, 2, 1), x) # res1 = self.decoder(glimpses) return x1 >>> >[500/300000] valid loss: 0.88969 accuracy: 0.000, norm_ED: 29.00 >[1000/300000] valid loss: 0.30434 accuracy: 39.773, norm_ED: 9.86 >[1500/300000] valid loss: 0.14993 accuracy: 70.455, norm_ED: 4.29
Second: Delete unused modules in init in time
class Bert_Ocr(nn.Module): def __init__(self, cfg): super().__init__() self.transformer = Transformer(cfg, cfg.attention_layers) # self.attention = Parallel_Attention(cfg) # self.decoder = Two_Stage_Decoder(cfg) def forward(self, x, mask): x1 = self.transformer(x, mask) # x_atten = self.attention(x1) # glimpses = torch.bmm(x_atten.permute(0, 2, 1), x) # res1 = self.decoder(glimpses) return x1 >>> >[500/300000] valid loss: 0.46041 accuracy: 27.273, norm_ED: 14.14 >[1000/300000] valid loss: 0.03799 accuracy: 93.182, norm_ED: 0.86
Summary: Delete unused modules in init in time