MoE 公式到源码对照
这一页只做一件事:把 MoE 推理里最核心的四个对象连起来——router 概率、top-k 选择、capacity 限制、All-to-All 通信。读完之后,
moe_routing.py里的每个函数都能对应到一条明确的公式。
这页覆盖哪些源码
../../src/simulators/moe_routing.py:softmax、top-k route、辅助损失、capacity、dispatch、drop rate、通信量。
1. Router 是怎么把 token 送进 expert 的
1.1 先把 logits 变成概率
对某个 token,router 给出 $E$ 个 expert 的 logits,记为 $z \in \mathbb{R}^{E}$。softmax 概率为
为了数值稳定,工程实现通常先减去最大值:
源码对应 softmax():
也就是说,router 的第一步和普通分类器没有本质区别,都是“logits -> 概率”。
1.2 top-k 路由为什么还要重新归一化
MoE 通常不会让每个 token 经过所有 expert,而是只保留概率最大的 $K$ 个 expert:
但是被选中的这 $K$ 个 expert 的原始概率和通常小于 1,所以还要在被选集合内重新归一化:
这一步很关键,因为最终 expert 输出的加权和必须保持在一个归一化尺度上。
源码里的 topk_route() 分三步完成这件事:先 softmax,再找 top-k,再对 top-k 概率做归一化。
这段代码里最容易忽略的点是:np.argpartition() 只保证“最大的 k 个元素被挑出来”,不保证它们有序,所以后面还要用 argsort() 再按概率从大到小排一次。
1.3 expert load 是怎么来的
若第 $t$ 个 token 的第 $k$ 个路由槽位选中了 expert $e$,则该 expert 的路由负载可以写成
源码里直接用 np.bincount() 把 topk_indices 展平后统计频次,这正对应上面的计数公式。
2. 负载均衡损失为什么能约束路由塌缩
2.1 辅助损失的定义
MoE 最怕的情况是:router 总是把 token 往少数几个 expert 上堆。为避免这种塌缩,常见的辅助损失写作
其中
这里的 $f_i$ 是“实际被选频率”,$P_i$ 是“router 赋给该 expert 的平均概率质量”。
2.2 为什么均匀分配时损失等于 1
如果路由完全均匀,则对所有 expert 都有
代回去得到
这也是很多实现里把“越接近 1 越平衡”当作直觉参考的原因。
2.3 对应源码
../../src/simulators/moe_routing.py 里的 load_balancing_loss() 完整对应上述定义:
其中:
actual_freq对应 $f_i$。mean_prob对应 $P_i$。loss对应 $E \sum_i f_i P_i$。
3. Capacity factor 怎么从公式变成 drop rate
3.1 每个 expert 最多能接多少 token
若一个 batch 中有 $N$ 个 token,每个 token 选 $K$ 个 expert,共有 $E$ 个 expert,capacity factor 为 $\text{CF}$,则平均每个 expert 的理论负载是 $N K / E$,实际容量通常写成
对应源码:
所以 capacity factor 本质上是在“理想平均负载”之上再乘一个冗余系数。
3.2 dispatch 为什么会产生 dropped mask
当某个 expert 接收的 token 数超过 capacity 时,后来的 token 就必须被拒绝、丢弃或者 fallback。对于第 $e$ 个 expert,可接受 token 数满足
超过容量的路由位置会落入 dropped 集合。源码里的 dispatch_to_experts() 是逐 token、逐路由槽位做这个检查:
这段代码的工程含义很直接:capacity 不是“建议值”,而是 dispatch 阶段的硬门槛。
3.3 drop rate 的公式
如果一共存在 $N K$ 个路由槽位,其中被丢弃的个数是 $N_{\text{drop}}$,则
源码实现就是取布尔掩码的均值:
也就是说,drop_rate() 默认统计的是“路由槽位丢弃率”,而不是“token 级丢弃率”。如果做 top-2 routing,一个 token 丢一个槽位和丢两个槽位,对这个指标的影响不同。
4. All-to-All 的通信量为什么正比于 token 数和 hidden size
4.1 dispatch + combine 的最小模型
在 Expert Parallel 中,token 的 hidden states 要先发到目标 expert 所在设备,计算完成后再发回原设备。因此通信至少包含两次:
若一共有 $N$ 个 token,每个 token 选择 $K$ 个 expert,模型维度为 $d_{\text{model}}$,每个元素字节数为 $s$,则总通信量近似为
这里前面的 2 就对应发出去再收回来两次。
4.2 对应源码
all_to_all_bytes() 正是这个公式:
这也是为什么 MoE 的系统瓶颈经常不在专家计算本身,而在 All-to-All 的尾延迟和负载倾斜。
5. 把整条 MoE 推理链串起来
从单个 batch 的视角,MoE 推理的最小链路可以写成:
把这条链映射到 ../../src/simulators/moe_routing.py 就是:
softmax():把 logits 变成概率。topk_route():保留 top-k,并计算归一化权重和 expert load。load_balancing_loss():分析是否出现路由塌缩。expert_capacity():给出每个 expert 的最大接收量。dispatch_to_experts():真正把 token 分给 expert,并记录 dropped mask。drop_rate():衡量容量限制对质量的影响。all_to_all_bytes():估算系统级通信量。
6. 建议的源码阅读顺序
先读 ../../math_dictionary/moe-routing-math.md,把路由、辅助损失、通信量的大图建立起来。
再读 ../../src/simulators/moe_routing.py 的
softmax()和topk_route(),先看路由结果怎么生成。接着读
load_balancing_loss()和expert_capacity(),把“均衡”和“容量”这两个约束接进来。最后读
dispatch_to_experts()、drop_rate()、all_to_all_bytes(),从算法视角走到系统视角。
这一页记住一句话
MoE 的核心不是“多几个 expert”这么简单,而是把每个 token 的去向、每个 expert 的容量、每次 All-to-All 的通信量都写成可计算的账本;router 决定算哪里,capacity 决定能不能接住,通信决定系统会不会被拖慢。
最后更新于