TL;DR
Cache the key, value matrix of the previous sequence for long contextual references. Then, use a kNN lookup (e.g. Faiss, ScaNN) to extract the keys and values that are relevant to the current query and concatenate them into the key and value matrix to get attention. At this time, previous cached memories are not learned.
background
long document
A common approach to long sequence lengths in transformers is to truncate the sequence to the maximum sequence length that can fit in memory.
In this case, if the same document is truncated by length, the information before it is not known, and this is called the “context fragment problem”.

This is especially true when you need to reference distant contexts, such as novels or code.
To solve this problem, we have Transformer-XL, longformer, reformer, etc.
The main idea behind Transformer-XL is,
We cache the hidden vector of the nth layer of the previous segments and concatenate it with the hidden vector of the current segment to perform the attention operation.
In this case, cached hidden vectors are not back-propagated.
kNN lookup
Finding and pulling the k closest data given a query For example, given a trained word2vec, think of it as having a vector that computes vector(queen) - vector(female) + vector(male), and you want to compute which vector is closest to the vectors of all the words trained on word2vec. Efficient implementations of this are 1) faiss 2) ScaNN
retrieval with transformer
Performing a kNN lookup means performing a kind of retrieval, which uses the vectors from the transformer to perform a search, and approaches that apply this to NLP tasks include REALM and RAG.
REALM is a model that e2e learns a model that retrieves documents when a query is given to perform QA and an MRC model that attaches the resulting docuemnt to it.

Memorizing Transformer
As explained in the background, the memorizing transformer is an approach to efficiently tackle long documents that uses a kNN lookup to select the segments with the most similar key values to the query and then appends them to the attention operation.
First, the document is cut in the following order
In the lower layers, we proceed like a normal transformer decoder. We cache the key and value vectors from each segment.
Queue it until it runs out of memory, pull it out when it runs out of memory, and insert the key value of the latest segment.

Now, given a query, it 1) pays attention to the general local context and 2) performs a kNN lookup on the query in memory, pulls out k keys and values, and creates an attention matrix from these k keys and values. (Think of it as a transformer decoder for k keys and values.)

Then you can do a weighted sum with different scala parameters for 1) and 2) depending on the head.
In our experiments, we found that almost all heads refer to external memory in most cases.
Position bias
Added T5-style position bias.
This seems to be a slightly simplified version of the usual relative position embedding.
Batching Memory is isolated because each batch has a different document, and when the document is finished, its memory is cleared (designed to not refer to other documents).
Experiment
Dataset
- github code, my math-related papers on arXiv, Isabelle, a corpus of math theory proofs, C4, data with token lengths greater than 4096, and PG-19, data from English-language books.
Parameter
- 12 layers transformer, 1024 hid dim, 8 heads, FFN dim 4096
- In kNN, k is 32, used in 9th layer of 12 layers
- sentence-piece tokenizer(vocab size 32K)
- Adafactor optimizer, linear warmup scheduler, square root decay, 32 TPU
- JAX implementation
Result
Scaling to Larger Model
Our model with 8K tokens in memory can perform similarly to vanilla Transformer, even with a model size 5x smaller.
Effect of External Memory
XL cache is viewed as Transformer-XL.
External memory improves perplexity for vanilla Transformer, Transformer-XL
In vanilla Transformer, the segment is truncated and the first token is missing information, so the XL cache fills in the localized short-range context and external memory fills in the longer context.
Performance for context length 512 and memory 8192 (arxiv 2.49) is similar to context 2048 and xl cache 2048 (arxiv 2.42).
The fact that memory is non-differentiable, context is differentiable and affects all layers, and performance is similar means that long-range context is not necessarily needed in the layers below the transformer.
Finetuning a non-memory model to use memory
Pre-training like this is pretty expensive, so I tried using memory only for fine-tuning and it worked well.
Information Retrieval Patterns
Often looked up rare words like definitions, people’s names, etc.
Example of context retrieved from the Isabelle dataset
conclusion
- the idea is simple and intuitive
- in our domain, I don’t know… the segment is so long that it would only make sense when it goes beyond the segment that transformer XL can cover.
- truncate the seq_len and make the batch_size much larger to train faster?
- since it only needs to be applied to finetuning, it should be easy to apply (once implemented).
etc
papers
- relative PE https://arxiv.org/pdf/1803.02155.pdf
- different style or relative PE https://arxiv.org/pdf/2006.15595.pdf