New generative model to beat GANs: principle, network structure, application, code, experiment and Prospect of score based model (diffusion model)

Keywords: neural networks Pytorch Deep Learning

In the top conferences of NeurIPS, ICCV and CVPR in recent two years, there have been 20 or 30 papers related to score based generic models, which is a new generative model. In particular, some papers directly shouted the slogan of beat GANs. The new generation method and the generation effect of leading GANs and VAE in some fields have attracted more and more people to be interested in and devote themselves to research.

  • Will it be the next GANs? Can we solve the problems currently encountered by GANs?
  • What are the advantages compared with the existing generative model? What are the disadvantages?
  • What is the current network structure?
  • How to implement it in code?
  • What are the commonly used data sets?
  • What are the commonly used evaluation indicators?
  • What fields can it be applied to?
  • What problems have you encountered?
  • What are the bottlenecks of development?
  • What will happen in the future?

This paper discusses these problems.

catalogue

Principle overview

Why is it called scored based?

Lang Zhiwan dynamics

Score based models and diffusion model

3D point cloud reconstruction task

network structure

UNet

Denoising Score Matching

GANs,DPM,DDPM

Advantages of GANs

GANs disadvantages

DDPM/DPM benefits

DDPM/DPM disadvantages

Common evaluation indicators

Common data sets

One dimensional sketch

Two dimensional picture

three-dimensional model

application area

reference resources:

Principle overview

Estimate the score function from the data and use Langevin to motivate students into a new sample. Therefore, the core physical background of scored based model and diffusion model is Langevin dynamics.

Because the estimated fractional function is inaccurate in the regions without training data, Langevin dynamics may not converge correctly when the sampling trajectory encounters these regions. As a remedy, Gaussian noise with different intensity is used to disturb the data, and the score function of all noise disturbed data distribution is jointly estimated. In the reasoning process, the information of all noise scales is combined with Langevin dynamics to sample from each noise disturbance distribution in turn.

Compared with GANs, the most significant advantages are:

  • The sample quality of confrontation training is not required, and confrontation training is not required. As we all know, the difficulty of GANs training has always been a difficult problem in the industry. This is mainly because the biggest problem of implicit generic models like GANs is the need for confrontation training, and this training method is usually very unstable. (PS: the training of scored based model is not simple)
  • Flexible model architecture.
  • Accurate log likelihood calculation.
  • The inverse problem is solved without retraining the model. The model after train can participate in sampling reconstruction without training a feature network like the model of StyleGAN.

Why is it called scored based?

Like GANs and VAE, scored based is also an implicit generative model. It is necessary to ensure easy to handle regularization constants (which will be mentioned later) to facilitate the calculation of likelihood, which usually means that the network structure has great limitations, that is, it is impossible to arbitrarily organize and design the network structure like NAS. Or we must rely on alternative objectives to approximate maximum likelihood training in the training process.

However, scored based models the gradient of log PDF to obtain a quantity called fractional function, which does not need to deal with regularized constants similar to likelihood based models.

This fractional function is called:, and our task is to minimize the Fisher divergence between the model and the data distribution:

Lang Zhiwan dynamics

Langevin dynamics only uses the fractional function to sample the real data distribution P (x) by Markov chain Monte Carlo. The iterative process is as follows:

Score based models and diffusion model

The principles of scored based models and diffusion models are similar. Interested students can refer to the previous article in this series:

 Diffusion Model and deep learning (with Python example)

This article focuses on the process from physical background to deep learning, mathematical derivation and code examples of general diffusion process. This article will not repeat this aspect.

3D point cloud reconstruction task

1. A condition generation problem, because the conditions of the point cloud generated by the considered Markov chain are some potential points of shape. The training and sampling scheme caused by this conditional adaptation is significantly different from the previous research on diffusion probability model.
2. The two-dimensional image related DDPM can not be directly extended to the point cloud, because the sampling mode of points in three-dimensional space is irregular, rather than the regular grid structure below the image.
3. Because the point cloud is composed of discrete points in three-dimensional space, these points are regarded as particles in the non-equilibrium thermodynamic system in contact with the hot bath. Under the action of hot bath, the position of particles evolves randomly in the way that they diffuse and eventually diffuse to space.
4. By adding noise in each time step, the initial distribution of particles is transformed into a simple noise distribution.
5. Connect the point distribution of the point cloud with the noise distribution through the diffusion process. In order to model the point distribution in point cloud generation, the reverse diffusion process is considered, which restores the distribution of target points from the noise distribution.
6. The reverse diffusion process is modeled as a Markov chain, and the noise distribution is transformed into the target distribution. The goal is to learn its transition kernel so that the Markov chain can reconstruct the desired shape. In addition, because the purpose of Markov chain is to model the point distribution, Markov chain alone can not generate point clouds of various shapes. Therefore, a shape potential is introduced as the condition of transition core. In the generation setting, the shape potential follows a priori distribution, which is parameterized by standardized flow to enhance the expression ability of the model. In the case of self coding, the shape potential is learned end-to-end.
7. The training target is expressed as maximizing the variational lower bound of the likelihood value of the point cloud under the condition of shape potential, and it is further expressed as an easy to handle closed expression.

network structure

UNet

UNET is famous in the medical field. Its advantage is that it can learn more dimensional information. We must take a good look at the original paper: u-net: Revolutionary networks for biomedical image segmentation. UNET model uses a pile of residual layers and lower sampling convolution, and then a pile of residual layers and upper sampling convolution. Layers with the same space size are connected by skip connection. In addition, a single head 16 * 16 resolution global attention layer is used, and the projection of embedded time step is added to each residual block.

The paper that uses unet in the score based model for the first time: Ho, J., Jain, A., and Abbeel, P. Denoising diffusion probabilistic models, 2020

Most of the subsequent work is on the network structure proposed in this paper. The classic unet model class code is as follows, which can be inherited directly when reusing.

class UNetModel(nn.Module):

    def __init__(
        self,
        in_channels,
        model_channels,
        out_channels,
        num_res_blocks,
        attention_resolutions,
        dropout=0,
        channel_mult=(1, 2, 4, 8),
        conv_resample=True,
        dims=2,
        # dims=1,
        num_classes=None,
        use_checkpoint=False,
        num_heads=1,
        num_heads_upsample=-1,
        use_scale_shift_norm=False,
    ):
        super().__init__()

        if num_heads_upsample == -1:
            num_heads_upsample = num_heads

        self.in_channels = in_channels
        self.model_channels = model_channels
        self.out_channels = out_channels
        self.num_res_blocks = num_res_blocks
        self.attention_resolutions = attention_resolutions
        self.dropout = dropout
        self.channel_mult = channel_mult
        # self.channel_mult = (1, 2, 4, 8)
        self.conv_resample = conv_resample
        self.num_classes = num_classes
        self.use_checkpoint = use_checkpoint
        self.num_heads = num_heads
        self.num_heads_upsample = num_heads_upsample

        time_embed_dim = model_channels * 4
        self.time_embed = nn.Sequential(
            linear(model_channels, time_embed_dim),
            SiLU(),
            linear(time_embed_dim, time_embed_dim),
        )

        if self.num_classes is not None:
            self.label_emb = nn.Embedding(num_classes, time_embed_dim)

        self.input_blocks = nn.ModuleList(
            [
                TimestepEmbedSequential(
                    conv_nd(dims, in_channels, model_channels, 3, padding=1)
                )
            ]
        )
        input_block_chans = [model_channels]
        ch = model_channels
        ds = 1
        for level, mult in enumerate(channel_mult):
            for _ in range(num_res_blocks):
                layers = [
                    ResBlock(
                        ch,
                        time_embed_dim,
                        dropout,
                        out_channels=mult * model_channels,
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                ch = mult * model_channels
                if ds in attention_resolutions:
                    layers.append(
                        AttentionBlock(
                            ch, use_checkpoint=use_checkpoint, num_heads=num_heads
                        )
                    )
                self.input_blocks.append(TimestepEmbedSequential(*layers))
                input_block_chans.append(ch)
            if level != len(channel_mult) - 1:
                self.input_blocks.append(
                    TimestepEmbedSequential(Downsample(ch, conv_resample, dims=dims))
                )
                input_block_chans.append(ch)
                ds *= 2

        self.middle_block = TimestepEmbedSequential(
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
            AttentionBlock(ch, use_checkpoint=use_checkpoint, num_heads=num_heads),
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
        )

        self.output_blocks = nn.ModuleList([])
        for level, mult in list(enumerate(channel_mult))[::-1]:
            for i in range(num_res_blocks + 1):
                layers = [
                    ResBlock(
                        ch + input_block_chans.pop(),
                        time_embed_dim,
                        dropout,
                        out_channels=model_channels * mult,
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                ch = model_channels * mult
                if ds in attention_resolutions:
                    layers.append(
                        AttentionBlock(
                            ch,
                            use_checkpoint=use_checkpoint,
                            num_heads=num_heads_upsample,
                        )
                    )
                if level and i == num_res_blocks:
                    layers.append(Upsample(ch, conv_resample, dims=dims))
                    ds //= 2
                self.output_blocks.append(TimestepEmbedSequential(*layers))

        self.out = nn.Sequential(
            normalization(ch),
            SiLU(),
            zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
        )

    def convert_to_fp16(self):
        """
        Convert the torso of the model to float16.
        """
        self.input_blocks.apply(convert_module_to_f16)
        self.middle_block.apply(convert_module_to_f16)
        self.output_blocks.apply(convert_module_to_f16)

    def convert_to_fp32(self):
        """
        Convert the torso of the model to float32.
        """
        self.input_blocks.apply(convert_module_to_f32)
        self.middle_block.apply(convert_module_to_f32)
        self.output_blocks.apply(convert_module_to_f32)

    @property
    def inner_dtype(self):
        """
        Get the dtype used by the torso of the model.
        """
        return next(self.input_blocks.parameters()).dtype

    def forward(self, x, timesteps, y=None):
        """
        Apply the model to an input batch.

        :param x: an [N x C x ...] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :param y: an [N] Tensor of labels, if class-conditional.
        :return: an [N x C x ...] Tensor of outputs.
        """
        assert (y is not None) == (
            self.num_classes is not None
        ), "must specify y if and only if the model is class-conditional"

        hs = []
        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))

        if self.num_classes is not None:
            assert y.shape == (x.shape[0],)
            emb = emb + self.label_emb(y)

        h = x.type(self.inner_dtype)
        # At this time, h is the same as the data size of each batch
        # print(f"h size befor is {h.size()}")
        # Down sampling
        for module in self.input_blocks:
            h = module(h, emb)  # Convolution + pooling
            # print(f"h size after is {h.size()}")
            hs.append(h)
        # adjoining course
        h = self.middle_block(h, emb)
        # Up sampling
        for module in self.output_blocks:
            hs_temp = hs.pop()
            # print(f"h size is {h.size()}; hs.pop() size is {hs_temp.size()}")
            # if (h.size()[2] != hs_temp.size()[2]) or (h.size()[3] != hs_temp.size()[3]):
            #     # Generally h size is greater than hs size
            #     # temp_shape = (h.size()[0]*h.size()[1]*h.size()[2]*h.size()[3]) / (hs_temp.size()[0]*hs_temp.size()[2]*hs_temp.size()[3])
            #     continue
            # cat_in = th.cat([h, hs.pop()], dim=1)
            cat_in = th.cat([h, hs_temp], dim=1)
            h = module(cat_in, emb)
        h = h.type(x.dtype)
        return self.out(h)

    def get_feature_vectors(self, x, timesteps, y=None):
        """
        Apply the model and return all of the intermediate tensors.

        :param x: an [N x C x ...] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :param y: an [N] Tensor of labels, if class-conditional.
        :return: a dict with the following keys:
                 - 'down': a list of hidden state tensors from downsampling.
                 - 'middle': the tensor of the output of the lowest-resolution
                             block in the model.
                 - 'up': a list of hidden state tensors from upsampling.
        """
        hs = []
        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
        if self.num_classes is not None:
            assert y.shape == (x.shape[0],)
            emb = emb + self.label_emb(y)
        result = dict(down=[], up=[])
        h = x.type(self.inner_dtype)
        for module in self.input_blocks:
            h = module(h, emb)
            hs.append(h)
            result["down"].append(h.type(x.dtype))
        h = self.middle_block(h, emb)
        result["middle"] = h.type(x.dtype)
        for module in self.output_blocks:
            cat_in = th.cat([h, hs.pop()], dim=1)
            h = module(cat_in, emb)
            result["up"].append(h.type(x.dtype))
        return result

Denoising Score Matching

unet was used in this field late, and the earliest paper was published in 2020. Before that, denoising score matching was widely used in the industry.

This method first learns the fractional function through fractional matching denoising. Intuitively, this means training neural network (called scoring network) to denoise the image blurred by Gaussian noise. A key point is to use multiple noise scales to interfere with the data, so that the scoring network can capture both coarse-grained and fine-grained image features. However, how to choose these noise scales is a very difficult problem.

Secondly, by running Langevin power students into samples, starting with the white noise, the white noise is gradually reduced into images by using the scoring network.

GANs,DPM,DDPM

Advantages of GANs

1. Sampling wall clock is faster

GANs disadvantages

1. It is difficult to train and crashes without careful selection of hyperparameters and regularizers.
2. gan can exchange diversity for fidelity to produce high-quality samples, but does not cover the whole distribution.
3. The training process of GANs may be unstable due to confrontation loss. The autoregressive model assumes that the generation order is unnatural, which may limit the flexibility of the model.

DDPM/DPM benefits

DDPM = DPM + denoising score matching(denoising autoencoders)

1. It captures more diversity and is usually easier to scale and train than gan.
2. Distributed coverage, fixed training objectives and easy expansion.

DDPM/DPM disadvantages

1. The wall clock time of sampling is slower than that of gan.
2. There are still deficiencies in the quality of visual samples.
3. Multiple de-noising steps are used (so forward transfer), which are still slower than gan in sampling time.

Common evaluation indicators

Evaluation indicators most articles compare GANs, so it is similar to the data set used by GANs.

  1. FID Gans trained by a two time scale update rule converge to a local Nash equiprium can better capture diversity and better conform to human judgment than IS. A symmetric measure describing the distance between two image distributions in the initial latent space.

  2. Concept score improved technologies for training Gans measures the extent to which a model can still produce convincing samples of a single class while capturing the complete distribution of ImageNet classes. One disadvantage of this metric IS that it does not reward models that cover the entire distribution or capture diversity in classes, and remember that a small part of the complete dataset will still have a high IS.

  3. Improved Precision and recall metric for assessing generic models mainly describes accuracy and model fidelity.

  4. Recall mainly describes recall, measuring diversity and distribution coverage.

  5. retrieval
    retrieval comparison is also a common method to illustrate the reconstruction effect

Common data sets

One dimensional sketch

Two dimensional picture

  • imagenet: ImageNet
  • LSUN lmdb
  • FFHQ
  • CelebA
  • cifar10, which can be downloaded using the following code:
    import os
    import tempfile
    
    import torchvision
    from tqdm.auto import tqdm
    
    CLASSES = (
        "plane",
        "car",
        "bird",
        "cat",
        "deer",
        "dog",
        "frog",
        "horse",
        "ship",
        "truck",
    )
    
    
    def main():
        for split in ["train", "test"]:
            out_dir = f"cifar_{split}"
            if os.path.exists(out_dir):
                print(f"skipping split {split} since {out_dir} already exists.")
                continue
    
            print("downloading...")
            with tempfile.TemporaryDirectory() as tmp_dir:
                dataset = torchvision.datasets.CIFAR10(
                    root=tmp_dir, train=split == "train", download=True
                )
    
            print("dumping images...")
            os.mkdir(out_dir)
            for i in tqdm(range(len(dataset))):
                image, label = dataset[i]
                filename = os.path.join(out_dir, f"{CLASSES[label]}_{i:05d}.png")
                image.save(filename)
    
    
    if __name__ == "__main__":
        main()
    

three-dimensional model

application area

  • Audio modeling
    DiffWave: A Versatile Diffusion Model for Audio Synthesis
    PriorGrad: Improving Conditional Denoising Diffusion Models with Data-Driven Adaptive Prior
  • speech synthesis
    Grad-TTS: A Diffusion Probabilistic Model for Text-to-Speech
  • time series prediction
    Autoregressive Denoising Diffusion Models for Multivariate Probabilistic Time Series Forecasting
  • 2D image generation
    Diffusion Models Beat GANs on Image Synthesis
    Improved Denoising Diffusion Probabilistic Models
    Denoising Diffusion Probabilistic Models
    Improved Techniques for Training Score-Based Generative Models
  • 3D point cloud reconstruction
    Diffusion Probabilistic Models for 3D Point Cloud Generation

reference resources:

Posted by Negligence on Wed, 20 Oct 2021 10:22:42 -0700