It is possible to create endless images with words. The Russian version of Dall-E pre training model starts
(including automatic translation, Chinese can be entered)
Look, look, this dalle is really wonderful. Write what you read and draw what you think.
Don't miss your brain hole.
"Avocado chair" "Beautiful anime girl" "A group of Shiba Inu flying in the sky"introduce
- DALL-E: OpenAI, a non-profit artificial intelligence research organization, trains a neural network called DALL-E, which allows users to annotate text in natural language and create images with consistent content. The name of DALLE comes from a mixture of Spanish Catalonia surrealist painter Salvador Dal í and Pixar animated character WALLE.
- CLIP: This is a paper released at the same time as DALL-E, which can map images to categories described by text. It can be used for image retrieval and text image matching ranking.
- VQGAN: part of the taking transformer, which encodes images into Tokens and decodes Tokens into high-definition recognizable pictures through confrontation training.
- RealESRGAN: generate low-definition pictures in a differentiable way as the degradation of high-definition pictures in reality, so as to achieve better super division effect on the basis of ESRGAN.
RuDalle is used in this project Source Repo It is based on these that the text generates images, filters and super points. Because VQGAN and RealESRGAN are used, it achieves a more beautiful generation effect than the original OpenAI. The training method also refers to CogView (yes, it is equivalent to adding some improved Chinese version of DALL-E, and the overall idea has not changed). The generation model of OpenAI is not open source, and although CogView releases the model, it does not integrate multiple models like RuDalle to achieve better generation effect.
If you find this project interesting, remember to come to Github Star oh
Model architecture
In fact, the overall architecture of DALL-E model is not complex. Except that the codec part encodes the image into Tokens and can decode it, other parts are basically similar to the famous text generation model GPT-3, which is equivalent to turning the image into characters and reasoning the next character from the GPT like model, The string of the description text of the picture is preceded by the string of the picture as a condition.
reasoning
However, as we all know, if models like GPT-3 have good generation effect, the amount of parameters will be large, that is, there will be many hidden layers, and the Size of each layer is also large. Therefore, RuDalle's model body has reached an amazing two G's, and the general graphics card can't be carried, so we have to turn to the powerful V100 for help.
Even if it is as strong as V100, it is difficult to make a complete reasoning, and we don't expect to train on the platform.
Let's start the experiment in order!
# Unzip model weights !cd data && unzip -qq data116979/pretrained_models.zip
# Installation dependency !pip install -r requirements.txt > /dev/null 2> /dev/null !pip install ipywidgets translators==4.9.5 > /dev/null 2> /dev/null
# View environment information import multiprocessing import paddle from psutil import virtual_memory ram_gb = round(virtual_memory().total / 1024**3, 1) print('CPU:', multiprocessing.cpu_count()) print('RAM GB:', ram_gb) print("PaddlePaddle version:", paddle.__version__) print("CUDA version:", paddle.version.cuda()) print("cuDNN version:", paddle.device.get_cudnn_version()) device = 'cuda:0' if len(paddle.static.cuda_places()) > 0 else 'cpu' print("device:", device) !nvidia-smi
CPU: 24 RAM GB: 110.2 PaddlePaddle version: 2.2.0 CUDA version: 10.1 cuDNN version: 7605 device: cuda:0 Tue Nov 23 10:22:33 2021 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 418.67 Driver Version: 418.67 CUDA Version: 10.1 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | |===============================+======================+======================| | 0 Tesla V100-SXM2... On | 00000000:05:00.0 Off | 0 | | N/A 61C P0 208W / 300W | 0MiB / 16384MiB | 100% Default | +-------------------------------+----------------------+----------------------+ +-----------------------------------------------------------------------------+ | Processes: GPU Memory | | GPU PID Type Process name Usage | |=============================================================================| | No running processes found | +-----------------------------------------------------------------------------+
# Import module dependencies from rudalle_paddle.pipelines import generate_images, show, super_resolution, cherry_pick_by_clip from rudalle_paddle import get_rudalle_model, get_tokenizer, get_vae, get_realesrgan, get_ruclip from rudalle_paddle.utils import seed_everything
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working from collections import MutableMapping /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working from collections import Iterable, Mapping /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working from collections import Sized INFO:matplotlib.font_manager:font search path ['/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/afm', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/pdfcorefonts'] INFO:matplotlib.font_manager:generated new fontManager
# Load Dalle model # fp16 can be enabled, but we haven't seen any acceleration effect so far device = 'cuda' dalle = get_rudalle_model('Malevich-paddle', pretrained=True, fp16=False, device=device, cache_dir='data/pretrained_models')
W1123 10:22:52.852972 2335 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1 W1123 10:22:52.853024 2335 device_context.cc:465] device: 0, cuDNN Version: 7.6. ◼️ Malevich is 1.3 billion params model from the family GPT3-like, that uses Russian language and text+image multi-modality.
# Loading codec model and super division model realesrgan = get_realesrgan('x2-paddle', device=device, cache_dir='data/pretrained_models') # x2/x4/x8 tokenizer = get_tokenizer(cache_dir='data/pretrained_models') vae = get_vae('vqgan.gumbelf8-sber.paddle', cache_dir='data/pretrained_models').to(device) # still not support dwt now ruclip, ruclip_processor = get_ruclip('ruclip-vit-base-patch32-v5-paddle', cache_dir='data/pretrained_models') ruclip = ruclip.to(device)
x2-paddle --> ready tokenizer --> ready Working with z of shape (1, 256, 32, 32) = 262144 dimensions. vae --> ready ruclip --> ready
Generation by rudale
# Load translator import translators as ts def translate(txt, backend): return getattr(ts, backend)(txt, from_language='auto', to_language='ru')
Using China server backend.
# Text description of the image backend = 'google' # google/bing/alibaba/tencent/sogou source_text = 'Avocado chair' # In automatic mode, you can enter any language as a text description # auto mode, you can type any language print('Source text target text:', source_text) try: target_text = translate(source_text, backend) except: raise Exception( 'Failed to call the translator, please try to replace it backend. ' 'Failed to call the translator, please try to replace the backend.' ) print('Target text target text:', target_text)
Source text target text: Avocado chair Target text target text: Стул авокадо
# Random number seed. In a single generation, the same seed number and the same text description will generate the same image seed_everything(42)
# Main part of model generation # Large model generation takes time, please wait patiently text = target_text pil_images = [] scores = [] for top_k, top_p, images_num in [ (2048, 0.995, 3), # A total of 8 times are generated. If you are in a hurry, you can comment out 6 (1536, 0.99, 3), (1024, 0.99, 3), (1024, 0.98, 3), (512, 0.97, 3), (384, 0.96, 3), (256, 0.95, 3), (128, 0.95, 3), ]: _pil_images, _scores = generate_images(text, tokenizer, dalle, vae, top_k=top_k, images_num=images_num, top_p=top_p) pil_images += _pil_images scores += _scores
HBox(children=(IntProgress(value=0, max=1024), HTML(value=''))) /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/varbase_patch_methods.py:392: UserWarning: [93m Warning: tensor.grad will return the tensor value of the gradient. This is an incompatible upgrade for tensor.grad API. It's return type changes from numpy.ndarray in version 2.0 to paddle.Tensor in version 2.1.0. If you want to get the numpy value of the gradient, you can use :code:`x.grad.numpy()` [0m warnings.warn(warning_msg) HBox(children=(IntProgress(value=0, max=1024), HTML(value=''))) HBox(children=(IntProgress(value=0, max=1024), HTML(value=''))) HBox(children=(IntProgress(value=0, max=1024), HTML(value=''))) HBox(children=(IntProgress(value=0, max=1024), HTML(value=''))) HBox(children=(IntProgress(value=0, max=1024), HTML(value=''))) HBox(children=(IntProgress(value=0, max=1024), HTML(value=''))) HBox(children=(IntProgress(value=0, max=1024), HTML(value='')))
# Show build results show([pil_image for pil_image, score in sorted(zip(pil_images, scores), key=lambda x: -x[1])] , 6)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working if isinstance(obj, collections.Iterator): /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working return list(data) if isinstance(data, collections.MappingView) else data /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/figure.py:457: UserWarning: matplotlib is currently using a non-GUI backend, so cannot show the figure "matplotlib is currently using a non-GUI backend, "
Use CLIP to filter the images that best match the description (auto cherry pick by ruclip)
top_images, clip_scores = cherry_pick_by_clip(pil_images, text, ruclip, ruclip_processor, device=device, count=6) show(top_images, 3)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:130: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations if data.dtype == np.object:
Super resolution of images
sr_images = super_resolution(top_images, realesrgan) show(sr_images, 3)