Differentiable Routing for Mixture-of-Experts (MoE)
Context
You are working with an MoE layer that routes each token to k experts (often k ∈ {1, 2}). The current router performs hard, non-differentiable decisions (e.g., argmax over logits), preventing end-to-end training via gradient descent.
Let the router produce logits z ∈ R^E for E experts per token. Hard routing uses g = one_hot(argmax(z)) (or top-k), and the layer output is y = Σ_j g_j · Expert_j(x).
Task
Propose modifications to make the routing parameter learnable via gradient descent. Compare at least two approaches from the following (you may cover more):
-
Straight-Through Estimator (STE)
-
Gumbel-Softmax with temperature annealing
-
Very steep sigmoid/softmax relaxation (possibly sparsemax/entmax or soft top-k)
-
REINFORCE (policy gradient)
Your comparison should address:
-
Output fidelity vs. hard routing at inference
-
Training stability and variance
-
Computational cost (experts executed per token)
-
Implementation details (including any required reparameterization, sampling, or gradient tricks)
Additionally, specify:
-
Any auxiliary losses (e.g., load balancing) and their formulas
-
Temperature schedules and exploration strategies
-
Regularization and strategies to avoid expert collapse
-
Practical defaults (k, capacity factor, loss weights) and any pseudocode if helpful