The conventional wisdom in deep learning: bigger models are better, but bigger models cost more compute per token. Every token processed by GPT-3 activates all 175 billion parameters. If you double the model size, every single inference call costs twice as much. This hard coupling between model capacity and compute seemed like a fundamental constraint. Switch Transformers (Fedus, Zoph, Shazeer, 2021) breaks that constraint with a simple and elegant idea: most of the model is optional. Which parts activate depends on which token you’re processing.
The core idea
The analogy: Imagine a consulting firm with 100 specialists — tax lawyers, software architects, medical researchers, translators, financial analysts. When a client asks a question, you don’t require all 100 specialists to consult on every query. You route the query to the 2-3 most relevant experts, get their input, and respond. The firm has enormous collective expertise (all 100 specialists’ knowledge), but any single query only engages a fraction of that expertise.
Mixture of Experts (MoE) applies this to language models. Instead of one large feedforward network (FFN) per Transformer layer, you have many smaller FFN “experts.” A learned routing function (the “gate”) looks at each input token and picks which expert(s) to send it to. The token only passes through those experts — all others are inactive. You get a model with enormous parameter count (all experts combined), but constant computational cost per token (only the selected experts execute).
Switch Transformer’s specific contribution: simplify MoE. Prior MoE systems routed to top-2 or top-k experts, used complex load-balancing algorithms, and were notoriously unstable. Switch Transformer routes each token to exactly 1 expert — the simplest possible routing. This reduces routing computation, simplifies implementation, and turns out to work just as well or better than top-2.
The mechanism, step by step
Architecture:
A standard Transformer layer has: Attention → FFN (expand from d_model to 4d_model, then contract back).
A Switch Transformer layer replaces the FFN with a Switch FFN layer:
STANDARD TRANSFORMER LAYER:
Input x (d_model)
|
[Multi-Head Self-Attention]
|
[FFN: Linear(d_model → 4d_model) → ReLU → Linear(4d_model → d_model)]
|
Output
SWITCH TRANSFORMER LAYER:
Input x (d_model)
|
[Multi-Head Self-Attention]
|
[Switch FFN Layer]:
|
[Router: Linear(d_model → N_experts) + softmax]
→ selects top-1 expert for each token
|
Token 1 → Expert 3 FFN
Token 2 → Expert 7 FFN
Token 3 → Expert 3 FFN (same expert, different token)
Token 4 → Expert 1 FFN
...
|
[weighted combination: output × routing_probability]
|
Output
The router:
Each token’s embedding is multiplied by a weight matrix (where is the number of experts), then softmaxed:
The selected expert index . The token is passed to expert , and the expert’s output is scaled by (the routing probability). This scaling is important: it provides gradient signal for learning which tokens should go to which experts.
Why does this increase model capacity without increasing compute?
In a T5-Base model (250M parameters), the FFN is a significant fraction of total params. Replace that FFN with 64 experts (each the same size as the original FFN), and you’ve gone from 250M to 6.5B parameters — but each token still only passes through one FFN, so the FLOP count per token is nearly unchanged. The extra parameters represent additional capacity that gets activated for relevant tokens.
Load balancing:
The obvious problem: all tokens might prefer the same expert. If 99% of tokens go to Expert 1, that expert is overloaded and others are idle, wasting capacity. Switch Transformers address this with an auxiliary load-balancing loss:
where is the fraction of tokens routed to expert , and is the mean routing probability for expert over a batch of tokens. This loss is added to the main training objective and penalizes uneven routing.
There’s also an expert capacity buffer: each expert can process at most tokens per batch, where . Overflow tokens skip their expert and pass through unchanged (a “dropped token”). This hardware-efficient design avoids dynamic tensor shapes.
Training stability with bfloat16:
Prior MoE models were unstable at lower precisions. Switch Transformers introduce several tricks:
- Initialize router weights from a smaller distribution (std 0.1 instead of 1.0)
- Use bfloat16 for activations but float32 for router logits
- Apply dropout selectively (expert dropout: drop entire experts randomly during training, which encourages each to be independently capable)
Find the instinct
The Shazeer heritage:
Noam Shazeer (a coauthor) had published the original MoE paper in 2017 (Sparsely-Gated Mixture-of-Experts), which introduced top-k routing to Transformers. That paper showed you could scale to 137 billion parameters in 2017 — but the system was complicated: top-2 routing, auxiliary losses, instabilities requiring careful tuning. Almost nobody could reproduce it.
Switch Transformers’ key insight: the complexity was unnecessary. Routing to 1 expert instead of 2 doesn’t hurt performance meaningfully (the empirical gap is small) and dramatically simplifies everything. The paper runs an extensive ablation showing that top-1, top-2, and top-3 routing perform similarly in terms of quality, but top-1 wins on efficiency.
Decoupling parameter count from compute:
“Mixture of Experts defies this and instead selects different parameters for each incoming example. The result is a sparsely-activated model — with outrageous numbers of parameters — but a constant computational cost.”
This is the fundamental tension in scaling: more parameters means more quality (up to scaling law limits), but also more compute per token. Dense models can only grow the parameter count by adding compute. MoE breaks the coupling: you can grow the parameter count “for free” by adding experts that are only activated on relevant tokens.
The implication for Chinchilla-style analysis: if you can double parameters without doubling compute, you should — because parameters are cheap (once you have the hardware). MoE models are effectively data-hungry and parameter-hungry simultaneously, which makes them compelling as scale increases.
Results
On C4 (English web crawl), pre-training speed (steps to reach a target perplexity):
- T5-Base (250M): baseline
- Switch-Base (same compute, ~7B params with 64 experts): 7× faster to same quality
The Switch-Base model reaches the quality of T5-Base in 1/7th the training steps, using the same number of FLOPs per step. This is the headline result: 7× speedup from sparsity.
On multilingual (mT5-Base, 101 languages):
- Switch-Base achieves gains over mT5-Base across all 101 languages — including low-resource languages where you might worry the expert routing would break down.
On scale:
- Trained models up to 1.6 trillion parameters on the “Colossal Clean Crawled Corpus”
- 4× speedup over T5-XXL at comparable compute budget
What doesn’t work:
- MoE models are harder to distill into smaller dense models
- Expert routing creates load-balancing headaches in multi-GPU setups (experts must be distributed, and tokens must be routed across devices — adds communication overhead)
- The “dropped tokens” issue means some inputs are silently processed sub-optimally
- Fine-tuning MoE models can be harder; the experts may specialize in pretraining patterns that don’t transfer cleanly
Practical implications
Switch Transformers proved MoE is practical at scale and opened the path to models like Mixtral (8×7B, routes to 2 of 8 experts), Mixtral 8×22B, and GPT-4 (rumored to be MoE). The architecture is now standard for large-scale language models: when you need quality beyond what a dense model can deliver at a given compute budget, you use MoE.
Mixtral’s design (top-2 of 8, not top-1) differs from Switch Transformer’s top-1 recommendation, suggesting the exact routing number is less important than the overall architecture. The key takeaway from Switch Transformer is that sparse activation works, is trainable, and is faster — the details are tunable.
Connections
- mixture-of-experts — the technique this paper simplifies and scales
- transformer — the architecture Switch Transformers extend with sparse MoE layers
- scaling-laws — Switch Transformer changes scaling dynamics by decoupling parameter count from compute
- attention-is-all-you-need — the Transformer architecture that Switch Transformers extend
- scaling-laws-neural-language-models — Switch Transformer changes the scaling dynamics; more params at constant compute
- mixtral-of-experts — Mixtral is the direct successor architecture, using top-2 MoE routing
Citation
Fedus, W., Zoph, B., & Shazeer, N. (2021). Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity. JMLR 2022. https://arxiv.org/abs/2101.03961