Using ESPCN of residual network to enlarge 4 times of flowers image

Use ESPCN network with residual network to recover picture 4 times


Reduce the length and width of the input image by 1 / 4

Add a residual network, add two functions, one is leaky ﹐ relu, the other is a function of residual ﹐ block, which is used to generate network residual blocks, to realize a residual block with two layers of convolution in the middle, then in the whole network construction, through a convolution layer and a residual layer to complete image feature conversion, the residual layer is a network composed of 16 residual blocks and a convolution layer, and feature transformation After that, the final feature restoration process is completed by 5-layer neural network.


The bottom 5 layers are the repair feature data, the first layer is the convolution layer, the second layer will expand the results of the first layer according to the pixel block of 2X2 size, the third layer is the same as the first layer, the fourth layer is the same as the second layer, and the fifth layer is also the convolution layer, which is magnified 4 times by two consecutive transformations, and finally generates the final repair image through the convolution of the output 3 channels.


Image shape change

(16, 256, 256, 3)
(16, 64, 64, 3)
(16, 64, 64, 64)
(16, 64, 64, 64)
(16, 64, 64, 64)
(16, 64, 64, 256)
(16, 128, 128, 64)
(16, 128, 128, 64)
(16, 128, 128, 256)
(16, 256, 256, 64)
(16, 256, 256, 64)
(16, 256, 256, 3)


import tensorflow as tf
from datasets import flowers
import numpy as np
import matplotlib.pyplot as plt
import os
import tensorflow.contrib.slim as slim

def batch_mse_psnr(dbatch):
    im1, im2 = np.split(dbatch, 2)
    mse = ((im1 - im2) ** 2).mean(axis=(1, 2))
    psnr = np.mean(20 * np.log10(255.0 / np.sqrt(mse)))
    return np.mean(mse), psnr

def batch_y_psnr(dbatch):
    r, g, b = np.split(dbatch, 3, axis=3)
    y = np.squeeze(0.3 * r + 0.59 * g + 0.11 * b)
    im1, im2 = np.split(y, 2)
    mse = ((im1 - im2) ** 2).mean(axis=(1, 2))
    psnr = np.mean(20 * np.log10(255.0 / np.sqrt(mse)))
    return psnr

def batch_ssim(dbatch):
    im1, im2 = np.split(dbatch, 2)
    imgsize = im1.shape[1] * im1.shape[2]
    avg1 = im1.mean((1, 2), keepdims=1)
    avg2 = im2.mean((1, 2), keepdims=1)
    std1 = im1.std((1, 2), ddof=1)
    std2 = im2.std((1, 2), ddof=1)
    cov = ((im1 - avg1) * (im2 - avg2)).mean((1, 2)) * imgsize / (imgsize - 1)
    avg1 = np.squeeze(avg1)
    avg2 = np.squeeze(avg2)
    k1 = 0.01
    k2 = 0.03
    c1 = (k1 * 255) ** 2
    c2 = (k2 * 255) ** 2
    c3 = c2 / 2
    return np.mean(
        (2 * avg1 * avg2 + c1) * 2 * (cov + c3) / (avg1 ** 2 + avg2 ** 2 + c1) / (std1 ** 2 + std2 ** 2 + c2))

def showresult(subplot, title, orgimg, thisimg, dopsnr=True):
    p = plt.subplot(subplot)
    p.imshow(np.asarray(thisimg[0], dtype='uint8'))
    if dopsnr:
        conimg = np.concatenate((orgimg, thisimg))
        mse, psnr = batch_mse_psnr(conimg)
        ypsnr = batch_y_psnr(conimg)
        ssim = batch_ssim(conimg)
        p.set_title(title + str(int(psnr)) + " y:" + str(int(ypsnr)) + " s:" + " s:%.4f" % ssim)

height = width = 256
batch_size = 16
DATA_DIR = "D:/tmp/data/flowers"

# Select dataset validation
dataset = flowers.get_split('validation', DATA_DIR)
# Create a provider
provider = slim.dataset_data_provider.DatasetDataProvider(dataset, num_readers=32)
# Get content through get of provider
[image, label] = provider.get(['image', 'label'])

# Clip picture to uniform size 
distorted_image = tf.image.resize_image_with_crop_or_pad(image, height, width)  # Clip size, not enough fill
images, labels = tf.train.batch([distorted_image, label], batch_size=batch_size)
print(images.shape)  # (16, 256, 256, 3)

x_smalls = tf.image.resize_images(images, (np.int32(height / 4), np.int32(width / 4)))  # Reduce 4*4 times
x_smalls2 = x_smalls / 255.0
print(x_smalls2.shape)  # (16, 64, 64, 3)
# reduction
x_nearests = tf.image.resize_images(x_smalls, (height, width), tf.image.ResizeMethod.NEAREST_NEIGHBOR)
x_bilins = tf.image.resize_images(x_smalls, (height, width), tf.image.ResizeMethod.BILINEAR)
x_bicubics = tf.image.resize_images(x_smalls, (height, width), tf.image.ResizeMethod.BICUBIC)

# net = slim.conv2d(x_smalls2, 64, 5,activation_fn = tf.nn.tanh)
# net =slim.conv2d(net, 256, 3,activation_fn = tf.nn.tanh)
# net = tf.depth_to_space(net,2) #64
# net =slim.conv2d(net, 64, 3,activation_fn = tf.nn.tanh)
# net = tf.depth_to_space(net,2) #16
# y_predt = slim.conv2d(net, 3, 3,activation_fn = None)#2*2*3

def leaky_relu(x, alpha=0.1, name='lrelu'):
    with tf.name_scope(name):
        x = tf.maximum(x, alpha * x)
        return x

def residual_block(nn, i, name='resblock'):
    with tf.variable_scope(name + str(i)):
        conv1 = slim.conv2d(nn, 64, 3, activation_fn=leaky_relu, normalizer_fn=slim.batch_norm)
        conv2 = slim.conv2d(conv1, 64, 3, activation_fn=leaky_relu, normalizer_fn=slim.batch_norm)
        return tf.add(nn, conv2)

net = slim.conv2d(x_smalls2, 64, 5, activation_fn=leaky_relu)
print(net.shape)  # (16, 64, 64, 64)
block = []
for i in range(16):
    block.append(residual_block(block[-1] if i else net, i))
conv2 = slim.conv2d(block[-1], 64, 3, activation_fn=leaky_relu, normalizer_fn=slim.batch_norm)
print(conv2.shape)  # (16, 64, 64, 64)
sum1 = tf.add(conv2, net)
print(sum1.shape)  # (16, 64, 64, 64)
conv3 = slim.conv2d(sum1, 256, 3, activation_fn=None)
print(conv3.shape)  # (16, 64, 64, 256)
ps1 = tf.depth_to_space(conv3, 2)
print(ps1.shape)  # (16, 128, 128, 64)
relu2 = leaky_relu(ps1)
print(relu2.shape)  # (16, 128, 128, 64)
conv4 = slim.conv2d(relu2, 256, 3, activation_fn=None)
print(conv4.shape)  # (16, 128, 128, 256)
ps2 = tf.depth_to_space(conv4, 2)  # Double 64
print(ps2.shape)  # (16, 256, 256, 64)
relu3 = leaky_relu(ps2)
print(relu3.shape)  # (16, 256, 256, 64)
y_predt = slim.conv2d(relu3, 3, 3, activation_fn=None)  # output
print(y_predt.shape)  # (16, 256, 256, 3)
y_pred = y_predt * 255.0
y_pred = tf.maximum(y_pred, 0)
y_pred = tf.minimum(y_pred, 255)

dbatch = tf.concat([tf.cast(images, tf.float32), y_pred], 0)

learn_rate = 0.001

cost = tf.reduce_mean(tf.pow(tf.cast(images, tf.float32) / 255.0 - y_predt, 2))
optimizer = tf.train.AdamOptimizer(learn_rate).minimize(cost)
# training_epochs =100000
# display_step =5000

training_epochs = 10000
display_step = 400

flags = 'b' + str(batch_size) + '_h' + str(height / 4) + '_r' + str(
    learn_rate) + '_res'  # set for practicers to try different setups
# flags='b'+str(batch_size)+'_r'+str(height/4)+'_depth_conv2d'#set for practicers to try different setups
if not os.path.exists('save'):
save_path = 'save/tf_' + flags
if not os.path.exists(save_path):
saver = tf.train.Saver(max_to_keep=1)  # Generate saver

# sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
sess = tf.Session()

kpt = tf.train.latest_checkpoint(save_path)
startepo = 0
if kpt != None:
    saver.restore(sess, kpt)
    ind = kpt.find("-")
    startepo = int(kpt[ind + 1:])
    print("startepo=", startepo)

# Startup queue

# Start cycle start training
for epoch in range(startepo, training_epochs):

    _, c =[optimizer, cost])

    # Show details in training
    if epoch % display_step == 0:
        d_batch = dbatch.eval()
        mse, psnr = batch_mse_psnr(d_batch)
        ypsnr = batch_y_psnr(d_batch)
        ssim = batch_ssim(d_batch)
        print("Epoch:", '%04d' % (epoch + 1),
              "cost=", "{:.9f}".format(c), "psnr", psnr, "ypsnr", ypsnr, "ssim", ssim), save_path + "/tfrecord.cpkt", global_step=epoch)
print("complete!"), save_path + "/tfrecord.cpkt", global_step=epoch)

imagesv, label_batch, x_smallv, x_nearestv, x_bilinv, x_bicubicv, y_predv =
    [images, labels, x_smalls, x_nearests, x_bilins, x_bicubics, y_pred])
print("primary", np.shape(imagesv), "Zoomed", np.shape(x_smallv), label_batch)

#        print(np.max(imagesv[0]),np.max(x_bilinv[0]),np.max(x_bicubicv[0]),np.max(y_predv[0]))
#        print(np.min(imagesv[0]),np.min(x_bilinv[0]),np.min(x_bicubicv[0]),np.min(y_predv[0]))

plt.figure(figsize=(20, 10))

showresult(161, "org", imagesv, imagesv, False)
showresult(162, "small/4", imagesv, x_smallv, False)
showresult(163, "near", imagesv, x_nearestv)
showresult(164, "biline", imagesv, x_bilinv)
showresult(165, "bicubicv", imagesv, x_bicubicv)
showresult(166, "pred", imagesv, y_predv)

