Beyond Causal Language Modeling

A deep dive into “Not All Tokens Are What You Need for Pretraining”IntroductionA few days ago, I had the chance to present at a local reading group that focused on some of the most exciting and insightful papers from NeurIPS 2024. As a presenter, I selected a paper titled “Not All Tokens Are What You Need for Pretraining”. It addresses a super simple but reasonable question: do we really need to apply the next-token prediction loss to all tokens during language model pretraining?Most of us have gotten used to the standard approach: feed in a huge web-scraping corpus, apply causal language modeling (CLM) across every token, and trust that bigger is better. The authors of this paper push back on that assumption, instead arguing that some tokens might actually be detrimental to the learning process. From their analysis, it becomes clear that focusing your training budget on particularly “useful” tokens can yield significant efficiency and performance gains in data efficiency and downstream tasks.In this post, I’ll summarize the paper’s core ideas of the proposed method called Selective Language Modeling (SLM) and share some insights from the experiments that impressed me the most.BackgroundToken-Level Noise in Large Web CorporaWhen you scrape huge amounts of texts from the web, it’s no surprise to find a fair amount of noise. Researchers have tried refining their corpora by applying document-level filters — removing entire documents that look suspicious or low quality. However, the authors note that noise can exist inside documents as well: a good article might still contain a handful of nonsense or extremely unpredictable tokens. If your model is forced to learn from everything, these noisy tokens can waste computation or even confuse the model.Token-Level Learning DynamicsBy examining training checkpoints at multiple stages, the authors categorized tokens based on whether their cross-entropy loss was high or low over time:L→L (low to low): These tokens are learned early and remain “easy” for the model, giving no further significant gradient updates later on.H→L (high to low): A subset that starts off hard and eventually gets learned. This means that those tokens still have significant room for improvement in learning.H→H (high to high): Tokens that remain difficult and fluctuate heavily, often due to their inherent unpredictability (i.e., aleatoric uncertainty)L→H (low to high): Tokens that were initially learned but become confusing later, possibly due to context shifts or noise.The big takeaway is that only a fraction of tokens truly contributes meaningful learning signals. Many tokens get mastered early (L→L) and then stop being beneficial. Meanwhile tokens that stay consistently hard (H→H) can be extremely noisy and unproductive throughout the entire training run.The Proposed Approach: Selective Language Modeling (SLM)Figure 1. Created by the author based on the figure presented in the original paper, with additional explanations and interpretationsThe authors propose Selective Language Modeling (SLM) as a more nuanced approach. Here’s how it works:Step 1: Train a Reference Model (RM)They first take a small but high-quality dataset that reflects their “ideal” data distribution. Using an already partially pre-trained base model, they fine-tune this model to create a RM. This model essentially becomes the judge that decide which tokens are worth training on later.Step 2: Score Tokens with Excess LossFor each token in the large-scale corpus, they compute the RM loss (how well the RM predicts that token) and compare it with the current training model’s loss. The difference, called the excess loss, indicates how much improvement is still possible on a token that “should” be predictable under the reference distribution.Step 3: Select Top-k% Tokens for Back-PropagationDuring the main pretraining, they still run the full forward pass across all tokens, but only back-propagate the loss on the top k% of tokens (those with the highest excess loss). This means that model devotes most of its capacity to the tokens that are both learnable and relevant, while ignoring tokens deemed less helpful. This dynamic selection happens at each step or batch, so it adapts as the training model itself changes.ExperimentsThis paper demonstrates SLM’s benefits across several experimental setups:Figure 2. Created by the author to summarize the experimental setupMath Domain ResultsWhen they continued pretraining smaller 1B models on OpenWebMath, the improvements were striking — up to 10% or even more gains on GSM8K and MATH over the same model trained with CLM. The authors highlight that SLM can reach baseline performance 5–10 times faster, requiring fewer tokens and less computation.In one remarkable case, their 7B model matched the accuracy of a prior SOTA approach while consuming only 3% of the training tokens that the prior had used.On top of that, a simple fine-tuning step further boosted MATH scores over 40% for a 1B model — a l

Jan 27, 2025 - 20:38
 0
Beyond Causal Language Modeling

A deep dive into “Not All Tokens Are What You Need for Pretraining”

Introduction

A few days ago, I had the chance to present at a local reading group that focused on some of the most exciting and insightful papers from NeurIPS 2024. As a presenter, I selected a paper titled “Not All Tokens Are What You Need for Pretraining”. It addresses a super simple but reasonable question: do we really need to apply the next-token prediction loss to all tokens during language model pretraining?

Most of us have gotten used to the standard approach: feed in a huge web-scraping corpus, apply causal language modeling (CLM) across every token, and trust that bigger is better. The authors of this paper push back on that assumption, instead arguing that some tokens might actually be detrimental to the learning process. From their analysis, it becomes clear that focusing your training budget on particularly “useful” tokens can yield significant efficiency and performance gains in data efficiency and downstream tasks.

In this post, I’ll summarize the paper’s core ideas of the proposed method called Selective Language Modeling (SLM) and share some insights from the experiments that impressed me the most.

Background

Token-Level Noise in Large Web Corpora

When you scrape huge amounts of texts from the web, it’s no surprise to find a fair amount of noise. Researchers have tried refining their corpora by applying document-level filters — removing entire documents that look suspicious or low quality. However, the authors note that noise can exist inside documents as well: a good article might still contain a handful of nonsense or extremely unpredictable tokens. If your model is forced to learn from everything, these noisy tokens can waste computation or even confuse the model.

Token-Level Learning Dynamics

By examining training checkpoints at multiple stages, the authors categorized tokens based on whether their cross-entropy loss was high or low over time:

  • L→L (low to low): These tokens are learned early and remain “easy” for the model, giving no further significant gradient updates later on.
  • H→L (high to low): A subset that starts off hard and eventually gets learned. This means that those tokens still have significant room for improvement in learning.
  • H→H (high to high): Tokens that remain difficult and fluctuate heavily, often due to their inherent unpredictability (i.e., aleatoric uncertainty)
  • L→H (low to high): Tokens that were initially learned but become confusing later, possibly due to context shifts or noise.

The big takeaway is that only a fraction of tokens truly contributes meaningful learning signals. Many tokens get mastered early (L→L) and then stop being beneficial. Meanwhile tokens that stay consistently hard (H→H) can be extremely noisy and unproductive throughout the entire training run.

The Proposed Approach: Selective Language Modeling (SLM)

Figure 1. Created by the author based on the figure presented in the original paper, with additional explanations and interpretations

The authors propose Selective Language Modeling (SLM) as a more nuanced approach. Here’s how it works:

Step 1: Train a Reference Model (RM)

They first take a small but high-quality dataset that reflects their “ideal” data distribution. Using an already partially pre-trained base model, they fine-tune this model to create a RM. This model essentially becomes the judge that decide which tokens are worth training on later.

Step 2: Score Tokens with Excess Loss

For each token in the large-scale corpus, they compute the RM loss (how well the RM predicts that token) and compare it with the current training model’s loss. The difference, called the excess loss, indicates how much improvement is still possible on a token that “should” be predictable under the reference distribution.

Step 3: Select Top-k% Tokens for Back-Propagation

During the main pretraining, they still run the full forward pass across all tokens, but only back-propagate the loss on the top k% of tokens (those with the highest excess loss). This means that model devotes most of its capacity to the tokens that are both learnable and relevant, while ignoring tokens deemed less helpful. This dynamic selection happens at each step or batch, so it adapts as the training model itself changes.

Experiments

This paper demonstrates SLM’s benefits across several experimental setups:

Figure 2. Created by the author to summarize the experimental setup

Math Domain Results

When they continued pretraining smaller 1B models on OpenWebMath, the improvements were striking — up to 10% or even more gains on GSM8K and MATH over the same model trained with CLM. The authors highlight that SLM can reach baseline performance 5–10 times faster, requiring fewer tokens and less computation.

In one remarkable case, their 7B model matched the accuracy of a prior SOTA approach while consuming only 3% of the training tokens that the prior had used.

On top of that, a simple fine-tuning step further boosted MATH scores over 40% for a 1B model — a level that smaller open-source models usually struggle to reach without huge training budgets.

General Domain Results

What if you already have a model that has seen a ton of general text? The authors show that SLM still helps.

Even a strong base model improved by around 5.8 percentage points on average across 15 benchmarks, especially in tougher domains like code and math. Therefore, even after large-scale training, there’s still a subset of tokens that can be conducive to improving the performance.

Self-Referencing

A question arises: what if you don’t have a curated dataset to t rain your reference model in the first place? The authors show a creative workaround: you can train a quick-and-dirty reference model on the same raw corpus at just few epochs.

Even though that reference model might not be perfect, it still does an okay job identifying the noisier tokens that hamper training. The result is a 2–3% downstream accuracy boost and a 30–40% reduction in tokens used.

Conclusion and Future Directions

Contributions of This Work

This paper provides both an illuminating analysis of token-level training dynamics and a new technique called SLM:

Token Loss Analysis:
They demonstrate that a majority of tokens contribute little beyond the initial training phase, while a small subset stays persistently high loss.

SLM for Focused Learning:
By leveraging a reference model to gauge how “useful” each token is, they manage to reduce training tokens drastically without sacrificing quality — in many cases even boosting downstream performance.

Broad Demonstration of Effectiveness:
SLM works not only on math-specific tasks but also in more general domains, with either a meticulously curated reference dataset or a reference model drawn from the same large corpus.

Where Could This Go Next?

SLM encompasses various potential directions for future research. For example:

Scaling Up Further:
Though the paper primarily focuses on models around 1B to 7B parameters, there remains the open question of how SLM performs at the 30B, 70B, or 100B+ scale. If the token-level approach generalizes well, the cost savings could be enormous for truly massive LLMs.

Reference Models via API:
If you can’t gather curated data, maybe you could use an API-based language model as your reference. That might make SLM more practical for smaller research teams who lack the resources for selective reference training.

Reinforcement Learning Extensions:
Imagine coupling SLM with reinforcement learning. The reference model could act as a “reward model,” and token selection might then be optimized through something akin to policy gradients.

Multiple Reference Models:
Instead of a single RM, you could train or gather several, each focusing on a different domain or style. Then, combine their token scores to produce a more robust multi-domain filtering system.

Alignment and Safety:
There’s a growing trend toward factoring in alignment or truthfulness. One might train a reference model to give higher scores to well-supported statements and zero out tokens that look factually incorrect or harmful.

Thanks for reading, and I hope this breakdown helps you understand this NeurIPS 2024 best-paperwork.


Beyond Causal Language Modeling was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.