It's possible to create endless images with words. Dale's pre training model starts

Keywords: Computer Vision Deep Learning paddlepaddle

It is possible to create endless images with words. The Russian version of Dall-E pre training model starts

(including automatic translation, Chinese can be entered)

Look, look, this dalle is really wonderful. Write what you read and draw what you think.

Don't miss your brain hole.

"Avocado chair" "Beautiful anime girl" "A group of Shiba Inu flying in the sky"

introduce

  • DALL-E: OpenAI, a non-profit artificial intelligence research organization, trains a neural network called DALL-E, which allows users to annotate text in natural language and create images with consistent content. The name of DALLE comes from a mixture of Spanish Catalonia surrealist painter Salvador Dal í and Pixar animated character WALLE.
  • CLIP: This is a paper released at the same time as DALL-E, which can map images to categories described by text. It can be used for image retrieval and text image matching ranking.
  • VQGAN: part of the taking transformer, which encodes images into Tokens and decodes Tokens into high-definition recognizable pictures through confrontation training.
  • RealESRGAN: generate low-definition pictures in a differentiable way as the degradation of high-definition pictures in reality, so as to achieve better super division effect on the basis of ESRGAN.

RuDalle is used in this project Source Repo It is based on these that the text generates images, filters and super points. Because VQGAN and RealESRGAN are used, it achieves a more beautiful generation effect than the original OpenAI. The training method also refers to CogView (yes, it is equivalent to adding some improved Chinese version of DALL-E, and the overall idea has not changed). The generation model of OpenAI is not open source, and although CogView releases the model, it does not integrate multiple models like RuDalle to achieve better generation effect.

If you find this project interesting, remember to come to Github Star oh

Model architecture

In fact, the overall architecture of DALL-E model is not complex. Except that the codec part encodes the image into Tokens and can decode it, other parts are basically similar to the famous text generation model GPT-3, which is equivalent to turning the image into characters and reasoning the next character from the GPT like model, The string of the description text of the picture is preceded by the string of the picture as a condition.

reasoning

However, as we all know, if models like GPT-3 have good generation effect, the amount of parameters will be large, that is, there will be many hidden layers, and the Size of each layer is also large. Therefore, RuDalle's model body has reached an amazing two G's, and the general graphics card can't be carried, so we have to turn to the powerful V100 for help.

Even if it is as strong as V100, it is difficult to make a complete reasoning, and we don't expect to train on the platform.

Let's start the experiment in order!

# Unzip model weights
!cd data && unzip -qq data116979/pretrained_models.zip
# Installation dependency
!pip install -r requirements.txt > /dev/null 2> /dev/null
!pip install ipywidgets translators==4.9.5 > /dev/null 2> /dev/null
# View environment information
import multiprocessing
import paddle
from psutil import virtual_memory

ram_gb = round(virtual_memory().total / 1024**3, 1)

print('CPU:', multiprocessing.cpu_count())
print('RAM GB:', ram_gb)
print("PaddlePaddle version:", paddle.__version__)
print("CUDA version:", paddle.version.cuda())
print("cuDNN version:", paddle.device.get_cudnn_version())
device = 'cuda:0' if len(paddle.static.cuda_places()) > 0 else 'cpu'
print("device:", device)

!nvidia-smi
CPU: 24
RAM GB: 110.2
PaddlePaddle version: 2.2.0
CUDA version: 10.1
cuDNN version: 7605
device: cuda:0
Tue Nov 23 10:22:33 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.67       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Tesla V100-SXM2...  On   | 00000000:05:00.0 Off |                    0 |
| N/A   61C    P0   208W / 300W |      0MiB / 16384MiB |    100%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
# Import module dependencies
from rudalle_paddle.pipelines import generate_images, show, super_resolution, cherry_pick_by_clip
from rudalle_paddle import get_rudalle_model, get_tokenizer, get_vae, get_realesrgan, get_ruclip
from rudalle_paddle.utils import seed_everything
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized
INFO:matplotlib.font_manager:font search path ['/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/afm', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/pdfcorefonts']
INFO:matplotlib.font_manager:generated new fontManager
# Load Dalle model
# fp16 can be enabled, but we haven't seen any acceleration effect so far
device = 'cuda'
dalle = get_rudalle_model('Malevich-paddle', pretrained=True, fp16=False, device=device, cache_dir='data/pretrained_models')
W1123 10:22:52.852972  2335 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W1123 10:22:52.853024  2335 device_context.cc:465] device: 0, cuDNN Version: 7.6.


◼️ Malevich is 1.3 billion params model from the family GPT3-like, that uses Russian language and text+image multi-modality.
# Loading codec model and super division model
realesrgan = get_realesrgan('x2-paddle', device=device, cache_dir='data/pretrained_models') # x2/x4/x8
tokenizer = get_tokenizer(cache_dir='data/pretrained_models')
vae = get_vae('vqgan.gumbelf8-sber.paddle', cache_dir='data/pretrained_models').to(device) # still not support dwt now
ruclip, ruclip_processor = get_ruclip('ruclip-vit-base-patch32-v5-paddle', cache_dir='data/pretrained_models')
ruclip = ruclip.to(device)
x2-paddle --> ready
tokenizer --> ready
Working with z of shape (1, 256, 32, 32) = 262144 dimensions.
vae --> ready
ruclip --> ready

Generation by rudale

# Load translator
import translators as ts

def translate(txt, backend):
    return getattr(ts, backend)(txt, from_language='auto', to_language='ru')
Using China server backend.
# Text description of the image

backend = 'google' # google/bing/alibaba/tencent/sogou
source_text = 'Avocado chair' # In automatic mode, you can enter any language as a text description # auto mode, you can type any language

print('Source text target text:', source_text)
try:
    target_text = translate(source_text, backend)
except:
    raise Exception(
        'Failed to call the translator, please try to replace it backend. '
        'Failed to call the translator, please try to replace the backend.'
    )
print('Target text target text:', target_text)
Source text target text: Avocado chair
 Target text target text: Стул авокадо
# Random number seed. In a single generation, the same seed number and the same text description will generate the same image
seed_everything(42)
# Main part of model generation
# Large model generation takes time, please wait patiently

text = target_text

pil_images = []
scores = []
for top_k, top_p, images_num in [
    (2048, 0.995, 3), # A total of 8 times are generated. If you are in a hurry, you can comment out 6
    (1536, 0.99, 3),
    (1024, 0.99, 3),
    (1024, 0.98, 3),
    (512, 0.97, 3),
    (384, 0.96, 3),
    (256, 0.95, 3),
    (128, 0.95, 3), 
]:
    _pil_images, _scores = generate_images(text, tokenizer, dalle, vae, top_k=top_k, images_num=images_num, top_p=top_p)
    pil_images += _pil_images
    scores += _scores
HBox(children=(IntProgress(value=0, max=1024), HTML(value='')))





/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/varbase_patch_methods.py:392: UserWarning: [93m
Warning:
tensor.grad will return the tensor value of the gradient. This is an incompatible upgrade for tensor.grad API.  It's return type changes from numpy.ndarray in version 2.0 to paddle.Tensor in version 2.1.0.  If you want to get the numpy value of the gradient, you can use :code:`x.grad.numpy()` [0m
  warnings.warn(warning_msg)



HBox(children=(IntProgress(value=0, max=1024), HTML(value='')))






HBox(children=(IntProgress(value=0, max=1024), HTML(value='')))






HBox(children=(IntProgress(value=0, max=1024), HTML(value='')))






HBox(children=(IntProgress(value=0, max=1024), HTML(value='')))






HBox(children=(IntProgress(value=0, max=1024), HTML(value='')))






HBox(children=(IntProgress(value=0, max=1024), HTML(value='')))






HBox(children=(IntProgress(value=0, max=1024), HTML(value='')))
# Show build results
show([pil_image for pil_image, score in sorted(zip(pil_images, scores), key=lambda x: -x[1])] , 6)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return list(data) if isinstance(data, collections.MappingView) else data
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/figure.py:457: UserWarning: matplotlib is currently using a non-GUI backend, so cannot show the figure
  "matplotlib is currently using a non-GUI backend, "

Use CLIP to filter the images that best match the description (auto cherry pick by ruclip)

top_images, clip_scores = cherry_pick_by_clip(pil_images, text, ruclip, ruclip_processor, device=device, count=6)
show(top_images, 3)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:130: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. 
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:

Super resolution of images

sr_images = super_resolution(top_images, realesrgan)
show(sr_images, 3)

Posted by thatsme on Sat, 27 Nov 2021 19:20:51 -0800