The Problem
In a softmax-based classification problem with many classes (vocabulary of 50K-500K, or millions of items in recommender systems), every training step requires computing a probability over all classes. That’s a forward and backward pass touching every output unit per training step — prohibitively expensive at scale. For Word2Vec on 1.6B words and 50K vocabulary: every training step does a 50K-way softmax. Total: ~38 trillion class-output computations. Untenable.
The Key Insight
Replace the full softmax with a binary classification: instead of predicting the correct class out of N, distinguish the correct class from a few randomly sampled “negative” classes. For each training example, sample negatives (typically 5-25). Reformulate as binary cross-entropy problems: positive vs each negative. The cost drops from to , with negligible accuracy loss.
Mechanism in Plain English
- For each training example with a true class , sample negative classes from a noise distribution .
- Train the model on binary tasks: classify the true class as positive, each negative as negative.
- The loss is the sum of binary cross-entropies (or log-sigmoid losses).
- After training, the model has learned to score true class higher than random classes — which is enough for ranking and similarity.
ASCII Diagram
FULL SOFTMAX:
prediction = softmax(W * h) # over all N classes
cost: O(N) per training step. For N=50K: huge.
NEGATIVE SAMPLING:
positive: y_pos = sigmoid(W[c] dot h) # true class -> 1
negatives: for each w in K samples from P_n:
y_neg = sigmoid(W[w] dot h) # negative -> 0
cost: O(K+1) per training step. For K=10: 5000x faster.
Math with Translation
Original full softmax loss:
- = score for class (e.g., ).
- The denominator requires summing over all classes.
Negative sampling loss:
- = sigmoid.
- First term: drive the score of the true class up (toward ).
- Second term: drive the scores of negatives down (toward ).
- Only scores need to be computed; .
Noise distribution choice: word2vec uses a unigram distribution raised to the 3/4 power:
This is empirically tuned: pure unigram over-samples frequent words (“the,” “a”) that don’t carry much signal; uniform under-samples them. The 3/4 power is a practical compromise that slightly upweights rare words.
Concrete Walkthrough
WORD2VEC SKIP-GRAM, vocabulary = 50K, K = 5 negatives:
TRAINING EXAMPLE:
center = "fox"
context = "brown" (the true positive)
SAMPLE 5 NEGATIVES from P_n^0.75:
["spaghetti", "the", "philosophical", "of", "blockchain"]
COMPUTE 6 SCORES:
s_brown = W_in[fox] dot W_out[brown]
s_spaghetti = W_in[fox] dot W_out[spaghetti]
s_the = W_in[fox] dot W_out[the]
s_philosophical = W_in[fox] dot W_out[philosophical]
s_of = W_in[fox] dot W_out[of]
s_blockchain = W_in[fox] dot W_out[blockchain]
LOSS:
L = -log_sigmoid(s_brown) - sum(log_sigmoid(-s_neg) for neg in [...])
Pull s_brown UP (toward +infty), pull s_neg DOWN (toward -infty).
GRADIENT FLOW:
Updates: W_in[fox], W_out[brown], W_out[spaghetti], W_out[the], ...
Only 6 vectors get updated this step (out of 50K). Massively cheap.
COMPARED TO FULL SOFTMAX:
Would need to update all 50K W_out rows per step. ~10000x more expensive.
What’s Clever
The first clever recognition: you don’t need to know exactly how unlikely each negative is — just that the model has separated them from the positive. The full softmax computes calibrated probabilities; negative sampling only needs the direction of separation. For ranking, retrieval, and similarity tasks, that’s all that matters.
The second clever recognition: negative sampling implicitly factorizes a shifted PMI matrix (Levy & Goldberg, 2014). The objective is mathematically equivalent to learning vectors such that . This means the resulting embeddings have the same geometric structure as count-based PMI matrices (e.g., LSA), with the bonus that they were produced via a streaming, scalable algorithm.
The third clever (and slightly hacky) move: the 3/4-power smoothing. Pure unigram negative sampling under-emphasizes rare words; uniform over-emphasizes them. The 3/4 power is a empirical Goldilocks: enough rare-word presence to learn discriminative features, not so much that the gradient is dominated by noise. Most successor papers (BERT’s MLM, contrastive learning systems) use similar power-law smoothing.
Code
import torch
import torch.nn as nn
import torch.nn.functional as F
class NegSamplingLoss(nn.Module):
def __init__(self, vocab_size, dim, K=5, noise_dist=None):
super().__init__()
self.W_in = nn.Embedding(vocab_size, dim)
self.W_out = nn.Embedding(vocab_size, dim)
self.K = K
# noise_dist: torch.tensor of shape (vocab_size,), unigram^0.75 normalized
self.noise_dist = noise_dist
def forward(self, center_ids, pos_ids):
# center_ids: (batch,) pos_ids: (batch,)
batch_size = center_ids.size(0)
# Sample K negatives per example from noise distribution
neg_ids = torch.multinomial(self.noise_dist, batch_size * self.K, replacement=True)
neg_ids = neg_ids.view(batch_size, self.K)
c = self.W_in(center_ids) # (batch, dim)
p = self.W_out(pos_ids) # (batch, dim)
n = self.W_out(neg_ids) # (batch, K, dim)
pos_score = (c * p).sum(-1) # (batch,)
neg_score = (c.unsqueeze(1) * n).sum(-1) # (batch, K)
loss = -F.logsigmoid(pos_score).mean() \
- F.logsigmoid(-neg_score).mean()
return lossKey Sources
-
word2vec-efficient-estimation-word-representations — introduced negative sampling for embedding learning
-
sentence-bert-siamese-bert-networks — uses in-batch negatives, a related but different idea
-
bge-c-pack-general-chinese-embeddings — modern contrastive systems use both in-batch and hard-mined negatives
Related Concepts
- contrastive-learning — generalized version; negative sampling is the foundational case
- word-embeddings — primary application of negative sampling
- self-supervised-learning — negative sampling makes self-supervised contrastive training feasible
Open Questions
- In-batch vs explicit negatives: modern contrastive learning often uses other examples in the same batch as negatives, avoiding explicit sampling. When is this better?
- Hard negative mining: random negatives become uninformative late in training. Mining “hard” negatives (those the model currently scores high but are actually wrong) accelerates learning but is fiddly. Sweet spot?
- Negative sample count K: word2vec uses 5-25. Modern systems (CLIP, BGE) effectively use thousands via in-batch negatives. How much does K matter?
- Noise distribution choice: unigram^0.75 is empirically tuned. Are there principled choices? Some recent work uses learned negative samplers (adversarial training).