Inference-time Scaling
1. Inference-time Scaling
Inference-first perspective:
- Pre-training algorithms for generative AI should have inference-time scalability in sequence length and refinement steps.
- These algorithms should also scale efficiently, e.g., with as few number of model steps as possible.
- Before developing the training method, it should be verified whether the model has enough capacity to represent the target distribution during inference.
2. Two axes of inference-time scaling
- Sequence length defines the number of tokens. This is seen in standard autoregressive large language models (LLMs), which have recently demonstrated strong chain-of-thought and reasoning capabilities from inference-time scaling.
- Refinement steps defines the number of iterative steps that improve existing tokens without changing the sequence length. This is seen in standard score-based diffusion models, where more steps indicate fewer discretization errors in the numerical solver. We note that the refinement process is not restricted to denoising alone; it is considered valid as long as the sequence length does not increase.
| Not Scalable in either Sequence Length or Refinement Steps | VAE, GAN, Normalizing Flows |
| Scalable in Sequence Length but not in Refinement Steps | GPT, PixelCNN, MaskGiT, VAR |
| Scalable in Refinement Steps but not in Sequence Length | Diffusion models, Energy-based models, Consistency models, Parallel non-linear equation solving for autoregressive models1 |
| Scalable in both, with sequence length in the outer loop | AR-Diffusion, Rolling diffusion, MAR, Blockwise parallel decoding2 |
| Scalable in both, with refinement steps in the outer loop | Diffusion Forcing3, Autoregressive distribution smoothing4 |
-
This uses an iterative approach to sample from all the tokens in parallel, despite being trained with autoregressive objectives.
-
Applies "predict, verify, accept" as part of the refinement process.
-
Trains with independent per-token noise levels, enabling global iterative optimization of the entire sequence while maintaining causal dependencies.
-
Performs an iterative denoising process with an autoregressive model as the inner loop.