Image Recognition Based on ImageNet by Tensorflow Estimator

Tensorflow has introduced Estimator since version 1.3, and with the evolution of the version, support for this advanced API programming method is increasing, and support for multi-GPU training can be easily realized on Estimator. In my previous blog, I used the low-level API to build and train the model. This advantage is more flexible and can understand the underlying details of the model, but the disadvantage is that the code is large and cumbersome, and many details need to be implemented by myself. For this reason, I tried to use the advanced API in the latest version of TensorFlow 1.14 to test whether it was really easy to use and to achieve the same performance as the low-level API.

I'm testing based on ImageNet's image classification data. Imagenet's data preparation can be found in my previous blog. Specific code is as follows, which includes two models, one is the pre-training model Darknet 53 used in Yolo V3, and the other is Alexnet.

import tensorflow as tf
import horovod.tensorflow as hvd
import os
import random
import time
import numpy as np
from absl import app as absl_app

imageWidth = 224
imageHeight = 224
imageDepth = 3
batch_size = 32
resize_min = 256

train_files_names = os.listdir('/data/AI/train_tf/')
train_files = ['/data/AI/train_tf/'+item for item in train_files_names]
valid_files_names = os.listdir('/data/AI/valid_tf/')
valid_files = ['/data/AI/valid_tf/'+item for item in valid_files_names]

# Parse TFRECORD and distort the image for train
def _parse_function(example_proto):
    features = {"image": tf.FixedLenFeature([], tf.string, default_value=""),
                "height": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
                "width": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
                "channels": tf.FixedLenFeature([1], tf.int64, default_value=[3]),
                "colorspace": tf.FixedLenFeature([], tf.string, default_value=""),
                "img_format": tf.FixedLenFeature([], tf.string, default_value=""),
                "label": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
                "bbox_xmin": tf.VarLenFeature(tf.float32),
                "bbox_xmax": tf.VarLenFeature(tf.float32),
                "bbox_ymin": tf.VarLenFeature(tf.float32),
                "bbox_ymax": tf.VarLenFeature(tf.float32),
                "text": tf.FixedLenFeature([], tf.string, default_value=""),
                "filename": tf.FixedLenFeature([], tf.string, default_value="")
    parsed_features = tf.parse_single_example(example_proto, features)
    image_decoded = tf.image.decode_jpeg(parsed_features["image"], channels=3)
    # Random resize the image 
    shape = tf.shape(image_decoded)
    height, width = shape[0], shape[1]
    resized_height, resized_width = tf.cond(height<width,
        lambda: (resize_min, tf.cast(tf.multiply(tf.cast(width, tf.float64),tf.divide(resize_min,height)), tf.int32)),
        lambda: (tf.cast(tf.multiply(tf.cast(height, tf.float64),tf.divide(resize_min,width)), tf.int32), resize_min))
    image_float = tf.image.convert_image_dtype(image_decoded, tf.float32)
    resized = tf.image.resize_images(image_float, [resized_height, resized_width])
    # Random crop from the resized image
    cropped = tf.random_crop(resized, [imageHeight, imageWidth, 3])
    # Flip to add a little more random distortion in.
    flipped = tf.image.random_flip_left_right(cropped)
    # Standardization the image
    #image_train = flipped
    image_train = tf.image.per_image_standardization(flipped)
    #features = {'images': image_train}
    return image_train, tf.one_hot(parsed_features["label"][0], 1000)

def train_input_fn():
    dataset_train =
    dataset_train =, num_parallel_calls=4)
    dataset_train = dataset_train.repeat(10)
    dataset_train = dataset_train.batch(batch_size)
    dataset_train = dataset_train.prefetch(batch_size)
    return dataset_train

def _parse_test_function(example_proto):
    features = {"image": tf.FixedLenFeature([], tf.string, default_value=""),
                "height": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
                "width": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
                "channels": tf.FixedLenFeature([1], tf.int64, default_value=[3]),
                "colorspace": tf.FixedLenFeature([], tf.string, default_value=""),
                "img_format": tf.FixedLenFeature([], tf.string, default_value=""),
                "label": tf.FixedLenFeature([1], tf.int64, default_value=[0]),
                "bbox_xmin": tf.VarLenFeature(tf.float32),
                "bbox_xmax": tf.VarLenFeature(tf.float32),
                "bbox_ymin": tf.VarLenFeature(tf.float32),
                "bbox_ymax": tf.VarLenFeature(tf.float32),
                "text": tf.FixedLenFeature([], tf.string, default_value=""),
                "filename": tf.FixedLenFeature([], tf.string, default_value="")
    parsed_features = tf.parse_single_example(example_proto, features)
    image_decoded = tf.image.decode_jpeg(parsed_features["image"], channels=3)
    shape = tf.shape(image_decoded)
    height, width = shape[0], shape[1]
    resized_height, resized_width = tf.cond(height<width,
        lambda: (resize_min, tf.cast(tf.multiply(tf.cast(width, tf.float64),tf.divide(resize_min,height)), tf.int32)),
        lambda: (tf.cast(tf.multiply(tf.cast(height, tf.float64),tf.divide(resize_min,width)), tf.int32), resize_min))
    image_float = tf.image.convert_image_dtype(image_decoded, tf.float32)
    image_resized = tf.image.resize_images(image_float, [resized_height, resized_width])
    # calculate how many to be center crop
    shape = tf.shape(image_resized)  
    height, width = shape[0], shape[1]
    amount_to_be_cropped_h = (height - imageHeight)
    crop_top = amount_to_be_cropped_h // 2
    amount_to_be_cropped_w = (width - imageWidth)
    crop_left = amount_to_be_cropped_w // 2
    image_cropped = tf.slice(image_resized, [crop_top, crop_left, 0], [imageHeight, imageWidth, -1])
    image_valid = tf.image.per_image_standardization(image_cropped)
    #features = {'images': image_valid}
    return image_valid, tf.one_hot(parsed_features["label"][0], 1000)

def val_input_fn():
    dataset_valid =
    dataset_valid =, num_parallel_calls=4)
    dataset_valid = dataset_valid.batch(batch_size)
    dataset_valid = dataset_valid.prefetch(batch_size)
    return dataset_valid

def darknet_53():
    image = tf.keras.Input(shape=(imageHeight,imageWidth,3))
    l = tf.keras.layers
    def _conv(inputs, filters, kernel_size, strides, padding, bias=False, normalize=True, activation='leaky_relu'):
        output = inputs
        padding_str = 'same'
        if padding>0:
            output = l.ZeroPadding2D(padding=(padding, padding))(output)
            padding_str = 'valid'
        output = l.Conv2D(filters, kernel_size, strides, padding_str, use_bias=bias, \
                          kernel_initializer='he_normal', \
        if normalize:
            output = l.BatchNormalization(axis=3)(output, training=True)
        if activation=='leaky_relu':
            output = l.LeakyReLU(alpha=0.1)(output)
        return output

    def _residual(inputs, filters):
        output = _conv(inputs, filters, 1, (1,1), 0)
        output = _conv(output, filters*2, 3, (1,1), 1)
        output = tf.add(inputs, output)
        return output

    net = _conv(image, 32, 3, (1,1), 1)
    net = _conv(net, 64, 3, (2,2), 1)
    net = _residual(net, 32)
    net = _conv(net, 128, 3, (2,2), 1)
    for _ in range(2):
        net = _residual(net, 64)
    net = _conv(net, 256, 3, (2,2), 1)
    for _ in range(8):
        net = _residual(net, 128)
    #add route1
    net = _conv(net, 512, 3, (2,2), 1)
    for _ in range(8):
        net = _residual(net, 256)
    #add route2
    net = _conv(net, 1024, 3, (2,2), 1)
    for _ in range(4):
        net = _residual(net, 512)
    #add route3
    net = l.GlobalAveragePooling2D()(net)
    net = l.Dense(1000, kernel_initializer=tf.initializers.truncated_normal(stddev=1e-1))(net)
    net = tf.keras.activations.softmax(net)
    model = tf.keras.Model(inputs=image, outputs=net)
    return model

def alexnet():
    image = tf.keras.Input(shape=(imageHeight,imageWidth,3))
    l = tf.keras.layers
    def _conv(inputs, filters, kernel_size, strides, padding, bias=True):
        output = inputs
        padding_str = 'same'
        if padding>0:
            output = l.ZeroPadding2D(padding=(padding, padding))(output)
            padding_str = 'valid'
        output = l.Conv2D(filters, kernel_size, strides, padding_str, use_bias=bias, \
                          kernel_initializer=tf.initializers.truncated_normal(stddev=1e-1), \
        output = l.BatchNormalization(axis=3)(output, training=True)
        output = l.ReLU()(output)
        return output
    net = _conv(image, 96, 11, 4, 0)
    net = l.MaxPool2D(3, 2)(net)
    net = _conv(net, 256, 5, 1, 0)
    net = l.MaxPool2D(3, 2)(net)
    net = _conv(net, 384, 3, 1, 0)
    net = _conv(net, 384, 3, 1, 0)
    net = _conv(net, 256, 3, 1, 0)
    net = l.MaxPool2D(3, 2)(net)
    net = l.Flatten()(net)
    net = l.Dense(4096, kernel_initializer=tf.initializers.truncated_normal(stddev=1/4096))(net)
    net = l.Dense(4096, kernel_initializer=tf.initializers.truncated_normal(stddev=1/4096))(net)
    net = l.Dense(1000, kernel_initializer=tf.initializers.truncated_normal(stddev=1/1000))(net)
    net = tf.keras.activations.softmax(net)
    model = tf.keras.Model(inputs=image, outputs=net)
    return model
def my_loss(y_true, y_pred):
    l2_loss = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
    crossentropy = tf.keras.losses.categorical_crossentropy(y_true=y_true, y_pred=y_pred)
    loss = tf.add(l2_loss, crossentropy)/batch_size
    return loss

def main(_):
    epoch_steps = 1281167/batch_size
    boundaries = [epoch_steps*5, epoch_steps*8]
    values = [0.01, 0.001, 0.0001]
    learning_rate_fn = tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries, values)

    model = darknet_53()
    model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate = learning_rate_fn, momentum=0.9), \
                  loss='categorical_crossentropy', \
                  metrics=['categorical_accuracy', 'top_k_categorical_accuracy'])
    est_model = tf.keras.estimator.model_to_estimator(model, model_dir='imagenet_model_darknet53/')
    for _ in range(5):
        est_model.train(input_fn=train_input_fn, steps=5000)
        eval_results = est_model.evaluate(input_fn=val_input_fn)
        print('\nEvaluation:\n\t%s\n' % eval_results)
if __name__ == "__main__":

As can be seen from the above code, it is still very convenient to build models and train with Estimator. However, in the course of testing, I found that if I did not encode Labels in the input data Onehot, that is to say, I used 0-999 numbers to represent the image category, and chose'sparse_categorization_crossentropy'in Loss, which seems to be ineffective in training. In addition, when kernel_regularizer is specified in Keras Conv2D, it seems that Loss automatically adds regularizer. Also, if you use Keras'Batch Normalization, specify Training=TRUE.

