-
Cause: I recently read articles related to CL(curriculum learning) and found SPL learning strategy. In short, it is to make the data learned by model from simple to easy. Reading articles related to SPL is bound to skip this article: Self-Paced Learning for Latent Variable Models
-
Difficulty: This article was written in 2010. The code based on pytorch implementation is difficult to find. After looking for it for two days, I finally found something
Topic: theoretical part
Import Formula 1: basic formula
- The optimized parameters of the formula are: w
- The first r(w) function in the brackets is the regularization function. For convenience, this part will not be expanded (it has no impact on subsequent results)
- The second item in parentheses $$\ sum_ The {I = 1} ^ {n} f() $$part is the part we often write. We use w and x to get p_red, and y use the loss function to calculate loss, and then update the parameters
Introduce formula 2: basic formula + SPL
- Introducing v parameter control learning, v can take 0,1. It means that the data is difficult and easy
- The optimized parameters of the formula are: w and v
- The third item in brackets $$- \ frac{1}{K} \sum_{i=1}^{n} v_{i} $$is a regular item, and there are many variations. See reference [0]. The regular item in the materials I found was modified to hard version, that is $$- \ lambda \sum_{i=1}^{n} v_{i}$$
Introduce formula 3: basic formula + SPL (variant hard)
- Based on Formula 2, and modify the regular term
- The optimized parameters of the formula are: w and v
- The first r(w) function in the brackets is the regularization function. For convenience, this part will not be expanded (it has no impact on subsequent results)
- The second item in parentheses $$\ sum_{i=1}^{n}v_{i} f () $$section added v_i. That is, V. if 0 is taken, it means that the data is difficult, and the overall value is 0, that is, this part has no impact on the follow-up, so that the difficult data is not learned. If 1 is taken, the opposite is true
- The third item in brackets $$- \ lambda \sum_{i=1}^{n} v_{i} $$$\ lambda $$is used to compare with the value of loss to determine whether the data is simple or not. The code is explained as follows
def spl_loss(super_loss, lambda_a): # If the loss of the model is < lambda -- > V = 1, it means that the data set is simple # Otherwise -- > V = 0 indicates that the dataset is difficult to use v = super_loss < lambda_a return v.int()
And it needs to increase with the increase of epoch. With the increase of epoch, more data sets are selected. The code is explained as follows
def increase_threshold(lambda_a, growing_factor): lambda_a *= growing_factor return lambda_a
Topic: part of the code structure is complete, and the runnable code is shown in reference [1]
SPL-LOSS part
import torch from torch import Tensor import torch.nn as nn class SPLLoss(nn.NLLLoss): def __init__(self, *args, n_samples=0, **kwargs): super(SPLLoss, self).__init__(*args, **kwargs) self.threshold = 0.1 self.growing_factor = 1.35 self.v = torch.zeros(n_samples).int() def forward(self, input: Tensor, target: Tensor, index: Tensor) -> Tensor: super_loss = nn.functional.nll_loss(input, target, reduction="none") v = self.spl_loss(super_loss) self.v[index] = v return (super_loss * v).mean() # Increase the size of each workout by increasing the threshold def increase_threshold(self): self.threshold *= self.growing_factor def spl_loss(self, super_loss): # If the loss of the model is < threshold -- > V = 1, it means that the data set is simple # Otherwise -- > V = 0 indicates that the dataset is difficult to use v = super_loss < self.threshold return v.int()
train part
def train(): model = Model(2, 2) dataloader = get_dataloader() criterion = SPLLoss(n_samples=len(dataloader.dataset)) optimizer = optim.Adam(model.parameters()) for epoch in range(10): for index, data, target in tqdm.tqdm(dataloader): optimizer.zero_grad() output = model(data) loss = criterion(output, target, index) loss.backward() optimizer.step() criterion.increase_threshold() plot(dataloader.dataset, model, criterion) animation = camera.animate() animation.save("plot.gif")
reference resources
[0]: Shu Jun, Meng Deyu, Xu zongben. Yuan self-learning. Chinese Science: Information Science, 2020, 50: 781 – 793, doi: 10.1360/SSI-2020-0005 Shu J, Meng d y, Xu Z B. meta self-paced learning (in Chinese). SCI sin info, 2020, 50: 781 – 793, doi: 10.1360/SSI-2020-0005
[1]: GitHub