Beyond RAG: Scaling long context

Alternative architectures: This is the way?

In our previous series on Advanced RAG, we looked at various ways to create efficient pipelines to get around the limitations imposed by context length windows of LLMs. But is this limitation insurmountable?

In today’s world dominated by Transformer architectures, it is tough to imagine an alternative reality that scales context efficiently (subquadratically). This post is hence dedicated to the “what if” alt architectures, an alt reality where context length is not the limiting factors anymore.

Attention is “also” what you need!

Transformer Architecture:

Let’s start by recapping just why Transformers are the default architecture for today’s LLMs. It all began with the revolutionary paper from Google in 2017 - “Attention is all you need”. The revolutionary part being the ability to process language in parallel vs sequentially (as was being done with RNNs previously - more on that below).

One of its biggest strengths, the general purpose multimodal nature of this architecture, is well summarized by one of the co-authors of the paper Ashish Vaswani:

Computational efficiency, scalability and accuracy were the results of this parallelization via the “self attention module” and a big leap from the previously popular sequential RNN architectures, their advancement coinciding with the rise of the GPUs! Hence started the race to train on larger datasets + more parameters for more powerful LLMs.

Why are we discussing alternative architectures then? Seems like attention is all we need! Well, maybe not. One of the biggest criticisms of transformer architecture is that the compute and memory scales quadratically with sequence length ie more $$$ for training (attention) and generation (full lookback).

This has led to the excitement around sub-quadratic LLMs that we discuss below. I am going to qualify the criticism around quadratic scaling by noting that while true for pre-training, when it comes to inference, attention is quite a small part of the compute, so take the hype with a big pinch of salt.

RNN Architecture:

The RNN architecture was the most widely used before the transformer era. Here, a hidden state vector (memory) is maintained and updated with each additional token, which is then used to predict the next token. This can be applied to various scenarios by using different “modes”, as outlined in the image below by Andrej Karpathy in his blog.

The biggest drawback of RNNs is that they struggle with processing extended sequences due to what is referred to as “the vanishing gradient problem”. Architectures such as Long Short Term memory (LSTM) and Gated Recurrent Unit (GRU) hold the promise of overcoming this issue, but for now the transformer architecture has solved this issue effectively, hence becoming the de-facto solution for generative models.

RWKV:

Let’s re-invent the RNN you say? Enter RWKV which is an open source, non-profit group under the linux foundation. Claiming to combine the best of both RNN and Transformer, RWKV is based on Apple’s Attention Free Transformer.

What that implies is that it scales better than transformers on both training and inference (linearly vs quadratically with sequence length) ie lower cost, larger context length. The Eagle 7b model based on this architecture beats all other 7b parameter models in multilingual benchmarks and trades blows with the Llama family + Mistral at a fraction of the inference cost and latency! It is the most efficient per token as well.

This is not without its drawbacks however. As per RWKV docs, this architecture is “sensitive to prompt formatting”, and “weaker at tasks requiring lookbacks”. Good place to plug Andrej Karpathy’s why transformers are awesome thread for the like for like.

However, the fact that we have an RNN based model competing on benchmarks, makes RWKV as one of the strongest candidates with a shot at dethroning the incumbent transformer architecture!

MonarchMixer:

MonarchMixer is another exciting prospect given the powerhouse team behind it at Stanford’s Hazy Research. The intent with this architecture is to overcome the limitations due to transformers scaling quadratically, along the two axes: “sequence length” (attention) and “model dimension” (MLP layers).

The solution? Use “Monarch matrices” as the “hardware efficient and expressive subquadratic-primitive” to replace attention and MLPs. More details here for the technically inclined. As the authors note in the paper, this is still very much a research area and since the M2 layer is not optimized for inference, it is still early to compare to Transformers or SSMs (more on that below). But definitely worth tracking the progress as it evolves.

State Space Models (SSMs):

Defining how SSMs work is a post in itself, but at a high level they are dynamic systems based on the evolution of input, output and state variables over time. State here can be defined as the representation of information necessary to describe the system at any point in time and the reason this is relevant here is because past information can be summarized into a state to predict the next state. SSMs can be explained as the following 2 equations.

The state equation describes how a) the current state evolves (forgets irrelevant stuff) over time and b) how the input influences the state (what to remember from the input)

The output equation describes how a) the current state translates to output (ie can predict the next state) and b) how the input directly influences the output (how to use input to make the prediction)

The big issue with SSMs? They treat all inputs equally. This makes them very efficient (small state), but not very effective (forget a lot like RNNs). Transformers are the opposite with their full lookback thanks to Attention.

Mamba:

This is what makes Mamba interesting as it claims to be a good compromise between the two, by being a “Selective” State Space Model.

Compare the above to the transformer architecture. What do you notice? The replacement of Attention (“Multi-Head Attention”) with Selection (SSM). So, now instead of looking at the entire corpus of historical data at the time of inference, the selectivity of what should be retained in memory is done at the time of creating the “compressed state”, leading to a more powerful and more efficient architecture vs RNNs and Transformers, reportedly. Noteworthy also are the tricks applied to make it hardware optimized for GPUs (ref. Tri Dao’s Flash Attention paper).

Given the in-built context as long term memory, the state now can be referenced at very low cost and latency. This makes it useful for scaling both context and storage, hence the excitement. The jury is still out, while we already have JAMBA released, taking this a step further using mixture-of-agents. More on that in a future post.

Striped Hyena:

Quick mention here of Striped Hyena from Together AI, an OSS project based on the SSM architecture. As per the authors, “it is a hybrid architecture composed of rotary (grouped) attention and gated convolutions arranged in Hyena blocks.” The result? It is >30%, >50% and >100% faster in training for 32k, 64k and 128k sequence lengths respectively, making them great for fine-tuning on long context use cases.

Striped Hyena’s being OSS with focus on real world use cases, the timing of its launch after Mamba and the fact that it claimed beating both Llama-2 and Yi-7b, meant it immediately turned heads. They did share a promising roadmap, so hopefully we should get more details on training data and further validation of the promise that these mixed architectures may hold.

Diffusion Models:

I shall be remiss not to mention Diffusion models as an exciting part of the alt architecture landscape, especially in the realm of Text to Image (for now). This free short course by Deeplearning.ai is a good starting point to get familiar with diffusion models for the uninitiated.

There are a few variations in the architecture within diffusion models such as CNN based, transformer based, state space models etc. Further to that, there are variations in how these architectures scale. Of these, Latent Diffusion has been the dominant way to scale for high res image synthesis (Stable Diffusion is based on this architecture for example).

The breakthrough of Latent Diffusion Models (LDMs) is the concept of Latent Space, a computationally efficient, low dimensional space in which high-frequency, imperceptible details are abstracted away. LDMs are predominantly based on a convolutional backbone (U-Net CNN) and were improved upon by the introduction of Diffusion Transformers (DiT).

Hourglass Diffusion:

Enter Hourglass Diffusion! SOTA is a shifting goalpost in AI, and Hourglass Diffusion Transformer (HDiT) architecture aims to do exactly that for diffusion models. LDMs were shown not to be great at representing finer detail, leading to authors of this paper introducing a pure transformer architecture for high-resolution pixel-space image generation with subquadratic scaling of compute cost with resolution. Goodbye CNN/LDM!

But aren’t Transformers computationally inefficient? Not when you do this:

Instead of treating images the same regardless of resolution, this architecture adapts to the target resolution, processing local phenomena locally at high resolutions and separately processing global phenomena in low-resolution parts of the hierarchy.

The paper claims to bridge the gap between scaling properties of transformers and and efficiency of U-nets, and given it comes from Katherine Crowson (ref Stable Diffusion), it is not to be taken lightly.

Ring Attention (Gemini 1.5?):

And finally, while we are discussing scaling long context, how can we not mention the elephant in the room - Gemini 1.5 Pro with its 10M context window, with a claimed (and community verified to an extent) near-perfect multimodal needle in a haystack recall up to 1M tokens.

How is that even possible? Is RAG dead? Gemini’s release has once again shifted the goalpost with people now wondering what use cases an infinite context could unlock. More importantly for the subject of this blog post, apart from being a mixture-of-expert (MoE) transformer-based model, what fundamental architectural changes has Deepmind employed to get to this scale without severely compromising on compute or efficiency.

One plausible theory is based on the Ring Attention paper, which leverages blockwise computation of self-attention and feedforward to distribute long sequences across multiple devices while fully overlapping the communication of key-value blocks with the computation of blockwise attention. 

Confused? The above is better explained as the following slide from Andreas Kopf’s presentation: 

Why Ring Attention matters is because it allows context length to scale linearly with the number of devices, all the while maintaining great recall, which makes the possibility of a near-infinite context very real. The caveat? It is for the GPU rich, which might be our constraint, not Google’s.

Only time will tell whether Transformers continue to dominate, but there is value in understanding alt architectures as we move into the era of AI in production vs demos. Memory efficiency will be crucial as we scale context and while there are leaders, there is no clear winner…yet! Fingers crossed for multiple winners across different architectures.

Until next time, happy reading!

Reply

or to participate.