Memory networks and why they’re a good Idea

Memory Networks are a relatively new class of models designed to alleviate the problem of learning long-term dependencies in sequential data, by providing an explicit memory representation for each token in the sequence. In this entry I’ll explain the model architecture and why this is a good idea, providing use case scenarios along the way.

Memory Networks: Under the Hood

The original concept of Memory Networks comes from Weston et al.1. However, we will base this description on the work by Sukhbaatar et al.2 since it requires way less supervision, yet performing on par, which makes it more realistic for common scenarios.

The input data on Memory Networks consists of a history (list of inputs so far), a query (current input) and, during supervised training, a label (what the model should output based on the given history and query). Each element in the history is embedded as a fixed-size vector and so is the query (quite commonly, using the same embedding matrix). Let us denote the history input as (x_1, x_2, ... x_M) represented as vectors of size e, which we will embed into memory vectors (m_1, m_2, ... m_M), of size d, usually by means of an embedding matrix A with dimensions (e \times d ). These vectors represent the memory that the model will check to define an output, but before that, we will construct yet another embedding of the inputs called output embedding (c_1, c_2, ... c_M), normally through means of a C matrix with dimensions identical to A. Now, let’s look at the query, which represents the current input, which is treated specially and not just like one more history element. We will embed this query into a vector u of size d so that now we can compute:p_i = Softmax(u^Tm_i)

This way, the vector p scores each memory embedding m with respect to the query. We will use this vector to find an average representation of all the (output) memory embeddings, which considers their relation with the query:

    $$o = \sum_ip_ic_i$$

. And that’s the one vector that we can use to decide an output. At this point there’s a wide variety of possibilities. The one used by the authors is to embed this vector o into the space of all possible answers and then applying softmax, that is:

    $$\hat a = Softmax(W(o + u))$$

Note how there, again the u query embedding is considered. We’ll explain that right away, but first, here’s the author’s diagram explaining how all these equations fit together:

Memory Networks Weston et al.


Now, where’s the recursivity? The reason why input query embedding u is considered at the end, is precisely so that we can make a module out of this architecture that can be repeated to compute more and more abstract output representations. To do that, instead of classifying the output vector o, we’ll use (o + u) as the new u for a new round of this process, up to a fixed but unbounded number of times (in practice, less than 10). Each such recursive call is called a hop, and is regarded as the model reconsidering its answer over and over until it gets to a ‘good enough’ answer (i.e. one that can be better classified). Then we compute final answer (\^{a}) only on the last hop. That would look like this:

Memory Networks

So how many parameters do we have there? So far, a (e \times d) matrix A, a matrix C and B of same dimensions to encode the output embeddings and query respectively, and a (d \times V) matrix W to compute the final answer, V being the total number of possible actions. Then multiply that by the number of hops. This can result in a large number of parameters, but that is just the basic prototype of the model. In practice, it can be constrained so that each layer uses the same A, B and C, or even more creative layouts so that making C from hop i equals A from hop i+1 for all i but the last. That’s the trial and error part and we’ll leave it there, just keep in mind that state of the art results has been achieved on many relevant problems with just a few parameters compared to all the matrices mentioned here.

Why is This a Good Idea?

Right from the start, this is a novel way to look at sequential data: instead of analyzing it piece by piece, updating an internal fixed-size memory representation (that forgets more from the past the more inputs it gets), Memory Networks consider the entire history so far explicitly, with a dedicated vector representation for each history element, effectively removing the chance to “forget”. The limit on memory size becomes now a hyper-parameter to tune, rather than an intrinsic limitation of the model itself.

Taking an entire sequence and representing it as a single input offers the above-mentioned advantages with respect to memory, but of course, it has challenges, one of the main ones being that the order in the sequence is lost, as well as the local proximity of sequence elements, which is crucial for most sequential inputs. Think about language modeling: if you want a model that predicts the next word based on the words seen so far, of course, the order matters, otherwise “hers” would be equally as good a prediction as “his” for the observed sequence “he likes her work but she prefers”. Secondly, fitting the sequence as a fixed size input either makes you lose information (e.g. average word embedding or Bag of Words to represent a sentence) or limits the size of sequences you can deal with (e.g. concatenating all the elements up to a fixed length, padding for shorter sequences). Memory Networks deal with this as follows:

  • To deal with the order of the inputs in the sequence, time has to be explicitly modeled as an extra feature. A straightforward way to achieve this is to add a feature that just tracks the index of the element within the sequence, but Sukhbataar et al.2 take a more involved method, adding a ‘time encoding’ vector to the embedding of each input, and learning these vectors along during training. All in all, the way traditional RNNs deal with time seems indeed more elegant and natural but even so, they fall short on results vs Memory Networks as we’ll see below.
  • Memory Networks do not force a limitation of the sequence length: the exact same architecture can deal with different lengths of sequences, as each will get a memory cell representation that is eventually averaged as single fixed size vector to classify on. Nevertheless, it is possible and in fact, quite common to force a memory limit, as this makes classifying easier and indeed, in most real problems we need only a few recent elements to make a good decision. The ability to consider all tokens in the story at the same time to make a decision reminds us to attention mechanisms in Sequence to Sequence models, but unlike those, in a Memory Network the embeddings for each token are independent of one another (except for the fact they were encoded with the same parameters), whereas in Sequence to Sequence, each latent representation keeps information of all the previous tokens in the sequence.

Memory Networks are also recurrent models since intermediate operations are repeated recursively taking their previous value (a so-called hop). That is, hidden layers are a function of their previous value. But unlike traditional RNNs, each recursive step has access to the same input information (namely, the entire sequence history so far), and differ only on the level of abstraction they operate at (i.e. the whole point of having more layers in a deep learning model). Do note how even though a Memory Network is recursive, this recursivity has nothing to do with the inputs or outputs: in traditional RNNs, the recursive operations occur once per every input in the sequence, potentially giving one output per step as well. But in a Memory Network, the recursive operations have nothing to do with the sequential nature of the input and in fact, only one output is produced after these recursive operations. Therefore, a Memory Network effectively disentangles recursivity from the nature of the inputs and outputs.

Moreover, the use of these recursive operations or memory hops offer very promising empirical results2. Sukhbaatar et al. try it on Q&A tasks3 from Weston et al. These are 20 types of question-answering problems, where the bot is given a set of facts and then a question that can be answered by reasoning over some of the facts (some others are potential distractors). Below you can see how every hop makes the model focus more and more on the supporting fact that answers the question (taken from 2)

 

Memory networks Sukhbaatar et al

 

The number of hops proved to be beneficial as well on a language modeling task, which is an area long time dominated by different recurrent models such as vanilla RNNs and LSTMs. In this domain, Sukhbaatar et al. proved not only that the Memory Network achieved better performance (lower perplexity) than any other model tried, but also that it got a benefit that is disproportionally high with respect to the number of parameters (almost 3 times less than an LSTM). Incidentally, this experiments also proved the benefit of adding more hops.

All in all, this is a promising architecture and an entire family of models. There are many reasonable ways model and embed the inputs, to decide how many embeddings and how to compute the output before the recursive hops, and also many well-known non-linearities that could be exploited (in this vanilla form, the only non-linearity used was softmax). So long as a model has an explicit memory for each input in the history and computes an answer that considers them all, potentially repeating the process recursively, it fits the general idea of a Memory Network.

Memory Networks for Task-Oriented Chatbots

Tasks oriented dialogue agents or chatbots are ever more popular models that are meant to engage on a conversation with a human and are meant to serve an often narrow-domain task for this user (such as booking a hotel). These models need to understand natural language from the user and make a decision that considers both the most recent user input and the entire conversation history. This seems like a scenario designed for a memory network. Most such models are rule-based, so as to be more reliable, at the expense of very expensive to develop and maintain. The other alternative that is gaining momentum is a machine learning approach: using lots of conversations to train an action classifier. RNNs and LSTMs are the by-default go to model in this case. So, how would a Memory Network perform here? We don’t have to guess, there’s a baseline provided by the authors of the Memory Network themselves4.

In this work, they propose a new dataset for these type of agents called the bAbI dialog tasks (of which they designed 5). These tasks come in the form of simulated conversations between a human and a bot. The human wants to book a restaurant based on filters such as type of cuisine, location and price range. The bot can then either check a database to suggest a matching option, ask for more information or announce no matching options available. The task consists on training a model on this conversations so that it learns to predict the bot actions at each step. They propose several baseline models of which a Memory Network offers the best results. Their Memory Network however, has to deal with quite a task, since it has to learn to understand the human text (relying mostly on pre-trained word embeddings) and then choose one out of 4212 total possible answers. Even with such a classifying task, the Memory Network performs quite well (96.1% accuracy on all turns (77.7% if testing with out of vocabulary words in test data), matching 49.4% of the conversations perfectly (0, if testing with out of vocabulary terms).

This opens the door for a very reasonable question: why not using a well performing Natural Language Understanding component that is optimized for dealing with the user text and classifying it by intent and identifying entities, so that instead of “hi, I would like to book an Italian restaurant in the north” the Memory Network deals simply with book_restaurant(cuisine=italian, location=north)? Plus, why classifying over 4212 total output phrases, when you can simply classify each one as an “action class”, filling the exact text with the help of the NLU? Then instead of considering answers like “The hot pot is a cheap restaurant in the center” and “La cucina is a moderately priced restaurant in the west”, the model simply considers “[restaurant-name] is a [price] restaurant in the [location]” and if chosen, it uses the NLU provided info to fill in the blanks. There are plenty of commercial and open source well-performing NLU modules out there, that are machine learning based and can be easily replaced. I did just that and used a Memory Network with this advantage and got 100% in all the metrics above, confirming the power of a Memory Network, especially when it doesn’t have to learn to do natural language understanding and is reserved only for what it does best: predicting actions.

So what next? The bAbI tasks are a sound baseline but are a bit too safe to compare it with a real-life scenario. Real conversations are not as clean, but instead, they are full of noise, misunderstanding and mistakes. I said above there are 5 bAbI tasks and that is not completely truth: there is a 6th task that is taken from the 2nd Dialog State Tracking Challenge, this is a dataset consisting of actual humans talking to actual bots. The conversations are transcripts from audio with different levels of degradation. As a result, the conversations are full of noise to the point that even perfectly overfitting to train and development data will result in a low accuracy on unseen data. The same Memory Network that got 96.1% accuracy on bAbI task 5 got 41.1% on this task. So this is next and the outcome will be presented on soon to follow entry.

References

  1. Weston J, Chopra S and Bordes A. Memory Networks, ICLR 2015
  2. Sukhbaatar S, Szlam A, Weston J and Fergus R. End-to-end Memory Networks, NIPS 2015
  3. Weston J, Bordes A, Chopra S and Mikolov T. Towards AI-complete question answering: A set of prerequisite toy tasks, arXiv 2015
  4. Bordes A, Boureau Y and Weston J. Learning End-to-End Goal Oriented Dialog, ICLR 2017
  5. Henderson M, Thomson B, and Williams J. The second dialog state tracking challenge, 15th
    Annual Meeting of the Special Interest Group on Discourse and Dialogue 2014

Leave a Reply

Your email address will not be published. Required fields are marked *