I. Environment
Python 3.7.3 (Anaconda 3)
TensorFlow 1.14.0
II. METHODS
TensorFlow model preservation and recovery methods are mainly provided by tf.train.Saver class, but also combined with some model graph loading methods.
Official website description of relevant methods:
https://www.tensorflow.org/guide/saved_model?hl=zh-cN
https://tensorflow.google.cn/api_docs/python/tf/train/import_meta_graph
https://tensorflow.google.cn/api_docs/python/tf/train/Saver
1. Model preservation
This stage is generally referred to as the train stage, which mainly includes:
- Building models
- Training model
- Save the model
The saved model is mainly accomplished by the saved method of tf.train.Saver class object. Four types of files are generated under the specified directory of saved model:
saved_models_directory: ******.meta ******.index ******.data-00000-of-00001 checkpoint
Among them,
(1) checkpoint file: a list of the model files saved in the last five training sessions (the default value of parameter max_to_keep when creating tf.train.Saver class objects) is recorded, which can be viewed through a plain text editor;
(2) meta file: saves the graph of the model (network structure);
(3) index file and. data-00000-of-00001 file: (official description document has not been found yet, to be added, many existing data points out that these two documents preserve the variable values of the trained model (weight, bias, etc.), but provide official description, so there is no need to find official description document to confirm. Update later)
It should be noted that:
When training a large number of batch models, if the parameter max_to_keep of tf.train.Saver() class object is created with the default value of 5, then a checkpoint file will be generated in the directory where the model is saved. At the same time, the last three files will generate five copies in the model saved directory, that is, the last five models saved.
2. Model Reuse
Generally called inference stage, it is mainly divided into:
(1) Building a map (simply understood as creating a network structure, specifically in the following two ways)
-
Manual rebuilding of maps
-
Automatically restore graphs from saved files
(2) Recovery variables (weights, biases, hyperparameters, etc.)
(3) Running related operations
3. Explanation of Examples
Official website reference: https://www.tensorflow.org/guide/saved_model?hl=zh_cn
1. Model preservation
>>> import tensorflow as tf # Create variables and initialize the magnitude of 0 >>> v1 = tf.get_variable("v1", shape=[3], initializer=tf.zeros_initializer) >>> v2 = tf.get_variable("v2", shape=[5], initializer=tf.zeros_initializer) # The values of two variables are added and subtracted by 1, respectively. >>> inc_v1 = v1.assign(v1+1) >>> inc_v2 = v2.assign(v2-1) # Adding operations to save and reuse variables >>> init_op = tf.global_variables_initializer() # Create Saver objects >>> saver = tf.train.Saver() # Create a session >>> with tf.Session() as sess: # initialize variable ... sess.run(init_op) # Perform operations ... inc_v1.op.run() # The replication operation has op ... inc_v2.op.run() # Save variables to the specified directory ... save_path = saver.save(sess,'/saved_models_directory/model.ckpt') ... print("Model saved in path: %s" % save_path) ... Model saved in path: /saved_models_directory/model.ckpt
The following four files are generated in the specified directory:
model.ckpt.meta
model.ckpt.index
model.ckpt.data-00000-of-00001
checkpoint
2. Model Reuse
Method 1: Manual rebuilding of maps
# Reset graph >>> tf.reset_default_graph() # Constructing the same model and variables as the training stage >>> v1 = tf.get_variable("v1",shape=[3]) >>> v2 = tf.get_variable("v2",shape=[5]) # Create the tf.train.Saver() object >>> saver = tf.train.Saver() >>> with tf.Session() as sess: # Relevant variables of recovery model ... saver.restore(sess, '/saved_models_directory/model.ckpt') # Note that the checkpoint file passed in by the second parameter has no suffix name ... print("Model restored.") # Perform related operations ... print("v1 : %s" % v1.eval()) ... print("v2 : %s" % v2.eval()) ... Model restored. v1 : [1. 1. 1.] v2 : [-1. -1. -1. -1. -1.]
PS: Manual rebuilding graph (network structure) and restoring the saved values of trained variables are generally error-free, but relatively troublesome!
Method 2: Automatically restore the graph from the saved file
It is relatively simple and fast to restore the network structure of the trained model and the values of the relevant variables directly from the saved files, but for beginners, there may be some problems!
First, one mistake you may encounter is to forget to restore the TensorFlow graph before restoring the variable value, which is the network structure of the trained model.
>>> tf.reset_default_graph() >>> with tf.Session() as sess: ... saver.restore(sess,'/saved_models_directory/model.ckpt') ... print("Model restored.") ... all_vars = tf.global_variables() ... for v in all_vars: ... print(v.name) ... Traceback (most recent call last): File "<stdin>", line 2, in <module> File "/******/Anaconda/anaconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py", line 1286, in restore {self.saver_def.filename_tensor_name: save_path}) File "/******/Anaconda/anaconda3/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 950, in run run_metadata_ptr) File "/******/Anaconda/anaconda3/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1098, in _run raise RuntimeError('The Session graph is empty. Add operations to the ' RuntimeError: The Session graph is empty. Add operations to the graph before calling run().
Error Tip: The session graph is empty and needs to be created or imported into the saved model. Therefore, it is important to note that the trained variable values must be restored before the graph (network structure) can be restored.
Secondly, special attention should be paid to the difference between the variable name of the model recovered from the saved file by TensorFlow and the variable name of the training stage model. The suffix of ":0" was added after the variable name of the training stage model, such as "v1" to "v1:0", so beginners may encounter some errors.
The following example can observe the suffix change of this recovery variable
>>> tf.reset_default_graph() >>> # latest_model_file = tf.train.latest_checkpoint("/saved_models_directory/") >>> saver = tf.train.import_meta_graph("/saved_models_directory/model.ckpt.meta") >>> with tf.Session() as sess: ... saver.restore(sess,'/saved_models_directory/model.ckpt') ... print("Model restored.") ... all_vars = tf.global_variables() ... for v in all_vars: ... print(v.name) ... Model restored. v1:0 v2:0
You can find that the name of the variable recovered from the saved model will be an additional ":0", so when looking at variables V1 and v2, you need to use the new variable names "v1:0" and "v2:0".
Therefore, after restoring the model and variables, the relevant operations can be performed through the new variable name, and the relevant values can be printed out and viewed through the methods of v1.eval() or sess.run(v1).
>>> tf.reset_default_graph() >>> saver = tf.train.import_meta_graph("/saved_models_directory/model.ckpt.meta") >>> with tf.Session() as sess: ... saver.restore(sess,'/saved_models_directory/model.ckpt') ... print("Model restored.") ... v1 = tf.get_default_graph().get_tensor_by_name('v1:0') ... v2 = tf.get_default_graph().get_tensor_by_name('v2:0') ... print("v1 : %s" % v1.eval()) ... print("v2 : %s" % sess.run(v2)) ... Model restored. v1 : [1. 1. 1.] v2 : [-1. -1. -1. -1. -1.]
Operations can also be performed in the following ways
>>> tf.reset_default_graph() >>> saver = tf.train.import_meta_graph("/saved_models_directory/model.ckpt.meta") # Note that the graph structure of the model is saved in the meta file >>> with tf.Session() as sess: ... saver.restore(sess,'/saved_models_directory/model.ckpt') # Note that the checkpoint file passed in by the second parameter has no suffix name ... print("Model restored.") ... print (sess.run(tf.get_default_graph().get_tensor_by_name('v1:0'))) ... print (sess.run(tf.get_default_graph().get_tensor_by_name('v2:0'))) ... Model restored. [1. 1. 1.] [-1. -1. -1. -1. -1.]
Error instances: New variable names that automatically add suffixes without model reuse
>>> with tf.Session() as sess: ... saver = tf.train.import_meta_graph("/saved_models_directory/model.ckpt.meta") ... saver.restore(sess, '/saved_models_directory/model.ckpt') ... print("Model restored.") ... print("v1 : %s" % v1.eval()) ... print("v2 : %s" % v2.eval()) ... # The error report is as follows: Traceback (most recent call last): File "<stdin>", line 5, in <module> NameError: name 'v1' is not defined