Learning to (Learn at Test Time): RNNs with Expressive Hidden States

id:

2407.04620

Authors:

Yu Sun, Xinhao Li, Karan Dalal, Jiarui Xu, Arjun Vikram, Genghan Zhang, Yann Dubois, Xinlei Chen, Xiaolong Wang, Sanmi Koyejo, Tatsunori Hashimoto, Carlos Guestrin

Published:

2024-07-05

arXiv:

https://arxiv.org/abs/2407.04620

PDF:

https://arxiv.org/pdf/2407.04620

DOI:

N/A

Journal Reference:

N/A

Primary Category:

cs.LG

Categories:

cs.LG, cs.AI, cs.CL

Comment:

N/A

github_url:

_

abstract

Self-attention performs well in long context but has quadratic complexity. Existing RNN layers have linear complexity, but their performance in long context is limited by the expressive power of their hidden state. We propose a new class of sequence modeling layers with linear complexity and an expressive hidden state. The key idea is to make the hidden state a machine learning model itself, and the update rule a step of self-supervised learning. Since the hidden state is updated by training even on test sequences, our layers are called Test-Time Training (TTT) layers. We consider two instantiations: TTT-Linear and TTT-MLP, whose hidden state is a linear model and a two-layer MLP respectively. We evaluate our instantiations at the scale of 125M to 1.3B parameters, comparing with a strong Transformer and Mamba, a modern RNN. Both TTT-Linear and TTT-MLP match or exceed the baselines. Similar to Transformer, they can keep reducing perplexity by conditioning on more tokens, while Mamba cannot after 16k context. With preliminary systems optimization, TTT-Linear is already faster than Transformer at 8k context and matches Mamba in wall-clock time. TTT-MLP still faces challenges in memory I/O, but shows larger potential in long context, pointing to a promising direction for future research.

premise

outline

quotes

notes

summary

1. Brief Overview

This paper introduces a new class of sequence modeling layers called Test-Time Training (TTT) layers. Unlike traditional RNNs which are limited by the expressiveness of their fixed-size hidden state, TTT layers make the hidden state a machine learning model itself, updated via a self-supervised learning step even during test time. The authors propose two instantiations: TTT-Linear and TTT-MLP, which achieve linear complexity and match or exceed the performance of strong Transformers and modern RNNs (like Mamba) in long-context scenarios. Key optimizations, including mini-batching and a dual form for computations, significantly improve wall-clock time.

2. Key Points

  • TTT layers address the limitations of RNNs in long-context scenarios by making the hidden state a learnable model.

  • TTT-Linear and TTT-MLP are two instantiations of TTT layers, using a linear model and a two-layer MLP respectively as the hidden state.

  • Experiments show that TTT layers match or exceed the performance of Transformers and Mamba (a state-of-the-art RNN) across various model sizes and context lengths.

  • Mini-batch TTT and the dual form significantly improve the hardware efficiency of TTT layers.

  • TTT-Linear achieves comparable speed to Mamba at 8k context and is faster than Transformer at the same context length.

  • The paper explores several self-supervised tasks for TTT layers.

3. Notable Quotes

No specific quotes were identified as particularly notable for future reference.

4. Primary Themes

  • Improving RNN performance in long-context tasks: This is the central problem addressed by the paper. The limitations of traditional RNNs in handling long sequences are highlighted, and the TTT approach is presented as a solution.

  • Test-time training (TTT): This novel technique is the core contribution of the paper. TTT involves training the hidden state model even during the inference phase, leading to improved performance in long sequences.

  • Hardware efficiency: The paper emphasizes the importance of efficient algorithms and explores optimizations to improve the wall-clock time performance of TTT layers. Mini-batching and the dual form are key optimizations presented.

  • Self-supervised learning: The choice of self-supervised learning objective for updating the hidden state is crucial, and the paper explores various options, demonstrating how this aspect can be optimized end-to-end.