tensorflow model preservation and loading

Keywords: network Session

1. tensorflow model preservation

Examples of model preservation:

import tensorflow as tf  
import numpy as np
with tf.name_scope('train'):
    w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')  
    w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')  
saver = tf.train.Saver()  
with tf.Session() as sess:
    for t in range(3):
        sess.run(tf.global_variables_initializer())  
        saver.save(sess, 'D:\\tuxiang\\hhh\\my_test_model',t)  

Among them, t refers to the number of iterations, which will be appended to the model name when the model is saved.
You can change the value of the parameter max_to_keep in tf.train.Saver() to set the number of models that need to be saved. You can also set the parameters that need to be saved without saving all the parameters. The setting method is as follows:

# Gets tensor for the specified scope
need_save = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='vgg_feat_fc')
# When initializing saver, a var_list parameter is passed in, which is the parameter that needs to be saved.
saver = tf.train.Saver(need_save)

The results are as follows:

2. Tenorflow model loading

After saving the model, we can directly call the saved model to test the target data set without training from scratch.
Examples of the above models are documented:

1. By re-creating the same network (copying the previous code) and using it as the original model.

Code:

import tensorflow as tf  
import numpy as np
with tf.name_scope('train'):
    w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')  
    w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2') 
#     w3 = tf.Variable(tf.random_normal(shape=[5]), name='w3') 
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # Initial w1
    print(sess.run('train/w1:0'))
    saver.restore(sess,'D:\\tuxiang\\hhh\\my_test_model-1')
    # w1 after assignment
    print(sess.run('train/w1:0'))
    # confirm
    print(sess.run(w1))


In this way, the graph model is restored. Because saver = tf.train.Saver() represents the graph of the current network outside the session, the graph must be consistent with the graph structure of the original model before loading. I tried to add variable w3 in "train", which will cause errors in loading parameters.

Note: It can also be used in conversations.

# 'train'Corresponding to the start of the model variable loaded name_scopeļ¼Œ
# The name_scope of the model is consistent with the name_scope of the parameter to be loaded (the graph model is consistent)
tf.train.Saver([var for var in tf.global_variables() if var.name.startswith('train')]) \
            .restore(sess,' D:\\tuxiang\\hhh\\my_test_model-1')

To load all the parameters of the corresponding model of a constructed network and participate in training, I used the above code when writing LSTM, but failed to achieve the initialization of W1 and w2.

2. Use tf.import_meta_graph(path) to load the network defined in the. meta file into the current graph, and then use the restore recovery parameter
import tensorflow as tf  
import numpy as np
with tf.Session() as sess:
    saver =tf.train.import_meta_graph('D:\\tuxiang\\hhh\\my_test_model-1.meta')
    saver.restore(sess,'D:\\tuxiang\\hhh\\my_test_model-1')
    print(sess.run('train/w1:0'))  

Because tf.train.import_meta_graph() is used, it is not necessary to rewrite the previous network here. Importing the trained model in this way will import all the parameters of the model.

Reference material:

1.https://blog.csdn.net/u010159842/article/details/82791533

Posted by Isomerizer on Tue, 27 Aug 2019 22:49:35 -0700