Rolling sequence and reverse sequence of tf rnn delay sequence

Keywords: network Lambda Session

Using rnn to fit delay and inversion sequences

In fact, the rnn of this layer is not as good as full connection. When the sequence length is too long, the number of last full connection nodes needs to be increased. However, if the sequence is relatively short, there are too many parameters and the result is very accurate....

Scroll the sequence, shift the sequence horizontally to the right by 3 elements, and set the front to 0 on this basis

i  [1 0 1 1 0 1 1 1 1 1 0 1 0 0 1 1 0 1 1 0 1 1 1 0 0 0 1 1 1 0 1 1]
j  [0 1 1 1 0 1 1 0 1 1 1 1 1 0 1 0 0 1 1 0 1 1 0 1 1 1 0 0 0 1 1 1]
k  [0 1 1 1 0 1 1 0 1 1 1 1 1 0 1 0 0 1 1 0 1 1 0 1 1 1 0 0 0 1 1 1]

 

Inversion sequence

i  [0 1 1 1 1 1 1 0 1 0 1 1 1 0 0 1 1 0 0 1 1 0 0 0 0 1 0 1 1 1 1 0]
j  [0 1 1 1 1 0 1 0 0 0 0 1 1 0 0 1 1 0 0 1 1 1 0 1 0 1 1 1 1 1 1 0]
k  [0 1 1 1 1 0 1 0 0 0 0 1 1 0 0 1 1 0 0 1 1 1 0 1 0 1 1 1 1 1 1 0]

 

import tensorflow as tf
import numpy as np
import random


#  Turning numbers into binary np arrays
def int2bin(x, num_len):
    x = bin(x)[2:].zfill(num_len)
    return np.array(
        list(map(lambda i: float(i), x))
    ).reshape((1, 1, num_len))


#  Get data, number of data groups and maximum number of digits
def get_data(batch_size, num_len):
    #  Delay sequence
    # [0 1 2 3 4 5 6 7 8 9]
    # [7 8 9 0 1 2 3 4 5 6]
    # a = np.random.randint(0, 2, (batch_size, num_len))
    # b = np.roll(a, 3, axis=1)
    # a = a.reshape((batch_size, 1, num_len))

    # Inverse sequence
    # [0 1 2 3 4 5 6 7 8 9]
    # [9 8 7 6 5 4 3 2 1 0]
    a = np.random.randint(0, 2, (batch_size, num_len))
    b = a[:, ::-1]
    a = a.reshape((batch_size, 1, num_len))
    return a, b


# x, y = get_data(10, 8)
# print(x.shape, y.shape) # (10, 1, 8) (10, 8)

output_n = input_n = 32
train_steps = 40000
show_steps = 500
batch_size = 32
hidden_size = 64
full_size = 64
time_step = 1
learning_rate = .001

x_in = tf.placeholder(tf.float32, (None, time_step, input_n))
y_in = tf.placeholder(tf.float32, (None, output_n))
print(x_in.shape, y_in.shape)  # (?, 2, 8) (?, 8)
weight = tf.Variable(tf.truncated_normal([full_size, output_n], stddev=.1))
biase = tf.Variable(tf.constant(.1, shape=[output_n]))

# Building network lstm -- full -- relu
# cell = tf.contrib.rnn.BasicLSTMCell(hidden_size)
# outputs, final_state = tf.nn.dynamic_rnn(cell, x_in, dtype=tf.float32)
# print(final_state[1].shape)  # (?, 128)
# full = tf.layers.dense(final_state[1], full_size)
# print(full.shape)  # (?, 64)
# result = tf.matmul(full, weight) + biase
# result = tf.nn.relu(result)
# print(result.shape)  # (?, 8)


# Simple network, after rnn output, directly connected to a full connection output, relu prevents negative numbers
cell = tf.contrib.rnn.BasicLSTMCell(hidden_size)
outputs, final_state = tf.nn.dynamic_rnn(cell, x_in, dtype=tf.float32)
print(final_state[1].shape)  # (?, 128)
result = tf.contrib.slim.fully_connected(final_state[1],
                                         input_n,
                                         activation_fn=tf.nn.relu)

# loss and train
loss = tf.reduce_mean((result - y_in) ** 2)
train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for epoch in range(train_steps):
        x_data, y_data = get_data(batch_size, input_n)
        sess.run(train_op, feed_dict={
            x_in: x_data,
            y_in: y_data,
        })

        if not (epoch + 1) % show_steps:
            x_data, y_data = get_data(batch_size, input_n)
            loss__val, res = sess.run([loss, result], feed_dict={
                x_in: x_data,
                y_in: y_data,
            })
            print('epoch:', epoch + 1, 'loss:', loss__val)
            for i, j, k in zip(x_data[:4], y_data[:4], res[:4]):
                print('i ', i.flatten())
                print('j ', j)
                print('k ', np.rint(k).astype(np.uint8))
                print('k ', k)
                print('m ', np.mean(
                    np.rint(k) == j
                ))

Posted by brainardp on Thu, 02 Jan 2020 14:57:36 -0800