RNNs tutorial

[1]:
# we assume that we have the dynet module in your path.
import dynet as dy

An LSTM/RNN overview:

An (1-layer) RNN can be thought of as a sequence of cells, \(h_1,...,h_k\), where \(h_i\) indicates the time dimenstion.

Each cell \(h_i\) has an input \(x_i\) and an output \(r_i\). In addition to \(x_i\), cell \(h_i\) receives as input also \(r_{i-1}\).

In a deep (multi-layer) RNN, we don’t have a sequence, but a grid. That is we have several layers of sequences:

  • \(h_1^3,...,h_k^3\)
  • \(h_1^2,...,h_k^2\)
  • \(h_1^1,...h_k^1\),

Let \(r_i^j\) be the output of cell \(h_i^j\). Then:

The input to \(h_i^1\) is \(x_i\) and \(r_{i-1}^1\).

The input to \(h_i^2\) is \(r_i^1\) and \(r_{i-1}^2\), and so on.

The LSTM (RNN) Interface

RNN / LSTM / GRU follow the same interface. We have a “builder” which is in charge of creating definining the parameters for the sequence.

[2]:
pc = dy.ParameterCollection()
NUM_LAYERS=2
INPUT_DIM=50
HIDDEN_DIM=10
builder = dy.LSTMBuilder(NUM_LAYERS, INPUT_DIM, HIDDEN_DIM, pc)
# or:
# builder = dy.SimpleRNNBuilder(NUM_LAYERS, INPUT_DIM, HIDDEN_DIM, pc)

Note that when we create the builder, it adds the internal RNN parameters to the ParameterCollection. We do not need to care about them, but they will be optimized together with the rest of the network’s parameters.

[3]:
s0 = builder.initial_state()
[4]:
x1 = dy.vecInput(INPUT_DIM)
[5]:
s1=s0.add_input(x1)
y1 = s1.output()
# here, we add x1 to the RNN, and the output we get from the top is y (a HIDEN_DIM-dim vector)
[6]:
y1.npvalue().shape
[6]:
(10,)
[7]:
s2=s1.add_input(x1) # we can add another input
y2=s2.output()

If our LSTM/RNN was one layer deep, y2 would be equal to the hidden state. However, since it is 2 layers deep, y2 is only the hidden state (= output) of the last layer.

If we were to want access to the all the hidden state (the output of both the first and the last layers), we could use the .h() method, which returns a list of expressions, one for each layer:

[8]:
print(s2.h())
(expression 47/0, expression 62/0)

The same interface that we saw until now for the LSTM, holds also for the Simple RNN:

[9]:
# create a simple rnn builder
rnnbuilder=dy.SimpleRNNBuilder(NUM_LAYERS, INPUT_DIM, HIDDEN_DIM, pc)

# initialize a new graph, and a new sequence
rs0 = rnnbuilder.initial_state()

# add inputs
rs1 = rs0.add_input(x1)
ry1 = rs1.output()
print("all layers:", s1.h())
all layers: (expression 19/0, expression 32/0)
[10]:
print(s1.s())
(expression 17/0, expression 30/0, expression 19/0, expression 32/0)

To summarize, when calling .add_input(x) on an RNNState what happens is that the state creates a new RNN/LSTM column, passing it: 1. the state of the current RNN column 2. the input x

The state is then returned, and we can call it’s output() method to get the output y, which is the output at the top of the column. We can access the outputs of all the layers (not only the last one) using the .h() method of the state.

``.s()`` The internal state of the RNN may be more involved than just the outputs \(h\). This is the case for the LSTM, that keeps an extra “memory” cell, that is used when calculating \(h\), and which is also passed to the next column. To access the entire hidden state, we use the .s() method.

The output of .s() differs by the type of RNN being used. For the simple-RNN, it is the same as .h(). For the LSTM, it is more involved.

[11]:
rnn_h  = rs1.h()
rnn_s  = rs1.s()
print("RNN h:", rnn_h)
print("RNN s:", rnn_s)


lstm_h = s1.h()
lstm_s = s1.s()
print("LSTM h:", lstm_h)
print("LSTM s:", lstm_s
)
RNN h: (expression 70/0, expression 72/0)
RNN s: (expression 70/0, expression 72/0)
LSTM h: (expression 19/0, expression 32/0)
LSTM s: (expression 17/0, expression 30/0, expression 19/0, expression 32/0)

As we can see, the LSTM has two extra state expressions (one for each hidden layer) before the outputs h.

Extra options in the RNN/LSTM interface

Stack LSTM The RNN’s are shaped as a stack: we can remove the top and continue from the previous state. This is done either by remembering the previous state and continuing it with a new .add_input(), or using we can access the previous state of a given state using the .prev() method of state.

Initializing a new sequence with a given state When we call builder.initial_state(), we are assuming the state has random /0 initialization. If we want, we can specify a list of expressions that will serve as the initial state. The expected format is the same as the results of a call to .final_s(). TODO: this is not supported yet.

[12]:
s2=s1.add_input(x1)
s3=s2.add_input(x1)
s4=s3.add_input(x1)

# let's continue s3 with a new input.
s5=s3.add_input(x1)

# we now have two different sequences:
# s0,s1,s2,s3,s4
# s0,s1,s2,s3,s5
# the two sequences share parameters.

assert(s5.prev() == s3)
assert(s4.prev() == s3)

s6=s3.prev().add_input(x1)
# we now have an additional sequence:
# s0,s1,s2,s6
[13]:
s6.h()
[13]:
(expression 207/0, expression 222/0)
[14]:
s6.s()
[14]:
(expression 205/0, expression 220/0, expression 207/0, expression 222/0)

Aside: memory efficient transduction

The RNNState interface is convenient, and allows for incremental input construction. However, sometimes we know the sequence of inputs in advance, and care only about the sequence of output expressions. In this case, we can use the add_inputs(xs) method, where xs is a list of Expression.

[15]:
state = rnnbuilder.initial_state()
xs = [x1,x1,x1]
states = state.add_inputs(xs)
outputs = [s.output() for s in states]
hs =      [s.h() for s in states]
print(outputs, hs)
[expression 226/0, expression 230/0, expression 234/0] [(expression 224/0, expression 226/0), (expression 228/0, expression 230/0), (expression 232/0, expression 234/0)]

This is convenient.

What if we do not care about .s() and .h(), and do not need to access the previous vectors? In such cases we can use the transduce(xs) method instead of add_inputs(xs). transduce takes in a sequence of Expressions, and returns a sequence of Expressions. As a consequence of not returning RNNStates, trnasduce is much more memory efficient than add_inputs or a series of calls to add_input.

[16]:
state = rnnbuilder.initial_state()
xs = [x1,x1,x1]
outputs = state.transduce(xs)
print(outputs)
[expression 238/0, expression 242/0, expression 246/0]

Character-level LSTM

Now that we know the basics of RNNs, let’s build a character-level LSTM language-model. We have a sequence LSTM that, at each step, gets as input a character, and needs to predict the next character.

[17]:
import random
from collections import defaultdict
from itertools import count
import sys

LAYERS = 1
INPUT_DIM = 50
HIDDEN_DIM = 50

characters = list("abcdefghijklmnopqrstuvwxyz ")
characters.append("<EOS>")

int2char = list(characters)
char2int = {c:i for i,c in enumerate(characters)}

VOCAB_SIZE = len(characters)


[18]:
pc = dy.ParameterCollection()


srnn = dy.SimpleRNNBuilder(LAYERS, INPUT_DIM, HIDDEN_DIM, pc)
lstm = dy.LSTMBuilder(LAYERS, INPUT_DIM, HIDDEN_DIM, pc)

# add parameters for the hidden->output part for both lstm and srnn
params_lstm = {}
params_srnn = {}
for params in [params_lstm, params_srnn]:
    params["lookup"] = pc.add_lookup_parameters((VOCAB_SIZE, INPUT_DIM))
    params["R"] = pc.add_parameters((VOCAB_SIZE, HIDDEN_DIM))
    params["bias"] = pc.add_parameters((VOCAB_SIZE))

# return compute loss of RNN for one sentence
def do_one_sentence(rnn, params, sentence):
    # setup the sentence
    dy.renew_cg()
    s0 = rnn.initial_state()


    R = params["R"]
    bias = params["bias"]
    lookup = params["lookup"]
    sentence = ["<EOS>"] + list(sentence) + ["<EOS>"]
    sentence = [char2int[c] for c in sentence]
    s = s0
    loss = []
    for char,next_char in zip(sentence,sentence[1:]):
        s = s.add_input(lookup[char])
        probs = dy.softmax(R*s.output() + bias)
        loss.append( -dy.log(dy.pick(probs,next_char)) )
    loss = dy.esum(loss)
    return loss


# generate from model:
def generate(rnn, params):
    def sample(probs):
        rnd = random.random()
        for i,p in enumerate(probs):
            rnd -= p
            if rnd <= 0: break
        return i

    # setup the sentence
    dy.renew_cg()
    s0 = rnn.initial_state()

    R = params["R"]
    bias = params["bias"]
    lookup = params["lookup"]

    s = s0.add_input(lookup[char2int["<EOS>"]])
    out=[]
    while True:
        probs = dy.softmax(R*s.output() + bias)
        probs = probs.vec_value()
        next_char = sample(probs)
        out.append(int2char[next_char])
        if out[-1] == "<EOS>": break
        s = s.add_input(lookup[next_char])
    return "".join(out[:-1]) # strip the <EOS>


# train, and generate every 5 samples
def train(rnn, params, sentence):
    trainer = dy.SimpleSGDTrainer(pc)
    for i in range(200):
        loss = do_one_sentence(rnn, params, sentence)
        loss_value = loss.value()
        loss.backward()
        trainer.update()
        if i % 5 == 0:
            print("%.10f" % loss_value, end="\t")
            print(generate(rnn, params))

Notice that: 1. We pass the same rnn-builder to do_one_sentence over and over again. We must re-use the same rnn-builder, as this is where the shared parameters are kept. 2. We dy.renew_cg() before each sentence – because we want to have a new graph (new network) for this sentence. The parameters will be shared through the model and the shared rnn-builder.

[19]:
sentence = "a quick brown fox jumped over the lazy dog"
train(srnn, params_srnn, sentence)
149.1241455078  czdz
106.8910675049  vvaleamkxnmgiehjqziydnqw wvdyqvocv
70.4573440552
45.8460693359   k upcel romw ufec dbo tzhw
23.2288417816   p g
10.3553504944   d slick boxwn xox
3.4879603386    a quick browncfox jumped oher dhe lazy dog
1.2615398169    a quick brown fox jumped over the lazy dog
0.8019771576    q quick brown fox jumped kver the lazy dog
0.5865874290    a quick brown fox jumped over the lazy dog
0.4615180790    a quick brown fox jumped over the lazy dog
0.3798972070    a quick brown fox jumpgd over the lazy dog
0.3224956095    a quick brown fox jumped over the lazy dog
0.2799601257    a quick brown fox jumped ovec the lazy dog
0.2471984029    a quick brown fox jumped over the lazy dog
0.2212039977    a quick brown fox jumped over ehe lazy dog
0.2000847459    a quick brown fox jumped over the lazy dog
0.1825926602    a quick brown fox jumped over the lazy dog
0.1678716689    a quick brown fox jumped over the lazy dog
0.1553139985    a quick brown fox jumped over the lazy dog
0.1444785446    a quick brown fox jumped over the lazy dog
0.1350349337    a quick brewn fox jumped over the lazy dog
0.1267340034    a quick brown fox jumped over the lazy dog
0.1193781197    a quick brown fox jumped over the lazy dog
0.1128196493    a quick brown fox jumped over the lazy dog
0.1069323123    a quick brown fox jumped over the lazy dog
0.1016217321    a quick brown fox jumped over the lazy dog
0.0968040079    a quick brown fox jumped over the lazy dog
0.0924172848    a quick brown fox jumped over the lazy dog
0.0884065777    a quick brown fox jumped over the lazy dog
0.0847217217    a quick brown fox jumped over the lazy dog
0.0813286975    a quick brown fox jumped over the lazy dog
0.0781916752    a quick brown fox vumped over the lazy dog
0.0752858222    a quick brown fox jumped over the lazy dog
0.0725848153    a quick brown fox jumped over the lazy dog
0.0700684935    a quick brown fox jumped over the lazy dog
0.0677180886    a quick brown fox jumped over the lazy dog
0.0655169189    a quick brown fox jumped over the lazy dog
0.0634531230    a quick brown fox jumped over the lazy dog
0.0615142360    a quick brown fox jumped over the lazy dog
[20]:
sentence = "a quick brown fox jumped over the lazy dog"
train(lstm, params_lstm, sentence)

143.8308105469  rsfmqxt ozqsiiaqa
128.6366729736  v ypvoobknwepfeply
121.5253829956  ymquqr wfmocoe  ovwuwfmhdm ueod yewe
115.6775283813  q thluk  nwzwz eoodzod
109.5644912720  kxni bpuch xj enu mr ung omj dp eevem r vyyd t p lt  oyqbr
102.8272857666  bqkbhfmnb o mppoeoegoyt ddl rusay l  da
98.4713058472   qosqhn poafr  of uhexedoo pe h etavopyd pyiy d o yee al slghh
90.9259567261   a a qakmbn bm qcoayper efoyeroldpddm
85.2885818481   jku u upoowwbj jvdemmdfdeduree ood  oogdogpq dlto  y agzog i g gdlzac fokn  ux po opu uvlrr e eer rae ed ogy oel olzz
76.7979202271
71.5208969116   q bauqn kkowcon ffojjpfp ox ouvtt e lzuy dv hoty dgggo oqjgkgo  oonx oxm om vee  eeo o ad
62.6548461914   dc zbrb oqn xomper joehpee eztlazd lqau
56.1585731506   howunqm oofw ojpder re rezt ogavyy  dogdcwo
49.3944778442   t bc qouwr rw o bo xm ojumer r ree ele azlad do
41.5289344788   h uuikr ob wox jumepd rr loy ulz do
36.3642997742   dw ucown oox mx op jadmee ther loh
30.3189773560   a qqucakb onw fn jumjee oe tee taza dzo
24.9423580170   a quackk bborn fox juumpedde ove verr azy dogg
19.8645935059    aauk qbrrr oorf fmomed oee the layy oys
15.7765054703    aukc irr brow fox jumedd oveer aauzy dog
12.4693098068   a quikc brown foo jope dve ovr lay dogy dbog
9.8480806351    ucc brown fox juxmpe ooer the llazy dog
7.8634152412    a quiic brronn fox jumed over the lazy dog
5.9515495300    a quick brown x jxx jjumed over the lzyy log
4.5509667397    a quic bronn fox jumpdd over the lazy dog
3.4462499619    a quic brown fox jumped oper the lazy dog
2.2497565746    a qgiick brown fox jumped over the lazy dog
1.7854881287    a quicr brown fox jumped over the lazy dog
1.4716305733    a quick brown n fumpped over the lazy dog
1.2449830770    a quick brown fox jumped over th lazy dog
1.0744248629    a quick brown fox jumped over the lazy dog
0.9419770241    a qick brown fox jumped over the lazy dog
0.8365104198    a quik brown fox jumped over the lazy dog
0.7507840395    a quick brown fox jumped over the lazy dog
0.6798918843    a quick brown fox jumped over the lazy dog
0.6204041839    a quick brown fox jumped over the lazy dog
0.5698559284    a quick brown fox jumped over the lazy dog
0.5264287591    a quick brown fox jumped over the lazy ddog
0.4887620807    a quick brown fox jumped over the lazy dog
0.4558122456    a quick brown fomx jumped over the lazy dog

The model seem to learn the sentence quite well.

Somewhat surprisingly, the Simple-RNN model learn quicker than the LSTM! (If you increase the number of layers, this difference will be even more pronounced)

How can that be?

The answer is that we are cheating a bit. The sentence we are trying to learn has each letter-bigram exactly once. This means a simple trigram model can memorize it very well.

Try it out with more complex sequences.

[21]:
train(srnn, params_srnn, "these pretzels are making me thirsty")
239.8676452637  a quick brown fox jumped over the lazy dog
160.0521697998  a quick lrojny la tumped over the lazy driwn lrlazthe lazmzover the lazy dvgrn zrldnthd qheed gex thed bmne the lazy dbqre lazd yrd nrkwn fox jumped over the lazy dhe lazy dog
95.4504623413   a quick brzy e lazy e greaze erg
63.7105903625   ahqne lretze n mve lbmed tizd mare a
40.0441474915   aaquicpr the laz
21.2409267426   aze ecpretztli arirltzixtoped bt
7.8475003242    these pritzels pre maling me thirsty
2.9435257912    these pretz ls whe taking me thirsty
1.0504565239    these pretzels are making me thirsty
0.4977459908    these pretzels are making me the mty
0.3605296910    these pretz
0.2839964628    these pretzels are making me ghirsty
0.2345837206    these pretzels are making me thirsty
0.1999202222    these pretzels are making me thirsty
0.1742218733    these pretzels are making me thirsty
0.1543948352    these pretzels are making me thirsty
0.1386296600    these pretzels are making me thirsty
0.1257905215    these pretzels are making me thirsty
0.1151323095    these pretzels are making me thirsty
0.1061392874    tmese pretzels are making me thirsty
0.0984523371    these pretzels are making me thirsty
0.0918025821    these pretzels are making me thirsty
0.0859967992    these pretzels are making me thirsty
0.0808787867    these pretzels are making me thirsty
0.0763375387    these pretzels are making me thirsty
0.0722793266    these pretzels are making me thirsty
0.0686300844    these pretzels are making me thirsty
0.0653314814    these pretzels are making me thirsty
0.0623353273    these pretzels are making me thirsty
0.0596007779    these pretzels are making me thirsty
0.0570969619    these pretzels are making me thirsty
0.0547946990    these pretzels are making me thirsty
0.0526688434    these pretzels are making me thirsty
0.0507033207    these pretzels are making me thirsty
0.0488782115    these pretzels are making me thirsty
0.0471798219    these pretzels are making me thirsty
0.0455954187    these pretzels are making me thirsty
0.0441129319    these pretzels are making me thirsty
0.0427242145    these pretzels are making me thirsty
0.0414196625    these pretzels are making me thirsty