為什么BF16的FlashAttention會把訓練「炸掉」?

為什么BF16的FlashAttention會把訓練「炸掉」?

文章圖片

為什么BF16的FlashAttention會把訓練「炸掉」?

文章圖片

為什么BF16的FlashAttention會把訓練「炸掉」?

文章圖片

為什么BF16的FlashAttention會把訓練「炸掉」?

文章圖片

為什么BF16的FlashAttention會把訓練「炸掉」?

文章圖片

為什么BF16的FlashAttention會把訓練「炸掉」?

文章圖片

為什么BF16的FlashAttention會把訓練「炸掉」?

文章圖片

為什么BF16的FlashAttention會把訓練「炸掉」?

文章圖片

為什么BF16的FlashAttention會把訓練「炸掉」?

文章圖片

為什么BF16的FlashAttention會把訓練「炸掉」?

文章圖片

為什么BF16的FlashAttention會把訓練「炸掉」?

文章圖片

為什么BF16的FlashAttention會把訓練「炸掉」?

文章圖片


一句話總結:社區里困擾了多年的一個 “玄學” 現象終于被拆解清楚了:在 BF16 等低精度訓練里 , FlashAttention 不是隨機出 bug , 而是會在特定條件下觸發有方向的數值偏置 , 借助注意力中涌現的相似低秩更新方向被持續放大 , 最終把權重譜范數和激活推到失控 , 導致 loss 突然爆炸 。 論文還給出一個幾乎不改模型、只在 safe softmax 里做的極小修改 , 實測能顯著穩定訓練 。

因果鏈總覽(論文 Figure 1)

  • 標題:Why Low-Precision Transformer Training Fails: An Analysis on Flash Attention
  • 作者:邱海權 , 姚權銘
  • 機構:清華大學 電子工程系
  • 投稿:ICLR 2026 Oral
  • 關鍵詞:低精度訓練 , BF16 , FlashAttention , 數值穩定性 , 舍入誤差(rounding error) , 低秩表示(low-rank)
  • 論文鏈接:https://arxiv.org/abs/2510.04212
  • 代碼鏈接:https://github.com/ucker/why-low-precision-training-fails
背景:低精度訓練越來越 “剛需” , 但注意力比你想的更敏感
大模型訓練的現實是:顯存和吞吐決定一切 。 工業界普遍在混合精度里使用 BF16/FP16 , 甚至把 FFN 推到 FP8 , 以換取更高的訓練效率 。 但工程實踐同樣殘酷:越接近 “極限精度” , 訓練越容易出現難以解釋的不穩定 。
Flash Attention 是長上下文訓練的關鍵加速組件 , 幾乎成了標配 。 問題在于 , 社區長期存在一個可復現卻難以解釋的失敗案例:
  • 用 FlashAttention + BF16 訓練 GPT-2 , 一開始正常收斂 , 但在幾千 step 之后突然 loss 爆炸 。
  • 你可以通過回退到標準注意力、或把關鍵計算提高到 FP32 來 “救火” , 但代價是吞吐和顯存優勢沒了 。
這類問題被報告了多年(相關 issue 在多個開源項目里反復出現) , 卻一直缺少一條能 “從數值誤差一路解釋到 loss 爆炸” 的機制鏈 。

作者的做法很工程 , 且足夠 “可復現”:


機制解釋 1:相似低秩結構 , 讓誤差變成 “持續推力” 而不是噪聲



結果就是:權重更新被 “帶偏” , 譜范數和激活異常增長 , 最終把訓練推到 loss 爆炸 。


低秩結構相似性與偏置累積(論文 Figure 4/5)
【為什么BF16的FlashAttention會把訓練「炸掉」?】機制解釋 2:偏置從哪來?safe softmax + BF16 舍入誤差里藏著一個 “離散觸發器”

作者把問題追到了 FlashAttention 前向里的未歸一化輸出:









  • 檢測一行 S 中最大值是否出現多次
  • 一旦出現 “重復最大值” , 就動態調整 safe softmax 的行移位常數 m , 讓最大位置的指數也變成嚴格小于 1
論文給出的實現(概念上)如下:


實驗結果:穩定訓練不再 “突然炸”
論文在 BF16 設置下驗證了上述分析與修復:
  • GPT-2S:使用修改后的 FlashAttention , 在 AdamW 與 Muon 兩種優化器下 , 都能穩定訓練到 600K steps
  • GPT-2M:同樣能在 AdamW 下穩定訓練(論文展示到 100K steps)
  • 論文還提到該現象與結論在多種硬件上保持一致(包括 A100、RTX 4090、Ascend 910B)

驗證集 loss 曲線對比(論文 Figure 7)
更重要的啟示:別把低精度誤差當成 “零均值噪聲”
這篇論文的價值不只在 “修了一個 bug” , 更在于給出了一個可遷移的診斷范式:
  • 數值誤差未必是隨機噪聲 。 在特定分布與離散事件(如重復最大值、概率精確為 1)下 , 舍入誤差可能形成系統性偏置 。
  • 模型結構會放大偏置 。 注意力里涌現的相似低秩更新方向 , 讓偏置誤差更容易 “同向疊加” 。
  • 經驗修復為什么有效也能被解釋:論文討論了 attention sinks 與多最大值的關系 , 并給出了一個數值層面的連接;同時也指出一些穩定化技巧(如 QK normalization、Gated Attention)可能通過 “打散結構相似性” 來阻止誤差同向累積 。
作者介紹
邱海權是清華大學在讀博士研究生 , 研究方向涵蓋機器學習理論、表示學習與大模型機制分析 。 他的研究圍繞模型表達能力、結構歸納偏置以及參數空間幾何與優化動力學之間的內在聯系展開 , 關注模型在不同結構約束與訓練條件下的泛化行為與可組合性問題 。 整體上 , 他強調以可分析的理論框架刻畫模型的能力邊界與機制來源 , 從結構與原理層面理解深度模型為何有效、何時失效 。
姚權銘 , 清華大學電子工程系副教授 。 長期致力于數據高效學習與智能體系統研究 , 在少樣本學習、圖學習、知識圖譜與生物醫藥智能等方向取得系統性成果 。 發表 Nature 子刊、TPAMI、JMLR、ICML、NeurIPS、ICLR 等論文 130 余篇 , 被引 1.4 萬余次 。 代表性工作包括抗噪學習算法 Co-teaching、小樣本學習綜述、自動化圖學習方法及新藥物相互作用預測模型 。 現任 TPAMI、TMLR 編委及 Neural Networks 資深編委 , 多次擔任 ICML、NeurIPS、ICLR 領域主席 , 入選 IEEE Computing Top 30、IET Fellow 等 。

    推薦閱讀