Keras-Simple Convolutional Neural Network

Keywords: network github

Simple CNN Implementation

Articles Catalogue

brief introduction

  • Like Pytorch before, mnist handwritten data set is used to test the feature extraction ability of convolutional neural network.

step

  • Getting data sets
    • keras comes with this kind of training data set.
    • Code
      •   from keras.datasets import mnist
          import matplotlib.pyplot as plt
          %matplotlib inline
          
          (x_train, y_train), (x_valid, y_valid) = mnist.load_data()
          print(x_train.shape, y_train.shape, x_valid.shape, y_valid.shape)
          x_train = x_train.reshape(-1, 28, 28) / 255
          x_valid = x_valid.reshape(-1, 28, 28) / 255
          # Visual data
          plt.figure(figsize=(12, 8))
          for i in range(10):
          	plt.subplot(2, 5, i+1)
          	plt.title("label:{}".format(y_train[i]))
          	plt.imshow(x_train[i], cmap='gray')
          plt.show()
        
    • Demonstration effect
  • Modeling
    • Two convolution layers are used to extract parameters (which can be understood as narrowing the length and width of the image and extracting features from the height).
    • Code
      •   # Building models
          from keras.models import Sequential
          from  keras.layers import Convolution2D, MaxPooling2D, Activation, Flatten, Dense
          from keras.optimizers import Adam
          
          model = Sequential()
          
          model.add(Convolution2D(
          	batch_input_shape=(None, 28, 28, 1), # Input data dimension
          	filters=32,  # Number of convolution kernels
          	kernel_size=3,  # Convolutional Kernel Size
          	strides=1,  # step
          	padding='same',  # (3-1)/2
          	data_format='channels_last'  #  Channel location, note that keras and torch are different, the general channel is at the end
          ))  # Add a convolution layer and output (28, 28, 32)
          model.add(Activation('relu'))  # Add activation function
          model.add(MaxPooling2D(pool_size=2, strides=2, padding='same', data_format='channels_last',))  # Output (14, 14, 32)
          
          model.add(Convolution2D(64, 3, strides=1, padding='same', data_format='channels_last'))
          model.add(Activation('relu'))
          model.add(MaxPooling2D(2, 2, 'same', data_format='channels_last'))  # Output (8, 8, 64)
          
          model.add(Flatten())
          model.add(Dense(1024))
          model.add(Activation('relu'))  # (1024)
          
          model.add(Dense(10))
          model.add(Activation('softmax'))  # (10) Here is the probability.
          
          model.compile(optimizer=Adam(lr=1e-4), loss='categorical_crossentropy', metrics=['accuracy'])
          
          model.summary()
        
    • Visual model structure
      • As you can see, there are still many training parameters.
  • Training process
    • There are only fewer rounds of training here (to prevent over-fitting)
    • Code
      •   # Training model
          history = model.fit(x_train.reshape(-1, 28, 28, 1), y_train, batch_size=64, epochs=10, validation_split=0.2, shuffle=True, verbose=True)
        
    • Training Visualization
  • Validation Set Evaluation
    • Code
      •   loss, accuracy = model.evaluate(x_valid.reshape(-1, 28, 28, 1), y_valid)
          print(loss, accuracy)
          
          result = model.predict(x_valid[:10].reshape(-1, 28, 28, 1))
          plt.figure(figsize=(12, 8))
          for i in range(10):
          	plt.subplot(2, 5, i+1)
          	plt.imshow(x_valid[i], cmap='gray')
          	plt.title("true:{}pred:{}".format(np.argmax(y_valid[i], axis=0), np.argmax(result[i], axis=0)))
          plt.show()
        
    • Visualization results
      • On the real verification set, the accuracy rate is 0.9905, which is good.

Supplementary Notes

  • This case uses Keras framework, specific API reference official documents.
  • This kind of framework cases all use code and effect to speak, about the principle of neural network can be seen in my other blog. See Specific Code My Github Welcome to star or fork.
  • Blog Synchronization to Personal Blog Website Welcome to view.

Posted by RobReid on Wed, 09 Oct 2019 09:47:54 -0700