Tensorflow has introduced Estimator since version 1.3, and with the evolution of the version, support for this advanced API programming method is increasing, and support for multi-GPU training can be easily realized on Estimator. In my previous blog, I used the low-level API to build and train the model. This advantage is more flexible and can understand the underlying details of the model, but the disadvantage is that the code is large and cumbersome, and many details need to be implemented by myself. For this reason, I tried to use the advanced API in the latest version of TensorFlow 1.14 to test whether it was really easy to use and to achieve the same performance as the low-level API.
I'm testing based on ImageNet's image classification data. Imagenet's data preparation can be found in my previous blog. Specific code is as follows, which includes two models, one is the pre-training model Darknet 53 used in Yolo V3, and the other is Alexnet.
import tensorflow as tf import horovod.tensorflow as hvd import os import random import time import numpy as np from absl import app as absl_app imageWidth = 224 imageHeight = 224 imageDepth = 3 batch_size = 32 resize_min = 256 train_files_names = os.listdir('/data/AI/train_tf/') train_files = ['/data/AI/train_tf/'+item for item in train_files_names] valid_files_names = os.listdir('/data/AI/valid_tf/') valid_files = ['/data/AI/valid_tf/'+item for item in valid_files_names] # Parse TFRECORD and distort the image for train def _parse_function(example_proto): features = {"image": tf.FixedLenFeature([], tf.string, default_value=""), "height": tf.FixedLenFeature([1], tf.int64, default_value=[0]), "width": tf.FixedLenFeature([1], tf.int64, default_value=[0]), "channels": tf.FixedLenFeature([1], tf.int64, default_value=[3]), "colorspace": tf.FixedLenFeature([], tf.string, default_value=""), "img_format": tf.FixedLenFeature([], tf.string, default_value=""), "label": tf.FixedLenFeature([1], tf.int64, default_value=[0]), "bbox_xmin": tf.VarLenFeature(tf.float32), "bbox_xmax": tf.VarLenFeature(tf.float32), "bbox_ymin": tf.VarLenFeature(tf.float32), "bbox_ymax": tf.VarLenFeature(tf.float32), "text": tf.FixedLenFeature([], tf.string, default_value=""), "filename": tf.FixedLenFeature([], tf.string, default_value="") } parsed_features = tf.parse_single_example(example_proto, features) image_decoded = tf.image.decode_jpeg(parsed_features["image"], channels=3) # Random resize the image shape = tf.shape(image_decoded) height, width = shape[0], shape[1] resized_height, resized_width = tf.cond(height<width, lambda: (resize_min, tf.cast(tf.multiply(tf.cast(width, tf.float64),tf.divide(resize_min,height)), tf.int32)), lambda: (tf.cast(tf.multiply(tf.cast(height, tf.float64),tf.divide(resize_min,width)), tf.int32), resize_min)) image_float = tf.image.convert_image_dtype(image_decoded, tf.float32) resized = tf.image.resize_images(image_float, [resized_height, resized_width]) # Random crop from the resized image cropped = tf.random_crop(resized, [imageHeight, imageWidth, 3]) # Flip to add a little more random distortion in. flipped = tf.image.random_flip_left_right(cropped) # Standardization the image #image_train = flipped image_train = tf.image.per_image_standardization(flipped) #features = {'images': image_train} return image_train, tf.one_hot(parsed_features["label"][0], 1000) def train_input_fn(): dataset_train = tf.data.TFRecordDataset(train_files) dataset_train = dataset_train.map(_parse_function, num_parallel_calls=4) dataset_train = dataset_train.repeat(10) dataset_train = dataset_train.batch(batch_size) dataset_train = dataset_train.prefetch(batch_size) return dataset_train #return tf.data.make_one_shot_iterator(dataset_train).get_next() def _parse_test_function(example_proto): features = {"image": tf.FixedLenFeature([], tf.string, default_value=""), "height": tf.FixedLenFeature([1], tf.int64, default_value=[0]), "width": tf.FixedLenFeature([1], tf.int64, default_value=[0]), "channels": tf.FixedLenFeature([1], tf.int64, default_value=[3]), "colorspace": tf.FixedLenFeature([], tf.string, default_value=""), "img_format": tf.FixedLenFeature([], tf.string, default_value=""), "label": tf.FixedLenFeature([1], tf.int64, default_value=[0]), "bbox_xmin": tf.VarLenFeature(tf.float32), "bbox_xmax": tf.VarLenFeature(tf.float32), "bbox_ymin": tf.VarLenFeature(tf.float32), "bbox_ymax": tf.VarLenFeature(tf.float32), "text": tf.FixedLenFeature([], tf.string, default_value=""), "filename": tf.FixedLenFeature([], tf.string, default_value="") } parsed_features = tf.parse_single_example(example_proto, features) image_decoded = tf.image.decode_jpeg(parsed_features["image"], channels=3) shape = tf.shape(image_decoded) height, width = shape[0], shape[1] resized_height, resized_width = tf.cond(height<width, lambda: (resize_min, tf.cast(tf.multiply(tf.cast(width, tf.float64),tf.divide(resize_min,height)), tf.int32)), lambda: (tf.cast(tf.multiply(tf.cast(height, tf.float64),tf.divide(resize_min,width)), tf.int32), resize_min)) image_float = tf.image.convert_image_dtype(image_decoded, tf.float32) image_resized = tf.image.resize_images(image_float, [resized_height, resized_width]) # calculate how many to be center crop shape = tf.shape(image_resized) height, width = shape[0], shape[1] amount_to_be_cropped_h = (height - imageHeight) crop_top = amount_to_be_cropped_h // 2 amount_to_be_cropped_w = (width - imageWidth) crop_left = amount_to_be_cropped_w // 2 image_cropped = tf.slice(image_resized, [crop_top, crop_left, 0], [imageHeight, imageWidth, -1]) image_valid = tf.image.per_image_standardization(image_cropped) #features = {'images': image_valid} return image_valid, tf.one_hot(parsed_features["label"][0], 1000) def val_input_fn(): dataset_valid = tf.data.TFRecordDataset(valid_files) dataset_valid = dataset_valid.map(_parse_test_function, num_parallel_calls=4) dataset_valid = dataset_valid.batch(batch_size) dataset_valid = dataset_valid.prefetch(batch_size) return dataset_valid #return tf.data.make_one_shot_iterator(dataset_valid).get_next() def darknet_53(): image = tf.keras.Input(shape=(imageHeight,imageWidth,3)) l = tf.keras.layers def _conv(inputs, filters, kernel_size, strides, padding, bias=False, normalize=True, activation='leaky_relu'): output = inputs padding_str = 'same' if padding>0: output = l.ZeroPadding2D(padding=(padding, padding))(output) padding_str = 'valid' output = l.Conv2D(filters, kernel_size, strides, padding_str, use_bias=bias, \ kernel_initializer='he_normal', \ kernel_regularizer=tf.keras.regularizers.l2(l=5e-4))(output) if normalize: output = l.BatchNormalization(axis=3)(output, training=True) if activation=='leaky_relu': output = l.LeakyReLU(alpha=0.1)(output) return output def _residual(inputs, filters): output = _conv(inputs, filters, 1, (1,1), 0) output = _conv(output, filters*2, 3, (1,1), 1) output = tf.add(inputs, output) return output net = _conv(image, 32, 3, (1,1), 1) net = _conv(net, 64, 3, (2,2), 1) net = _residual(net, 32) net = _conv(net, 128, 3, (2,2), 1) for _ in range(2): net = _residual(net, 64) net = _conv(net, 256, 3, (2,2), 1) for _ in range(8): net = _residual(net, 128) #add route1 net = _conv(net, 512, 3, (2,2), 1) for _ in range(8): net = _residual(net, 256) #add route2 net = _conv(net, 1024, 3, (2,2), 1) for _ in range(4): net = _residual(net, 512) #add route3 net = l.GlobalAveragePooling2D()(net) net = l.Dense(1000, kernel_initializer=tf.initializers.truncated_normal(stddev=1e-1))(net) net = tf.keras.activations.softmax(net) model = tf.keras.Model(inputs=image, outputs=net) return model def alexnet(): image = tf.keras.Input(shape=(imageHeight,imageWidth,3)) l = tf.keras.layers def _conv(inputs, filters, kernel_size, strides, padding, bias=True): output = inputs padding_str = 'same' if padding>0: output = l.ZeroPadding2D(padding=(padding, padding))(output) padding_str = 'valid' output = l.Conv2D(filters, kernel_size, strides, padding_str, use_bias=bias, \ kernel_initializer=tf.initializers.truncated_normal(stddev=1e-1), \ kernel_regularizer=tf.keras.regularizers.l2(l=5e-4))(output) output = l.BatchNormalization(axis=3)(output, training=True) output = l.ReLU()(output) return output net = _conv(image, 96, 11, 4, 0) net = l.MaxPool2D(3, 2)(net) net = _conv(net, 256, 5, 1, 0) net = l.MaxPool2D(3, 2)(net) net = _conv(net, 384, 3, 1, 0) net = _conv(net, 384, 3, 1, 0) net = _conv(net, 256, 3, 1, 0) net = l.MaxPool2D(3, 2)(net) net = l.Flatten()(net) net = l.Dense(4096, kernel_initializer=tf.initializers.truncated_normal(stddev=1/4096))(net) net = l.Dense(4096, kernel_initializer=tf.initializers.truncated_normal(stddev=1/4096))(net) net = l.Dense(1000, kernel_initializer=tf.initializers.truncated_normal(stddev=1/1000))(net) net = tf.keras.activations.softmax(net) model = tf.keras.Model(inputs=image, outputs=net) return model def my_loss(y_true, y_pred): l2_loss = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) crossentropy = tf.keras.losses.categorical_crossentropy(y_true=y_true, y_pred=y_pred) loss = tf.add(l2_loss, crossentropy)/batch_size return loss def main(_): epoch_steps = 1281167/batch_size boundaries = [epoch_steps*5, epoch_steps*8] values = [0.01, 0.001, 0.0001] learning_rate_fn = tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries, values) model = darknet_53() model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate = learning_rate_fn, momentum=0.9), \ loss='categorical_crossentropy', \ metrics=['categorical_accuracy', 'top_k_categorical_accuracy']) est_model = tf.keras.estimator.model_to_estimator(model, model_dir='imagenet_model_darknet53/') for _ in range(5): est_model.train(input_fn=train_input_fn, steps=5000) eval_results = est_model.evaluate(input_fn=val_input_fn) print('\nEvaluation:\n\t%s\n' % eval_results) if __name__ == "__main__": tf.app.run(main)
As can be seen from the above code, it is still very convenient to build models and train with Estimator. However, in the course of testing, I found that if I did not encode Labels in the input data Onehot, that is to say, I used 0-999 numbers to represent the image category, and chose'sparse_categorization_crossentropy'in Loss, which seems to be ineffective in training. In addition, when kernel_regularizer is specified in Keras Conv2D, it seems that Loss automatically adds regularizer. Also, if you use Keras'Batch Normalization, specify Training=TRUE.