SPL (self paced learning for late variable models) code reproduction - pytorch version

  1. Cause: I recently read articles related to CL(curriculum learning) and found SPL learning strategy. In short, it is to make the data learned by model from simple to easy. Reading articles related to SPL is bound to skip this article: Self-Paced Learning for Latent Variable Models

  2. Difficulty: This article was written in 2010. The code based on pytorch implementation is difficult to find. After looking for it for two days, I finally found something

Topic: theoretical part

Import Formula 1: basic formula

  1. The optimized parameters of the formula are: w

\[\mathbf{w}_{t+1}=\underset{\mathbf{w} \in \mathbb{R}^{d}}{\operatorname{argmin}}\left(r(\mathbf{w})+\sum_{i=1}^{n} f\left(\mathbf{x}_{i}, \mathbf{y}_{i} ; \mathbf{w}\right)\right) \]

  1. The first r(w) function in the brackets is the regularization function. For convenience, this part will not be expanded (it has no impact on subsequent results)
  2. The second item in parentheses $$\ sum_ The {I = 1} ^ {n} f() $$part is the part we often write. We use w and x to get p_red, and y use the loss function to calculate loss, and then update the parameters

Introduce formula 2: basic formula + SPL

  1. Introducing v parameter control learning, v can take 0,1. It means that the data is difficult and easy
  2. The optimized parameters of the formula are: w and v

\[\left(\mathbf{w}_{t+1}, \mathbf{v}_{t+1}\right)=\underset{\mathbf{w} \in \mathbb{R}^{d}, \mathbf{v} \in\{0,1\}^{n}}{\operatorname{argmin}}\left(r(\mathbf{w})+\sum_{i=1}^{n} v_{i} f\left(\mathbf{x}_{i}, \mathbf{y}_{i} ; \mathbf{w}\right)-\frac{1}{K} \sum_{i=1}^{n} v_{i}\right) \]

  1. The third item in brackets $$- \ frac{1}{K} \sum_{i=1}^{n} v_{i} $$is a regular item, and there are many variations. See reference [0]. The regular item in the materials I found was modified to hard version, that is $$- \ lambda \sum_{i=1}^{n} v_{i}$$

Introduce formula 3: basic formula + SPL (variant hard)

  1. Based on Formula 2, and modify the regular term
  2. The optimized parameters of the formula are: w and v

\[L=r(w)+\sum_{i=1}^{n} v_{i} f\left(x_{i}, y_{i}, w\right)-\lambda \sum_{i=1}^{n} v_{I} \]

  1. The first r(w) function in the brackets is the regularization function. For convenience, this part will not be expanded (it has no impact on subsequent results)
  2. The second item in parentheses $$\ sum_{i=1}^{n}v_{i} f () $$section added v_i. That is, V. if 0 is taken, it means that the data is difficult, and the overall value is 0, that is, this part has no impact on the follow-up, so that the difficult data is not learned. If 1 is taken, the opposite is true
  3. The third item in brackets $$- \ lambda \sum_{i=1}^{n} v_{i} $$$\ lambda $$is used to compare with the value of loss to determine whether the data is simple or not. The code is explained as follows
    def spl_loss(super_loss, lambda_a):
        # If the loss of the model is < lambda -- > V = 1, it means that the data set is simple
        # Otherwise -- > V = 0 indicates that the dataset is difficult to use
        v = super_loss < lambda_a
        return v.int()

And it needs to increase with the increase of epoch. With the increase of epoch, more data sets are selected. The code is explained as follows

def increase_threshold(lambda_a, growing_factor):
       lambda_a *= growing_factor
       return lambda_a

Topic: part of the code structure is complete, and the runnable code is shown in reference [1]

SPL-LOSS part

import torch
from torch import Tensor
import torch.nn as nn


class SPLLoss(nn.NLLLoss):
    def __init__(self, *args, n_samples=0, **kwargs):
        super(SPLLoss, self).__init__(*args, **kwargs)
        self.threshold = 0.1
        self.growing_factor = 1.35
        self.v = torch.zeros(n_samples).int()

    def forward(self, input: Tensor, target: Tensor, index: Tensor) -> Tensor:
        super_loss = nn.functional.nll_loss(input, target, reduction="none")
        v = self.spl_loss(super_loss)
        self.v[index] = v
        return (super_loss * v).mean()
    # Increase the size of each workout by increasing the threshold
    def increase_threshold(self):
        self.threshold *= self.growing_factor

    def spl_loss(self, super_loss):
        # If the loss of the model is < threshold -- > V = 1, it means that the data set is simple
        # Otherwise -- > V = 0 indicates that the dataset is difficult to use
        v = super_loss < self.threshold
        return v.int()

train part

def train():
    model = Model(2, 2)
    dataloader = get_dataloader()
    criterion = SPLLoss(n_samples=len(dataloader.dataset))
    optimizer = optim.Adam(model.parameters())

    for epoch in range(10):
        for index, data, target in tqdm.tqdm(dataloader):
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target, index)
            loss.backward()
            optimizer.step()
        criterion.increase_threshold()
        plot(dataloader.dataset, model, criterion)

    animation = camera.animate()
    animation.save("plot.gif")

reference resources
[0]: Shu Jun, Meng Deyu, Xu zongben. Yuan self-learning. Chinese Science: Information Science, 2020, 50: 781 – 793, doi: 10.1360/SSI-2020-0005 Shu J, Meng d y, Xu Z B. meta self-paced learning (in Chinese). SCI sin info, 2020, 50: 781 – 793, doi: 10.1360/SSI-2020-0005
[1]: GitHub

Posted by tippy_102 on Thu, 02 Dec 2021 22:04:25 -0800