Keras: MNIST classification

Keras implementation for MNIST classification with batch normalization and leaky ReLU.


import numpy as np
import time
import pandas as pd
import matplotlib.pyplot as plt

from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, LeakyReLU, Activation, Flatten
from keras.optimizers import Adam
from keras.layers.normalization import BatchNormalization
from keras.utils import np_utils
from keras.layers import Conv2D, MaxPooling2D
from keras.utils import plot_model
from keras.backend import set_learning_phase, learning_phase

def build_classifier(input_shape, num_classes, num_pooling):
"""
input_shape = (image_width, image_height, 1)
"""
# settings
kernel_size = (3, 3)
pool_size = (2, 2)
num_featuremaps = 32
size_featurevector = 1024

# Three steps to create a CNN
# 1. Convolution
# 2. Activation
# 3. Pooling
# Repeat Steps 1,2,3 for adding more hidden layers

# 4. After that make a fully connected network
# This fully connected network gives ability to the CNN
# to classify the samples

model = Sequential()

# add convolution blocks num_pooling times for featuremap extraction
for block in range(num_pooling):
if block == 0:
model.add(Conv2D(num_featuremaps, kernel_size=kernel_size, padding='same', activation='linear', input_shape=input_shape))
else:
num_featuremaps *= 2
model.add(Conv2D(num_featuremaps, kernel_size=kernel_size, padding='same', activation='linear'))
model.add(BatchNormalization())
model.add(LeakyReLU(0.2))

model.add(Conv2D(num_featuremaps, kernel_size=kernel_size, padding='same', activation='linear'))
model.add(BatchNormalization())
model.add(LeakyReLU(0.2))

model.add(MaxPooling2D(pool_size=pool_size))

model.add(Flatten())

# add fully connected layers for classification
for block in range(num_pooling):
model.add(Dense(size_featurevector, activation='linear'))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization())
model.add(Dropout(0.5))

model.add(Dense(num_classes))

model.add(Activation('softmax'))

return model

 

def demo_load_model():
num_pooling = 3
input_shapes = [(64, 64, 1), (128, 128, 1), (256, 256, 1)]
for input_shape in input_shapes:
model = build_pointset_classifier(input_shape,
100,
num_pooling)

filename = 'pointset_classifier_%s.png' % str(input_shape[0])
plot_model(model,
to_file=filename,
show_shapes=True)

 

if __name__ == '__main__':
""" settings
"""
img_width = 28
img_height = 28
input_shape = (img_width, img_height, 1)
num_classes = 10
num_epoch = 3
size_batch = 100

""" load data
"""
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(x_train.shape[0], img_width, img_height, 1)
x_train = x_train.astype('float32')
x_train /= 255.0
y_train = np_utils.to_categorical(y_train, num_classes)

x_test = x_test.reshape(x_test.shape[0], img_width, img_height, 1)
x_test = x_test.astype('float32')
x_test /= 255.0
y_test = np_utils.to_categorical(y_test, num_classes)

""" build neural network
"""
num_pooling = 2
set_learning_phase(1)
model = build_classifier(input_shape,
num_classes,
num_pooling)

filename = 'classifier_%s.png' % str(input_shape[0])
plot_model(model,
to_file=filename,
show_shapes=True)

model.compile(loss='categorical_crossentropy',
optimizer=Adam(),
metrics=['accuracy'])
model.summary()
""" train the model
"""
time_begin = time.clock()
history = model.fit(x_train, y_train,
validation_data=(x_test, y_test),
epochs=num_epoch,
batch_size=size_batch,
verbose=1)
print('Time elapsed: %.0f' % (time.clock() - time_begin))

""" valid the model
"""
set_learning_phase(0)
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
Advertisements

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s