Math for Deep Learning II
The linear layer collapse problem (also called rank collapse or
representation collapse) occurs when neural network layers learn representations that
lose information and become degenerate.
When multiple linear transformations are stacked without nonlinearities, they collapse into a single
equivalent linear transformation, severely limiting the network’s expressiveness.
Given layers: y = W₃(W₂(W₁x))
This simplifies to: y = (W₃W₂W₁)x = Weffx
Where Weff is just one matrix – no matter how many layers you stack!
# Bad - collapses to single linear transform
x → Linear(512) → Linear(256) → Linear(128) → y
# Good - each layer can learn nonlinear features
x → Linear(512) → ReLU → Linear(256) → ReLU → Linear(128) → y
During training, weight matrices can lose rank, meaning:
Full rank matrix (r = min(m,n)):
W = [diverse row vectors] → can map to full output space
Rank-collapsed matrix (r << min(m,n)):
W ≈ [repeated/dependent rows] → outputs confined to low-dimensional subspace
Consequences:
Hidden layer activations become too similar:
Good: h₁ = [0.8, 0.1, -0.3, 0.9, ...] (diverse features)
h₂ = [-0.2, 0.7, 0.4, -0.1, ...]
h₃ = [0.5, -0.6, 0.2, 0.3, ...]
Collapsed: h₁ = [0.5, 0.5, 0.5, 0.5, ...] (all similar)
h₂ = [0.48, 0.52, 0.49, 0.51, ...]
h₃ = [0.51, 0.49, 0.50, 0.50, ...]
All hidden units learn nearly identical representations → no diversity → poor performance.
# Bad - collapses to single linear transform
x → Linear(512) → Linear(256) → Linear(128) → y
# Good - each layer can learn nonlinear features
x → Linear(512) → ReLU → Linear(256) → ReLU → Linear(128) → y
Contrastive learning (SimCLR, MoCo) can suffer from collapse:
Bad: Model learns to map all inputs to same point
→ Loss minimized (all positives match), but useless
Good: Repulsion term prevents collapse
→ Forces different samples apart
Problem: Attention weights become uniform
softmax(QKᵀ/√dₖ) → [0.25, 0.25, 0.25, 0.25] for all queries
Causes:
Solutions:
Problem: All word vectors become nearly identical
king ≈ queen ≈ dog ≈ computer ≈ [0.3, 0.3, 0.3, ...]
Causes:
Solutions:
In SimCLR or BYOL:
Problem: Encoder maps all images to constant vector
f(cat) = f(dog) = f(car) = [1, 0, 0, ...]
Trivially minimizes contrastive loss but learns nothing!
Solutions:
W ~ N(0, 2/(n_in + n_out))W ~ N(0, 2/n_in)# Check weight matrix rank
import numpy as np
rank = np.linalg.matrix_rank(W)
full_rank = min(W.shape)
print(f"Rank: {rank}/{full_rank} - Collapse: {rank < 0.9 * full_rank}")
# Check activation diversity
activations = model(batch) # shape: [batch, features]
variance = activations.var(dim=0).mean()
print(f"Avg feature variance: {variance:.4f}") # Low → collapse
# Check attention entropy
attn_weights = attention_scores.softmax(dim=-1)
entropy = -(attn_weights * attn_weights.log()).sum(dim=-1).mean()
print(f"Attention entropy: {entropy:.4f}") # Low → uniform collapse
# Prevents activation magnitudes from collapsing
x = LayerNorm(x) # normalize to mean=0, std=1
# Preserves gradient flow, prevents vanishing
x = x + TransformerBlock(x) # skip connection
# Attention uses 1/√d_k to prevent softmax saturation
scores = Q @ K.T / math.sqrt(d_k)
# VICReg loss components:
variance_loss = max(0, 1 - std(z)) # maintain variance
invariance_loss = MSE(z1, z2) # match augmentations
covariance_loss = off_diagonal(cov(z)) # decorrelate features
This is why modern architectures (ResNets, Transformers) carefully combine linear layers with
nonlinearities, normalization, and skip connections – each element fights a different aspect of collapse!
Understanding collapse helps you debug training issues, design better architectures, and know when
your model is learning meaningful representations versus just memorizing trivial solutions.
November 8, 2025