公式到源码对照
这页专门用来解决“公式能背下来,但落到代码就断层”的问题。阅读顺序固定为:公式 -> 张量形状 -> 核心变量 -> 仓库源码。
这页覆盖哪些源码
../../src/attention/mha_gqa.py:缩放点积注意力、分头、GQA 共享 KV。
../../src/attention/rope_rmsnorm.py:RoPE cache、旋转、RMSNorm。
../../src/attention/flash_attn_sim.py:分块 attention 和在线 Softmax。
1. 缩放点积注意力对应 mha_gqa.py
mha_gqa.py1.1 线性投影
输入张量写成:
线性投影为:
其中:
$Q \in \mathbb{R}^{B \times T \times D}$
$K, V \in \mathbb{R}^{B \times T \times D_{kv}}$
MHA 时 $D_{kv} = D$;GQA 时 $D_{kv} = H_{kv} \cdot d_h$
对应源码:
这三行正对应 mha_gqa_forward() 里的投影部分。这里先在最后一维完成线性映射,暂时还没有拆成多个 head。
1.2 从 [B, T, D] 拆成 [B, H, T, d_h]
[B, T, D] 拆成 [B, H, T, d_h]定义:
mha_gqa.py 里的 _split_heads() 做的是一次 reshape + transpose:
它对应的数学动作是:
1.3 缩放点积和数值稳定 Softmax
单个 head 的注意力公式:
仓库实现:
这里有三个关键点:
np.swapaxes(k, -1, -2)对应 $K^\top$/ np.sqrt(head_dim)对应缩放因子 $1 / \sqrt{d_h}$softmax()里先减去x_max,对应安全 Softmax 的数值稳定写法
安全 Softmax 的公式是:
1.4 GQA 如何共享 KV
定义:
GQA 的核心不是减少 Query 头,而是让每个 KV 头服务 $G$ 个 Query 头:
对应源码:
这里的 np.repeat(..., axis=1) 是把 KV 头在“head 维”上逻辑展开到与 Query 头对齐。最重要的工程意义是:
比如 $H_q = 32, H_{kv} = 8$ 时,KV Cache 直接缩小为原来的 $1/4$。
1.5 输出合并
attention 结果还是 [B, H, T, d_h],最终要回到模型维度:
对应源码:
这一步把所有 head 拼接回去,再乘输出投影 $W_O$。
2. RoPE 与 RMSNorm 对应 rope_rmsnorm.py
rope_rmsnorm.py2.1 RMSNorm
RMSNorm 先算均方根:
再做缩放:
对应源码:
这和公式几乎一一对应:
x * x对应平方np.mean(..., axis=-1)对应对最后一维求均值np.sqrt(... + eps)对应均方根* weight对应可学习缩放参数 $w$
2.2 RoPE 的频率缓存
RoPE 先构造每个二维平面的旋转角频率。常见写法是:
每个位置 $p$ 的相位为:
对应源码:
这里的 np.outer(pos, inv_freq) 就是一次性构造所有位置、所有频率的相位表 freqs。
2.3 RoPE 为什么需要 _rotate_half
_rotate_half对每一对偶数 / 奇数维度,RoPE 做的是二维旋转:
_rotate_half() 做的就是把
从而能写成向量化形式:
对应源码:
3. FlashAttention 对应 flash_attn_sim.py
flash_attn_sim.py3.1 标准注意力为什么会有大中间矩阵
标准 attention 需要显式构造:
当 $T$ 很大时,$S$ 和 $P$ 都会变成巨大的中间矩阵,带来显著的 HBM 读写压力。
3.2 分块后的三个状态量
FlashAttention 不保存整个 $S$ 和 $P$,而是对每个 Query block 维护三个量:
在扫描到当前块 $(i, j)$ 时,先得到局部统计:
再做在线更新:
对应源码:
这几行就是整套在线 Softmax 的核心,和论文公式严格对应。
3.3 为什么它和标准 Softmax 等价
关键原因是:不同块虽然各自减去了不同的局部最大值,但在合并时又通过 $\alpha$ 和 $\beta$ 把它们重新 rescale 到共同基准 $m_i^{\text{new}}$ 下,所以最终结果和“先看完整一行再做 Softmax”完全一致。
换句话说,FlashAttention 改变的是计算顺序和数据流,不改变数学定义:
3.4 代码里每个循环分别在做什么
外层循环:固定一个 Query block,维护这一小块输出的在线统计量。
内层循环:顺序扫描所有 Key / Value block,把每个块的局部结果并进来。
scores = (q_blk @ k_blk.T) * scale:对应局部块的缩放点积。out[i:i_end] = o_i:扫描完整个 KV 轴后,写回当前 Query block 的最终输出。
4. 推荐对照顺序
最后回看 mha-vs-gqa-full-derivation.md 和 mha-vs-mla-full-derivation.md,把工程结论串起来。
最后更新于