[MAE] implementation of Masked Autoencoders and visualization of pre training

Masked Autoencoders Are Scalable Vision Learners

MAE proposes a self supervised training method, which can effectively and train the model and improve the performance of the model. This project realizes the self-monitoring training part and visualizes the training process.

network structure

The structure of MAE is relatively simple. It is composed of encoder and decoder. Here, both encoder and decoder adopt Transformer structure. After the input pictures are divided into patches, a certain proportion of patches are masked (75% in the paper). unmasked patches are sent to the encoder to obtain encoded patches. The combination of masked tokens and encoded patches is introduced to send them to the decoder. The output target of the decoder is the original image, and the loss is only calculated on the masked patches.

Details to note:
1. Masking: after the image is divided into non overlapping patches, the masked patches are selected to obey uniform distribution;
2. Encoder: encoder only works on unmasked patches. embedding patches need to be added with position embeddings;
3. Decoder: the input of decoder is composed of encoded patches and mask tokens. mask token is a learnable parameter shared by parameters. At the same time, position embeddings is added to mask tokens to represent location information;
4. Reconstruction target: the decoder outputs each pixel value of the target picture (input the original picture), and the loss is only calculated in masked patches;
5. Implementation:
(1) Generate a token for each patch;
(2) shuffle all tokens, and then remove some tokens according to the masking ratio;
(3) After obtaining encoded tokens, merge mask tokens and encoded tokens. Note that there is no need to unsuffle here, and simply concat;

Contents of the project

This project is trained in the verification machine of ImageNet 1K, 4W of 5W pictures are used as training data, and the remaining 1W is reserved for verification. Because the training is relatively slow, only MAE is pre trained here. The masking ratio is 0.5, and only 200 epochs are trained. Due to less data and small epochs, the effect is not very good, but the output change process of MAE can be seen, and there is no fine tuning process.

Output change

On the left is the original image, in the middle is the masked image, and on the right is the prediction result of mae.
epoch 1:

epoch 10:

epoch 200:


reference resources: vit-pytorch
Thank you very much for teacher Zhu Er's course (teacher Zhu is awesome): Learning visual Transformer from scratch
Aistudio homepage: https://aistudio.baidu.com/aistudio/personalcenter/thirdview/312316
Passing masters, please like it and give me some fighting power.

# Processing data sets
%cd ~/data/data89857/
!tar -xf ILSVRC2012mini.tar
%cd ~/

# There is a problem with the txt file of the dataset. Fix the train_list content, run it once
import os

train_file_path = '/home/aistudio/data/data89857/ILSVRC2012mini/train_list.txt'
data = []
with open(train_file_path, 'r') as f:
    lines = f.readlines()
    for line in lines:
        _, info = line.split('/')

with open(train_file_path, 'w') as f:


The implementation of ViT will not be explained too much. Note: since the reconstruction target of MAE is the pixel value of the original image, do not use convolution for patch embedding. First divide the original image into patches, and then use linear embedding.

import paddle 
from paddle import nn

class PreNorm(nn.Layer):
    def __init__(self, dim, fn):
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x):
        return self.fn(self.norm(x))

class Mlp(nn.Layer):
    def __init__(self, dim, hidden_dim, dropout=0.):
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.Linear(hidden_dim, dim),

    def forward(self, x):
        x = self.mlp(x)
        return x

class Identity(nn.Layer):
    def __init__(self):

    def forward(self, x):
        return x

class Attention(nn.Layer):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** (-0.5)

        self.attend = nn.Softmax(axis = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias_attr=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
        ) if project_out else Identity()

    def forward(self, x):
        B, N, _ = x.shape
        qkv = self.to_qkv(x).chunk(3, axis=-1)
        q, k, v = map(lambda t: t.reshape([B, N, self.heads, -1]).transpose([0, 2, 1, 3]), qkv)

        dots = paddle.matmul(q, k.transpose([0, 1, 3, 2])) * self.scale
        attn = self.attend(dots)

        out = attn.matmul(v)
        out = out.transpose([0, 2, 1, 3]).flatten(2)
        out = self.to_out(out)
        return out

class Transformer(nn.Layer):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        self.layers = nn.LayerList()
        for _ in range(depth):
                        PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
                        PreNorm(dim, Mlp(dim, mlp_dim, dropout=dropout)),

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class PatchEmbedding(nn.Layer):
    def __init__(self, image_size, patch_size, embed_dim=768, in_channels=3):
        image_height, image_width = image_size if isinstance(image_size, tuple) else (image_size, image_size) 
        self.patch_height, self.patch_width = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)

        assert image_height % self.patch_height == 0 and image_width % self.patch_width == 0, "Image dimensions must be divisible by the patch size."
        self.p1, self.p2 = (image_height // self.patch_height), (image_width // self.patch_width)
        self.num_patches = (image_height // self.patch_height) * (image_width // self.patch_width)

        self.patch_embed = nn.Linear(in_channels * self.patch_height * self.patch_width, embed_dim)

    def forward(self, x):
        N, C, H, W = x.shape
        patches = x.reshape([N, C, self.p1, self.patch_height, self.p2, self.patch_width]).transpose([0, 2, 4, 1, 3, 5]).reshape([N, self.num_patches, -1])
        x = self.patch_embed(patches)
        x = x.flatten(2)
        return x, patches

class ViT(nn.Layer):
    def __init__(

        assert pool in {'cls', 'mean'},  'pool type nums be either cls(cls token) or mean (mean pooling).'
        self.embed_dim = embed_dim
        self.patch_embedding = PatchEmbedding(image_size, patch_size, embed_dim=embed_dim, in_channels=channels)
        self.num_patches = self.patch_embedding.num_patches
        self.pos_embedding = self.create_parameter(shape=[1, self.num_patches + 1, embed_dim], default_initializer=nn.initializer.KaimingNormal(0.02))
        self.cls_token = self.create_parameter(shape=[1, 1, embed_dim])
        self.dropout = nn.Dropout(embed_dropout)

        self.transformer = Transformer(embed_dim, depth, heads, dim_head, mlp_dim, dropout)
        self.pool = pool
        self.to_latent = Identity()

        self.mlp_head = nn.Sequential(
            nn.Linear(embed_dim, num_classes),

    def forward(self, x):
        x, patches = self.patch_embedding(x)

        B, N, _ = x.shape
        cls_tokens = paddle.tile(self.cls_token, [B, 1, 1])
        x = paddle.concat([cls_tokens, x], axis=1)
        x += self.pos_embedding[:, :(N + 1)]
        x = self.dropout(x)

        x = self.transformer(x)
        x = x.mean(axis=1) if self.pool == 'mean' else x[:, 0]
        x = self.to_latent(x)
        x = self.mlp_head(x)

        return x

# if __name__ == '__main__':
#     model = ViT(image_size=(256,256), 
#         patch_size=(32,32), 
#         num_classes=1000, 
#         embed_dim=1024,
#         heads=8,
#         depth=6, 
#         mlp_dim=2048, )
#     x = paddle.randn([2, 3, 256, 256])
#     y = model(x)
#     print(x.shape, y.shape)
#     paddle.summary(model, (4, 3, 256, 256))


The encoder of MAE is ViT, and the decoder is a transformer model.

import paddle
from paddle import nn
import paddle.nn.functional as F

class MAE(nn.Layer):
    def __init__(self, encoder, decoder_dim, masking_ratio=0.75, decoder_depth=1, decoder_heads=8, decoder_dim_head=64):
        assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must in range (0, 1), but got {}.'.format(masking_ratio)
        self.masking_ratio = masking_ratio
        self.encoder = encoder

        patch_dim = self.encoder.patch_embedding.patch_embed.weight.shape[0] # dim of each patch after division

        self.enc_to_dec = nn.Linear(encoder.embed_dim, decoder_dim) if encoder.embed_dim != decoder_dim else Identity()
        self.mask_token = self.create_parameter(shape=(1, 1, decoder_dim))  # mask_ Learnable parameters shared by token
        self.decoder = Transformer(dim=decoder_dim, depth=decoder_depth, heads=decoder_heads, dim_head=decoder_dim_head, mlp_dim=decoder_dim*4) # decoder
        self.decoder_pos_emb = nn.Embedding(encoder.num_patches, decoder_dim) # decoder position embedding
        self.to_pixels = nn.Linear(decoder_dim, patch_dim)

    def forward(self, x):
        tokens, patches = self.encoder.patch_embedding(x) # Patches are the patches divided in the original diagram and used as the target
        batch, num_patches, _ = tokens.shape # batch_size, num_patches, _
        tokens = tokens + self.encoder.pos_embedding[:, 1:(num_patches + 1)]

        # mask part of patches
        num_masked = int(self.masking_ratio * num_patches)
        rand_indices = paddle.rand(shape=[batch, num_patches]).argsort(axis=-1)
        masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:]

        # unmasked tokens to be encoded
        batch_range = paddle.arange(batch)[:, None]
        tokens = tokens[batch_range, unmasked_indices]
        # masked_patches
        masked_patches = patches[batch_range, masked_indices] # Losses are calculated only in masked patches
        # transformer
        encoded_tokens = self.encoder.transformer(tokens)
        decoder_tokens = self.enc_to_dec(encoded_tokens)

        # decoder embed
        mask_tokens = paddle.tile(self.mask_token, [batch, num_masked, 1]) # decoder position embedding
        mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices) # learned mask token

        decoder_tokens = paddle.concat([mask_tokens, decoder_tokens], axis=1) # unshuffle is not required
        decoded_tokens = self.decoder(decoder_tokens)
        if self.training:
            mask_tokens = decoded_tokens[:, :num_masked]
            pred = self.to_pixels(mask_tokens) # N, num_unmasked, dim
            loss = F.mse_loss(pred, masked_patches)
            return loss
            image = patches.clone() # Graph after sampling
            image.stop_gradient = True
            image[batch_range, masked_indices] = 0 # mask sampling area
            pred = self.to_pixels(decoded_tokens)
            return pred, image

# if __name__ == '__main__':
#     encoder = ViT(image_size=256, 
#         patch_size=32, 
#         num_classes=1000, 
#         embed_dim=1024,
#         heads=8,
#         depth=6, 
#         mlp_dim=2048)
#     model = MAE(encoder, masking_ratio=0.75, decoder_dim=512, decoder_depth=6)
#     x = paddle.randn([4, 3, 256, 256])
#     y = model(x)
#     print(x.shape, y.shape)
#     paddle.summary(model, (4, 3, 256, 256))


# Build dataset
from paddle.io import Dataset, DataLoader
import paddle.vision.transforms as T
import cv2
import os

class ImageNetDataset(Dataset):
    def __init__(self, data_dir, info_txt, mode='train', transforms=None):
        self.data_dir = data_dir
        self.image_paths, self.labels = self.get_info(info_txt)
        self.mode = mode
        self.transforms = transforms

    def get_info(self, file_path):
        paths = []
        labels = []
        with open(file_path, 'r') as f:
            lines = f.readlines()
            for line in lines:
                image_name, label = line.strip().split(' ')
                paths.append(os.path.join(self.data_dir, image_name))
        return paths, labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        image = cv2.imread(image_path)
        if self.transforms:
            image = self.transforms(image)
        if self.mode == 'train':
            return image, label
            return image

# mae_train_trans = T.Compose(
#     [
#         T.Resize((256, 256)),
#         T.RandomHorizontalFlip(),
#         T.RandomVerticalFlip(),
#         T.Transpose([2, 0, 1]),
#     ]
# )

# if __name__ == '__main__':
#     dataset = ImageNetDataset('/home/aistudio/data/data89857/ILSVRC2012mini/train', '/home/aistudio/data/data89857/ILSVRC2012mini/train_list.txt', mode='val', transforms=mae_train_trans)
#     print(len(dataset))
#     image = dataset[0]
#     import matplotlib.pyplot as plt
#     plt.imshow(image)
#     plt.show()

Pre training MAE

# Auxiliary class
class AverageMeter:
    def __init__(self):
        self.val = 0.
        self.count = 0.

    def update(self, value, n=1):
        self.val += value
        self.count += n

    def reset(self):
        self.val = 0.
        self.count = 0.

    def __call__(self):
        return self.val / self.count
# Set relevant parameters
import time

epoches = 2000
batch_size = 256
learning_rate = 0.00001
grad_clip_value = 10

# encoder param
patch_size = (32, 32) 
image_size = (256, 256)
num_classes = 1000
encoder_embed_dim = 1024
encoder_heads = 8
encoder_depth = 6
encoder_mlp_dim = 2048

# decoder params
masking_ratio = 0.5
decoder_dim = 512
decoder_depth = 6

mae_train_trans = T.Compose(
        T.Resize((256, 256)),
        T.Transpose([2, 0, 1]),
# mode = 'val', because label is not required for pre training, it can also be added
mae_dataset = ImageNetDataset('/home/aistudio/data/data89857/ILSVRC2012mini/train', '/home/aistudio/data/data89857/ILSVRC2012mini/train_list.txt', mode='val', transforms=mae_train_trans)
mae_dataloader = DataLoader(

# MAE model
encoder = ViT(image_size=image_size, 
model = MAE(encoder, masking_ratio=masking_ratio, decoder_dim=decoder_dim, decoder_depth=decoder_depth)
# paddle.summary(model, (4, 3, 256, 256))
clip = paddle.nn.ClipGradByValue(min=-grad_clip_value, max=grad_clip_value)
optimizer = paddle.optimizer.Momentum(learning_rate=learning_rate, parameters=model.parameters(), grad_clip=clip)

# Test the function and visualize the training process with a picture
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

def reconstruct(x, image_size, patch_size):
    """reconstrcunt [batch_size, num_patches, embedding] -> [batch_size, channels, h, w]"""
    B, N, _ = x.shape  # batch_size, num_patches, dim

    p1, p2 = image_size[0] // patch_size[0], image_size[1] // patch_size[1] 
    x = x.reshape([B, p1, p2, -1, patch_size[0], patch_size[1]]).transpose([0, 3, 1, 4, 2, 5]).reshape([B, -1, image_size[0], image_size[1]])
    return x

def test(model):
    Use the model to predict a picture and check the effect to see the change trend predicted in the model training process
    image_path = '/home/aistudio/data/data89857/ILSVRC2012mini/val/ILSVRC2012_val_00040043.JPEG'
    source_image = cv2.imread(image_path)
    trans = T.Compose(
            T.Resize((256, 256)),
            T.Transpose([2, 0, 1]),
    source_image = trans(source_image)
    image = paddle.to_tensor(source_image, dtype='float32').unsqueeze(0)
    pred, masked_img = model(image)
    pred_img = reconstruct(pred, image_size, patch_size)
    masked_img = reconstruct(masked_img, image_size, patch_size)

    masked_img = masked_img[0].numpy()
    masked_img = np.clip(masked_img, 0, 255).astype('uint8')
    masked_img = np.transpose(masked_img, [1, 2, 0])

    pred_img = pred_img[0].numpy()
    pred_img = np.clip(pred_img, 0, 255).astype('uint8')
    pred_img = np.transpose(pred_img, [1, 2, 0])

    plt.subplot(1, 3, 1)
    plt.imshow(source_image.transpose([1, 2, 0]))
    plt.subplot(1, 3, 2)
    plt.subplot(1, 3, 3)
    return pred_img

# train

for epoch in range(1, epoches + 1):
    losses = AverageMeter()
    for batch_id, image in enumerate(mae_dataloader):
        image = image.astype('float32')
        loss = model(image)

        lr = optimizer.get_lr()
        if batch_id % 50 == 0:
            print(time.asctime( time.localtime(time.time()) ), "Epoch: {}/{}, Batch id: {}, lr: {}, loss: {}".format(epoch, epoches, batch_id, lr, losses()))
    obj = {
        'model': encoder.state_dict(),
        'epoch': epoch,
    paddle.save(obj, 'model.pdparams')
    obj = {
        'model': model.state_dict(),
        'epoch': epoch,
    paddle.save(obj, 'mae.pdparams')

    test(model) # This will become eval mode
    model.train() # Switch to train mode

