RNN, LSTM in TensorFlow for NLP in Python

We covered RNN for MNIST data, and it is actually even more suitable for NLP projects.
You can find more details on Valentino Zocca, Gianmario Spacagna, Daniel Slater’s book Python Deep Learning.

from __future__ import print_function, division
# -*- coding: utf-8 -*-
###"War and peace" contains more than 500,000 words, making it the perfect 
###candidate for our small example. Since it's in the public domain, "War and 
###peace" can be downloaded as plain text for free from Project Gutenberg.

###1. Load Data
import re
import codecs

filepath = 'war_and_peace.txt'  
out_file = 'wap.txt'   
with codecs.open(filepath, encoding='utf-8', mode='r') as f_input:
    book_str = f_input.read()

###2. Preprocess
###We will remove the Gutenberg license, book information, and table of 
###contents. Next, we will strip out newlines in the middle of sentences 
###and reduce the maximum number of consecutive newlines allowed to two.
New_Para = re.compile(r'(\S)\n(\S)')
Multi_New_Line = re.compile(r'(\n)(\n)+')

book_str = New_Para.sub('\g<1> \g<2>', book_str)
book_str = Multi_New_Line.sub('\n\n', book_str)

with codecs.open(out_file, encoding='utf-8', mode='w')as f_output:

###To feed the data into the network, we will have to convert it into a 
###numerical format. Each character will be associated with an integer

###3. Load the processed data again with functions
###3.1 Define the class of data reader
from six.moves import range
import numpy as np
class DataReader(object):
    """Data reader used for training language model."""
    def __init__(self, filepath, batch_length, batch_size):
        self.batch_length = batch_length
        self.batch_size = batch_size
        # Read data into string
        with codecs.open(filepath, encoding='utf-8', mode='r') as f:
            self.data_str = f.read().encode('utf-8')
        self.data_length = len(self.data_str)
        print('data_length: ', self.data_length)
        # Create a list of characters, indices are class indices for softmax
        char_set = set()
        for ch in self.data_str:
        self.char_list = sorted(list(char_set))
        print('char_list: ', len(self.char_list), self.char_list)
        # Create reverse mapping to look up the index based on the character
        self.char_dict = {val: idx for idx, val in enumerate(self.char_list)}
        print('char_dict: ', self.char_dict)
        # Initalise random start indices

    def reset_indices(self):
        self.start_idxs = np.random.random_integers(
            0, self.data_length, self.batch_size)

    def get_sample(self, start_idx, length):
        # Get a sample and wrap around the data string
        return [self.char_dict[self.data_str[i % self.data_length]]
                for i in range(start_idx, start_idx+length)]

    def get_input_target_sample(self, start_idx):
        sample = self.get_sample(start_idx, self.batch_length+1)
        inpt = sample[0:self.batch_length]
        trgt = sample[1:self.batch_length+1]
        return inpt, trgt

    def get_batch(self, start_idxs):
        input_batch = np.zeros((self.batch_size, self.batch_length),
        target_batch = np.zeros((self.batch_size, self.batch_length),
        for i, start_idx in enumerate(start_idxs):
            inpt, trgt = self.get_input_target_sample(start_idx)
            input_batch[i, :] = inpt
            target_batch[i, :] = trgt
        return input_batch, target_batch

    def __iter__(self):
        while True:
            input_batch, target_batch = self.get_batch(self.start_idxs)
            self.start_idxs = (
                self.start_idxs + self.batch_length) % self.data_length
            yield input_batch, target_batch

###3.2 Condcut the main function
def main():
    filepath = 'wap.txt'
    batch_length = 10
    batch_size = 2
    reader = DataReader(filepath, batch_length, batch_size)
    s = 'As in the question of astronomy then, so in the question of history now,'
    print(reader.char_dict[c] for c in s)

if __name__ == "__main__":

###4. Build the RNN Model
import time
import codecs
import locale
import sys
import tensorflow as tf
from tensorflow.python.util import nest

class Model(object):
    """RNN language model."""
    def __init__(self, batch_size, sequence_length, lstm_sizes, dropout,
                 labels, save_path):
        self.batch_size = batch_size
        self.sequence_length = sequence_length
        self.lstm_sizes = lstm_sizes
        self.labels = labels
        self.label_map = {val: idx for idx, val in enumerate(labels)}
        self.number_of_characters = len(labels)
        self.save_path = save_path
        self.dropout = dropout

    def init_graph(self):
        # Variable sequence length
        self.inputs = tf.placeholder(
            tf.int32, [self.batch_size, self.sequence_length])
        self.targets = tf.placeholder(
            tf.int32, [self.batch_size, self.sequence_length])
        self.saver = tf.train.Saver(tf.trainable_variables())

    def init_architecture(self):
        # Define a multilayer LSTM cell
        self.one_hot_inputs = tf.one_hot(
            self.inputs, depth=self.number_of_characters)
        cell_list = [tf.contrib.rnn.LSTMCell(lstm_size, state_is_tuple=True)
                     for lstm_size in self.lstm_sizes]
        self.multi_cell_lstm = tf.contrib.rnn.MultiRNNCell(
            cell_list, state_is_tuple=True)
        # Initial state of the LSTM memory.
        # Keep state in graph memory to use between batches
        self.initial_state = self.multi_cell_lstm.zero_state(
            self.batch_size, tf.float32)
        # Convert to variables so that the state can be stored between batches
        # Note that LSTM states is a tuple of tensors, this structure has to be
        # re-created in order to use as LSTM state.
        self.state_variables = nest.pack_sequence_as(
            [tf.Variable(var, trainable=False)
             for var in nest.flatten(self.initial_state)])
        # Define the rnn through time
        lstm_output, final_state = tf.nn.dynamic_rnn(
            cell=self.multi_cell_lstm, inputs=self.one_hot_inputs,
        # Force the initial state to be set to the new state for the next batch
        # before returning the output
        store_states = [
            for (state_variable, new_state) in zip(
        with tf.control_dependencies(store_states):
            lstm_output = tf.identity(lstm_output)
        # Reshape so that we can apply the linear transformation to all outputs
        output_flat = tf.reshape(lstm_output, (-1, self.lstm_sizes[-1]))
        # Define output layer
        self.logit_weights = tf.Variable(
                (self.lstm_sizes[-1], self.number_of_characters), stddev=0.01),
        self.logit_bias = tf.Variable(
            tf.zeros((self.number_of_characters)), name='logit_bias')
        # Apply last layer transformation
        self.logits_flat = tf.matmul(
            output_flat, self.logit_weights) + self.logit_bias
        probabilities_flat = tf.nn.softmax(self.logits_flat)
        self.probabilities = tf.reshape(
            (self.batch_size, -1, self.number_of_characters))

    def init_train_op(self, optimizer):
        # Flatten the targets to be compatible with the flattened logits
        targets_flat = tf.reshape(self.targets, (-1, ))
        # Get the loss over all outputs
        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=self.logits_flat, labels=targets_flat)
        self.loss = tf.reduce_mean(loss)
        trainable_variables = tf.trainable_variables()
        gradients = tf.gradients(loss, trainable_variables)
        gradients, _ = tf.clip_by_global_norm(
            gradients, 5)
        self.train_op = optimizer.apply_gradients(
            zip(gradients, trainable_variables))

    def sample(self, session, prime_string, sample_length):
        # Prime state
        print('prime_string: ', prime_string)
        for character in prime_string:
            character_idx = self.label_map[character]
            out = session.run(
                feed_dict={self.inputs: np.asarray([[character_idx]])})
            sample_label = np.random.choice(
                self.labels, size=(1),  p=out[0, 0])
        output_sample = prime_string
        print('start sampling')
        # Sample for sample_length steps
        for _ in range(sample_length):
            sample_label = np.random.choice(
                self.labels, size=(1),  p=out[0, 0])[0]
            output_sample += sample_label
            sample_idx = self.label_map[sample_label]
            out = session.run(
                feed_dict={self.inputs: np.asarray([[sample_idx]])})
        return output_sample

    def reset_state(self, session):
        for state in nest.flatten(self.state_variables):

    def save(self, sess):
        self.saver.save(sess, self.save_path)

    def restore(self, sess):
        self.saver.restore(sess, self.save_path)

def train_and_sample(minibatch_iterations, restore):
    batch_size = 64
    lstm_sizes = [512, 512]
    batch_len = 100
    learning_rate = 2e-3

    filepath = './wap.txt'

    data_feed = DataReader(
         filepath, batch_len, batch_size)
    labels = data_feed.char_list
    print('labels: ', labels)

    save_path = './model.tf'
    model = Model(
        batch_size, batch_len, lstm_sizes, 0.8, labels,
    optimizer = tf.train.AdamOptimizer(learning_rate)

    init_op = tf.global_variables_initializer()
    with tf.Session() as sess:
        if restore:
            print('Restoring model')
        start_time = time.time()
        for i in range(minibatch_iterations):
            input_batch, target_batch = next(iter(data_feed))
            loss, _ = sess.run(
                [model.loss, model.train_op],
                    model.inputs: input_batch, model.targets: target_batch})
            if i % 50 == 0 and i != 0:
                print('i: ', i)
                duration = time.time() - start_time
                print('loss: {} ({} sec.)'.format(loss, duration))
                start_time = time.time()
            if i % 1000 == 0 and i != 0:
            if i % 100 == 0 and i != 0:
                print('Reset initial state')
            if i % 1000 == 0 and i != 0:
                print('Reset minibatch feeder')

    print('\n''sampling after {} iterations'.format(minibatch_iterations))
    model = Model(
        1, None, lstm_sizes, 1.0, labels, save_path)
    init_op = tf.global_variables_initializer()
    with tf.Session() as sess:
        print('\n''Sample 1:')
        sample = model.sample(
            sess, prime_string=u'\n''This feeling was ', sample_length=500)
        print(u'sample: \n{}'.format(sample))
        print('\n''Sample 2:')
        sample = model.sample(
            sess, prime_string=u'She was born in the year ', sample_length=500)
        print(u'sample: \n{}'.format(sample))
        print('\n''Sample 3:')
        sample = model.sample(
            sess, prime_string=u'The meaning of this all is ',
        print(u'sample: \n{}'.format(sample))
        print('\n''Sample 4:')
        sample = model.sample(
            prime_string=u'In the midst of a conversation on political matters Anna Pávlovna burst out:,',
        print(u'sample: \n{}'.format(sample))
        print('\n''Sample 5:')
        sample = model.sample(
            sess, prime_string=u'\n\nCHAPTER X\n\n',
        print(u'sample: \n{}'.format(sample))
        print('\n''Sample 5:')
        sample = model.sample(
            sess, prime_string=u'"If only you knew,"',
        print(u'sample: \n{}'.format(sample))

def main():
    total_iterations = 500
    print('\n''Train for {}'.format(500))
    print('Total iters: {}'.format(total_iterations))
    train_and_sample(500, restore=False)
    for i in [500, 1000, 3000, 5000, 10000, 30000, 50000, 100000, 300000]:
        total_iterations += i
        print('\n''Train for {}'.format(i))
        print('Total iters: {}'.format(total_iterations))
        train_and_sample(i, restore=True)

if __name__ == "__main__":


Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google+ photo

You are commenting using your Google+ account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )


Connecting to %s