Most AI-related papers these days oversell their findings with catchy titles and making dubious connections to human brains when the explanation for their results is actually just simple statistics. It is refreshing to see the bug has not infected everyone. Here is a gem I came across this week, titled “Vision Transformers Need Registers”. Meta’s AI division just can’t stop delivering hits and seems like the absolute best place to do AI research right now. This paper is just one banger plot after another so I added many of them here directly. Let me lay out why this result is actually even more important than the authors make it out to be.
Interpretable Attention!
The casual observer might be excused for thinking that attention and transformers can be used interchangeably, but attention has actually been around longer. Before 2018, attention was often with CNNs/RNNs but rarely did it improve model performance. In fact it would sometimes even decrease accuracy and was mostly used as a way to add interpretability to supervised models. Given an output (e.g. a word of a text translation or the label of image classification) the attention map was a way to visualize which parts of the input were used by the model to produce the output. In other words, the attention map learns a conditional relationship between each part of the input with each part of the output.
The transformers paper in 2017 reinvented attention as a new way to do unsupervised learning. By 2019, everyone was doing “self-attention”, which maps the input to itself, i.e. learn a conditional relationship between each part of the input to every other part of the input. Transformers were really powerful, blew the tops out of all SOTAs, GPTs and ViTs are the best models we have, great!
Losing its meaning!
But here is the downside of transformers, the feature maps make less sense. Notice how in older classification models (below), attention maps look smooth and make intuitive human sense. The model is paying attention to where the information is. You can see that there are some pixels outside the obvious region, but they seem like a distribution falling off, maybe that’s expected.
As usual, when you do unsupervised learning, the picture [pun intended] becomes murkier. As transformers have gotten bigger and pre-trained on larger datasets, the attention maps get weird (image below), start showing a few very loud pixels while in places where there is nothing going on. Why is the model paying attention to these pixels? There are actually other problems too that are not mentioned in the paper but I know from personal experience training these models. The features learnt by CNNs during unsupervised learning often have some structure, for example they produce nice interpretable clusters in t-SNEs. But transformers make much noisier clusters despite having much better performance on all metrics. So the features have more information but are somehow less interpretable. This has led to some people claiming that perhaps this tradeoff is insurmountable, maybe more complex and powerful models should be harder to understand?
Under the rug
But as is often the case in research, digging into inconsequential discrepancies can lead to transformational discoveries. The FAIR team dug, and found that -
About 2% of the features are orders of magnitude larger than the rest. These artifacts are super outliers.
They only appear deep inside large models after training for a while. Small networks, early layers and early training stages have no outliers.
The outliers occur on redundant pixels, i.e. pixels which have very little information to add above their neighbors.
The outlier features do not have local spatial information, but hold global information. Outlier features are worse than normal features (table above) at predicting the position/intensity of the pixels they map to but are far better at image classification (table below). So these pixels have info about the whole image, but not about their location in the image.
Intuition : Memory and Symbols
The Hypothesis
The model learns high level information about the image during training, but has no place to put this global information since every feature corresponds to a location in the image. So it ‘forgets’ the local information in the most redundant pixels and stores global information in those features which show up as outliers.
The Test
When a few extra blank features (or registers) were added to the model - features that did not correspond to a pixel in the input nor any performance metric in the output - the outliers completely disappear from the features corresponding to the pixels. The model instead spontaneously learns to put that information in the register features. These memory registers show the same behavior as the outliers, not correlated with any individual pixels but with global information in the image.
The Fallout
The authors find that adding just 1 register leads to (small) increase in accuracy on most tasks and the effect saturates as you add more registers.
Corollary
The authors are modest in the paper, they acknowledge that adding registers is not a new idea, but theirs is the most principled and defensible study. They refrain from making any hype claims in the paper, so let me make some for them.
In the short term, image models will get better. Not just the cool consumer models like Dall-E/MJ, I am talking models doing hard unsexy work like medical image segmentation/quality detection in factories/satellite imaging/physics and climate modelling where these artifacts cause substantial loss is performance.
In the long term, this is fantastic progress on interpretability. Look at this mind-blowing image below. Not all of the 12 registers could be this easily understood, but 4 of them map back to the input in an intuitive way. The model is being trained to segment the full outline of the image of the left. The CLS token is being trained by the loss function, so makes sense that it attends to all the foreground objects in the image. But as it turns out, the registers also map back to the individual objects separately, the spoon, even going so far as to separate the cup [reg0] and the tea [reg12]. They are not being trained to do so, the model just learns that these pixels map to an important concept that is distinct from the other pixels and decides to write it down in the blank space we gave it. Incredible! Also what information do the other registers hold?
Another really significant realization is that this behavior does not happen for smaller or undertrained models. Showing that loss decreases for larger models (I talk about scaling laws in more detail in Is My LLM Too Large) is one thing, but this is an intuitive interpretable evidence for the ‘scale is all you need’ hype. Large models don’t just do more of what the smaller models do, there are non-trivial capabilities that large models pick up that smaller models just don’t have. And it seems these capabilities are not as inscrutable as previously thought, maybe we can understand what these giant matrices are thinking after all.
And there you have it my thoughts on the latest work from MetaAI on how vision transformers can use blank registers to store high level semantic concepts, improving both accuracy and interpretability at the same time. For more intuitions on AI/ML, subscribe below and follow me on Twitter. You can also check out my other blog and projects on nirsd.com.