RNN in TensorFlow in Python&R, with MNIST

Though it is more convenient to conduct TensorFlow framework in python, we also talked about how to apply Tensorflow in R here:https://charleshsliao.wordpress.com/tag/tensorflow/

We will talk about how to apply Recurrent neural network in TensorFlow on both of python and R. RNN might not be the best algorithm to deal with MNIST but this can be a nice example for RNN application.

in R:

#1. We load the data
library(tensorflow)
mnist<-tf$contrib$learn$datasets$mnist$load_mnist(train_dir = "MNIST-data")

#2.Identify Essential Parameters
Input<-28L
Steps<-28L
Hidden<-128L
Classes<-10L
batchSize<-128L

#3.Set up placeholders of TensorFlow
x<-tf$placeholder(tf$float32,shape(NULL,Steps,Input))
y<-tf$placeholder(tf$float32,shape(NULL,Classes))

#4.Process the future data
xX<-tf$transpose(x,shape(1,0,2))
xX<-tf$reshape(xX,shape(-1,Input))
xX<-tf$split(xX,Steps,0L)

#5.Set up weights and bias for layers and classes
weights<-tf$Variable(tf$random_normal(shape(Hidden,Classes)))
bias<-tf$Variable(tf$random_normal(shape(Classes)))

#6.Prepare cells and RNN framework(the rnn is in contrib now)
lstmCell<-tf$contrib$rnn$BasicLSTMCell(Hidden,forget_bias = 1.0,state_is_tuple = T)
result<-tf$contrib$rnn$static_rnn(lstmCell,xX,dtype=tf$float32)

#7.Construct the result
lastCell<-length(result[1][[1]])
pred<-tf$matmul(result[1][[1]][[lastCell-1]],weights)+bias
cost<-tf$reduce_mean(tf$nn$softmax_cross_entropy_with_logits(logits=pred,labels=y))
optimizer<-tf$train$AdamOptimizer(learning_rate = 0.001)$minimize(cost)

#8.Evaluate
correct_pred<-tf$equal(tf$argmax(pred,1L),tf$argmax(y,1L))
accuracy<-tf$reduce_mean(tf$cast(correct_pred,tf$float32))

#9.Initiate the session and run
sess<-tf$Session()
sess$run(tf$global_variables_initializer())
for (step in 1:500){
  batches<-mnist$train$next_batch(100)
  batch_xs<-sess$run(tf$reshape(batches[[1]],shape(batchSize,Steps,Input)))
  batch_ys<-batches[[2]]
  sess$run(optimizer,feed_dict=dict(x=batch_xs,y=batch_ys))
  if(step%%50==0){
    acc<-sess$run(accuracy,feed_dict=dict(x=batch_xs,y=batch_ys))
    loss<-sess$run(cost,feed_dict=dict(x=batch_xs,y=batch_ys))
    print(paste("Accracy: ",round(acc,4),"loss: ",round(loss,4),"step(",step,")"))
  }
}

We have similar approach in Python:


import tensorflow as tf 
from tensorflow.contrib import rnn
import numpy as np 

from tensorflow.examples.tutorials.mnist import input_data

mnist=input_data.read_data_sets("MNIST_data/",one_hot=True)

learning_rate = 0.001
training_iters = 100000
batch_size = 128
display_step = 10

# Network Parameters
n_input = 28 # MNIST data input (img shape: 28*28)
n_steps = 28 # timesteps
n_hidden = 128 # hidden layer num of features
n_classes = 10 # MNIST total classes (0-9 digits)

x = tf.placeholder("float", [None, n_steps, n_input])
y = tf.placeholder("float", [None, n_classes])

weights = {
    'out': tf.Variable(tf.random_normal([n_hidden, n_classes]))
}
biases = {
    'out': tf.Variable(tf.random_normal([n_classes]))
}

def RNN(x, weights, biases):
	x = tf.unstack(x, n_steps, 1)
	lstm_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
	outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)
	return tf.matmul(outputs[-1], weights['out']) + biases['out']

pred = RNN(x, weights, biases)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)

correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    step = 1
    # Keep training until reach max iterations
    while step * batch_size < training_iters:
        batch_x, batch_y = mnist.train.next_batch(batch_size)
        # Reshape data to get 28 seq of 28 elements
        batch_x = batch_x.reshape((batch_size, n_steps, n_input))
        # Run optimization op (backprop)
        sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})
        if step % display_step == 0:
            # Calculate batch accuracy
            acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})
            # Calculate batch loss
            loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})
            print("Iter " + str(step*batch_size) + ", Minibatch Loss= " + \
                  "{:.6f}".format(loss) + ", Training Accuracy= " + \
                  "{:.5f}".format(acc))
        step += 1
    print("Optimization Finished!")
        # Calculate accuracy for 128 mnist test images
    test_len = 128
    test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))
    test_label = mnist.test.labels[:test_len]
    print("Testing Accuracy:", \
        sess.run(accuracy, feed_dict={x: test_data, y: test_label}))

#Iter 1280, Minibatch Loss= 1.944456, Training Accuracy= 0.30469
#Iter 2560, Minibatch Loss= 1.516782, Training Accuracy= 0.52344
#Iter 3840, Minibatch Loss= 1.266199, Training Accuracy= 0.57031
#Iter 5120, Minibatch Loss= 0.961907, Training Accuracy= 0.67188
#Iter 6400, Minibatch Loss= 0.973848, Training Accuracy= 0.61719
#Iter 7680, Minibatch Loss= 0.609262, Training Accuracy= 0.75781
#Iter 8960, Minibatch Loss= 0.678220, Training Accuracy= 0.75000
#Iter 10240, Minibatch Loss= 0.533957, Training Accuracy= 0.80469
#Iter 11520, Minibatch Loss= 0.490650, Training Accuracy= 0.81250
#Iter 12800, Minibatch Loss= 0.431325, Training Accuracy= 0.87500
#Iter 14080, Minibatch Loss= 0.445506, Training Accuracy= 0.89062
#Iter 15360, Minibatch Loss= 0.354876, Training Accuracy= 0.85938
#Iter 16640, Minibatch Loss= 0.271417, Training Accuracy= 0.90625
#Iter 17920, Minibatch Loss= 0.296218, Training Accuracy= 0.89844
#Iter 19200, Minibatch Loss= 0.346756, Training Accuracy= 0.89062
#Iter 20480, Minibatch Loss= 0.204771, Training Accuracy= 0.92969
#Iter 21760, Minibatch Loss= 0.401685, Training Accuracy= 0.85156
#Iter 23040, Minibatch Loss= 0.328447, Training Accuracy= 0.92188
#Iter 24320, Minibatch Loss= 0.194539, Training Accuracy= 0.91406
#Iter 25600, Minibatch Loss= 0.249565, Training Accuracy= 0.96094
#Iter 26880, Minibatch Loss= 0.205035, Training Accuracy= 0.93750
#Iter 28160, Minibatch Loss= 0.146734, Training Accuracy= 0.93750
#Iter 29440, Minibatch Loss= 0.251819, Training Accuracy= 0.89844
#Iter 30720, Minibatch Loss= 0.360024, Training Accuracy= 0.88281
#Iter 32000, Minibatch Loss= 0.213888, Training Accuracy= 0.92969
#Iter 33280, Minibatch Loss= 0.163629, Training Accuracy= 0.92188
#Iter 34560, Minibatch Loss= 0.267468, Training Accuracy= 0.91406
#Iter 35840, Minibatch Loss= 0.268270, Training Accuracy= 0.92969
#Iter 37120, Minibatch Loss= 0.236008, Training Accuracy= 0.91406
#Iter 38400, Minibatch Loss= 0.145467, Training Accuracy= 0.96094
#Iter 39680, Minibatch Loss= 0.171247, Training Accuracy= 0.93750
#Iter 40960, Minibatch Loss= 0.158264, Training Accuracy= 0.94531
#Iter 42240, Minibatch Loss= 0.161707, Training Accuracy= 0.94531
#Iter 43520, Minibatch Loss= 0.153343, Training Accuracy= 0.96094
#Iter 44800, Minibatch Loss= 0.136208, Training Accuracy= 0.95312
#Iter 46080, Minibatch Loss= 0.135175, Training Accuracy= 0.95312
#Iter 47360, Minibatch Loss= 0.147518, Training Accuracy= 0.94531
#Iter 48640, Minibatch Loss= 0.050638, Training Accuracy= 0.99219
#Iter 49920, Minibatch Loss= 0.107003, Training Accuracy= 0.96094
#Iter 51200, Minibatch Loss= 0.316539, Training Accuracy= 0.90625
#Iter 52480, Minibatch Loss= 0.141136, Training Accuracy= 0.95312
#Iter 53760, Minibatch Loss= 0.158108, Training Accuracy= 0.95312
#Iter 55040, Minibatch Loss= 0.185566, Training Accuracy= 0.93750
#Iter 56320, Minibatch Loss= 0.099082, Training Accuracy= 0.96875
#Iter 57600, Minibatch Loss= 0.122914, Training Accuracy= 0.96094
#Iter 58880, Minibatch Loss= 0.244967, Training Accuracy= 0.90625
#Iter 60160, Minibatch Loss= 0.108733, Training Accuracy= 0.96875
#Iter 61440, Minibatch Loss= 0.074805, Training Accuracy= 0.97656
#Iter 62720, Minibatch Loss= 0.114873, Training Accuracy= 0.96094
#Iter 64000, Minibatch Loss= 0.097373, Training Accuracy= 0.96094
#Iter 65280, Minibatch Loss= 0.117917, Training Accuracy= 0.96094
#Iter 66560, Minibatch Loss= 0.140607, Training Accuracy= 0.96094
#Iter 67840, Minibatch Loss= 0.170446, Training Accuracy= 0.94531
#Iter 69120, Minibatch Loss= 0.052542, Training Accuracy= 0.97656
#Iter 70400, Minibatch Loss= 0.072579, Training Accuracy= 0.96875
#Iter 71680, Minibatch Loss= 0.154582, Training Accuracy= 0.96094
#Iter 72960, Minibatch Loss= 0.137373, Training Accuracy= 0.95312
#Iter 74240, Minibatch Loss= 0.172332, Training Accuracy= 0.94531
#Iter 75520, Minibatch Loss= 0.127297, Training Accuracy= 0.96094
#Iter 76800, Minibatch Loss= 0.225298, Training Accuracy= 0.95312
#Iter 78080, Minibatch Loss= 0.112549, Training Accuracy= 0.96094
#Iter 79360, Minibatch Loss= 0.024480, Training Accuracy= 1.00000
#Iter 80640, Minibatch Loss= 0.070614, Training Accuracy= 0.96875
#Iter 81920, Minibatch Loss= 0.057122, Training Accuracy= 0.97656
#Iter 83200, Minibatch Loss= 0.205139, Training Accuracy= 0.93750
#Iter 84480, Minibatch Loss= 0.106761, Training Accuracy= 0.96875
#Iter 85760, Minibatch Loss= 0.103553, Training Accuracy= 0.96875
#Iter 87040, Minibatch Loss= 0.111675, Training Accuracy= 0.96875
#Iter 88320, Minibatch Loss= 0.067976, Training Accuracy= 0.97656
#Iter 89600, Minibatch Loss= 0.068248, Training Accuracy= 0.96875
#Iter 90880, Minibatch Loss= 0.168272, Training Accuracy= 0.92969
#Iter 92160, Minibatch Loss= 0.095124, Training Accuracy= 0.96094
#Iter 93440, Minibatch Loss= 0.152670, Training Accuracy= 0.96875
#Iter 94720, Minibatch Loss= 0.070328, Training Accuracy= 0.97656
#Iter 96000, Minibatch Loss= 0.142944, Training Accuracy= 0.96094
#Iter 97280, Minibatch Loss= 0.073793, Training Accuracy= 0.99219
#Iter 98560, Minibatch Loss= 0.065288, Training Accuracy= 0.98438
#Iter 99840, Minibatch Loss= 0.063933, Training Accuracy= 0.97656
#Optimization Finished!
#Testing Accuracy: 0.984375

Most of the Python code comes from https://github.com/aymericdamien.

Advertisements

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s