FlashAttention-4正式發布:算法流水線大改,矩陣乘法級速度

FlashAttention-4正式發布:算法流水線大改,矩陣乘法級速度

文章圖片

FlashAttention-4正式發布:算法流水線大改,矩陣乘法級速度

文章圖片

FlashAttention-4正式發布:算法流水線大改,矩陣乘法級速度

文章圖片

FlashAttention-4正式發布:算法流水線大改,矩陣乘法級速度

文章圖片

FlashAttention-4正式發布:算法流水線大改,矩陣乘法級速度

文章圖片

FlashAttention-4正式發布:算法流水線大改,矩陣乘法級速度

文章圖片

FlashAttention-4正式發布:算法流水線大改,矩陣乘法級速度

機器之心編輯部
經過一年的努力 , FlashAttention-4 終于正式上線了 。
近日 , 深度學習領域重要底層優化技術 FlashAttention 迎來大版本更新 。
FlashAttention 核心作者、普林斯頓大學助理教授 Tri Dao 表示 , 在 Blackwell GPU 上 , 即使瓶頸截然不同 , 注意力機制的執行速度現在也幾乎與矩陣乘法一樣快了!

當前 , Tensor Core 的速度現在非常快 , 以至于注意力前向傳播的瓶頸呈指數級增長 , 而注意力后向傳播的瓶頸是共享內存帶寬 。
重新設計的算法中包含一些旨在克服這些瓶頸的機制 , 包括使用多項式進行指數模擬 , 新的在線 softmax 可以避免 90% 的 softmax 重新縮放 , 2CTA MMA 指令允許兩個線程塊共享操作數以減少 smem 流量等 。

論文地址:https://github.com/Dao-AILab/flash-attention/blob/main/assets/fa4_paper.pdf 代碼鏈接:https://github.com/Dao-AILab/flash-attention接下來 , 就來詳細了解一下 。
硬件趨勢:不對稱的硬件擴展
長期以來 , Attention 作為無處不在的 Transformer 架構中的核心層 , 一直是大語言模型和長上下文應用的性能瓶頸 。
此前 FlashAttention-3 通過異步執行和 warp 專門化對 Attention 進行了優化 , 但其主要針對的是 Hopper GPU(H100)架構 。
然而 , AI 行業已經迅速轉向部署 Blackwell 架構系統 , 例如 B200 和 GB200 。 而像 Blackwell GPU 這樣的現代加速器延續了一種趨勢:硬件的非對稱擴展(asymmetric hardware scaling) 。
在這種趨勢下 , 張量核心(Tensor Core)的吞吐量增長速度遠快于其他硬件資源 , 像是共享內存帶寬、用于指數運算等超越函數運算的特殊函數單元(SFU) , 以及通用整數與浮點 ALU……
舉個例子 , 從 Hopper H100 到 Blackwell B200 , BF16 張量核心吞吐量增加了 2.25 倍(從 1 到 2.25PFLOPs) , 但 SFU 數量和共享內存帶寬基本保持不變 。
這種擴展不對稱性對像 Attention 這樣的復雜 kernel 優化產生了深遠影響 。
具體來看 , Attention 的核心包含兩個通用矩陣乘法(GEMM):


中間夾著 softmax , 但在真實實踐中 , Attention 還涉及大量輔助工作 , 比如數據搬運、同步、數據布局轉換、元素級運算、調度、mask 處理等 。
傳統的觀點認為 , Attention 的性能完全由 GEMM 的速度決定 。 然而 , 對 B200 進行「速度與饋送」分析顯示:主要的瓶頸不在于張量核心 , 而是:
前向傳播中用于 Softmax 指數運算的 SFU 單元; 反向傳播中的共享內存流量 , 受 shared memory bandwidth 限制 。為此 , 團隊推出 FlashAttention-4 , 一種算法 + kernel 的協同設計 , 核心目標在于 , 通過最大化矩陣乘法與其他瓶頸資源之間的重疊 , 在 B200(BF16)上 , 最高可達 1605TFLOPs/s(71% 的利用率) , 比 cuDNN 9.13 快 1.3 倍 , 比 Triton 快 2.7 倍 。
協同設計的核心思路如下:
新型流水線:為前向和反向傳播分別設計了新的軟件流水線 , 利用 Blackwell 的全異步 MMA 和更大分塊(Tile)尺寸 , 最大化 Tensor Core 計算、softmax 計算以及內存操作之間的重疊執行; 前向傳播 (FWD):在 FMA 單元上通過多項式近似實現指數函數的軟件仿真 , 以提升指數計算吞吐量;同時引入條件式 softmax 重縮放(conditional softmax rescaling) , 跳過不必要的重縮放操作 , 從而緩解 SFU 瓶頸; 反向傳播 (BWD):利用張量內存 (TMEM) 存儲中間結果 , 以緩解共享內存流量壓力;同時 , 結合 Blackwell 新增的 2-CTA MMA 模式 , 進一步降低共享內存訪問 , 并將 atomic reduction 次數減少一半;此外 , 還支持確定性執行模式 , 以實現可復現訓練; 調度優化:引入新的 tile 調度器 , 解決因果掩碼和變長序列導致的負載不均衡 。Blackwell 的新硬件特性
張量內存(TMEM):在 B200 上 , 148 個 SM(流式多處理器)中的每一個都配備了 256 KB 的 TMEM , 與 Tensor Core 直接連接 , 用于 warp 同步的中間結果存儲 。
完全異步的第五代張量核心:指令 tcgen05.mma 支持異步執行 , 并將累加結果存儲在 TMEM 中 。 對于 BF16 和 FP16 , 單個 CTA 可使用的最大 UMMA tile 為 128×256×16 , 約為 Hopper 架構中最大 WGMMA 原子塊的 2 倍 。 UMMA 由單個線程發起 , 從而減輕寄存器壓力 , 使得在不出現 Hopper warpgroup MMA 那種寄存器溢出問題的情況下 , 可以更容易地使用更大的 tile 和更深的流水線 。
此外 , 這也使 warp 專門化更具可行性:部分 warp 負責搬運 tile , 另一些 warp 負責發起 MMA , 從而實現矩陣乘加運算與 softmax 計算以及內存訪問的重疊執行 。 tcgen05.mma 還可以直接從 TMEM 中讀取操作數 A 。
2-CTA MMA:Blackwell 支持在同一 cluster 中由一對 CTA 共同執行一個 UMMA 運算 , 并跨越兩個 CTA 的 TMEM 。 由 leader CTA 中的一個線程發起 MMA , 但在執行期間兩個 CTA 都必須保持活躍 。 通過在這對 CTA 之間拆分 M 和 N 維度 , 可以將 MMA 的 tile 尺寸擴展到 256×256×16 , 從而減少冗余數據傳輸并降低每個 CTA 的資源占用 。 在一個 kernel 中 , CTA 組大?。 ? 或 2)在 TMEM 操作和 Tensor Core 運算之間必須保持一致 。

編程語言與框架:CuTe-DSL
FlashAttention-4(FA4)完全使用 CuTe-DSL 實現 , 這是 CUTLASS 提供的 Python kernel DSL 。
Kernel 代碼使用 Python 編寫 , 隨后 DSL 會將其降級(lower 為 PTX , 再由 CUDA 工具鏈編譯為 GPU 機器代碼 。
該編程模型在抽象層面與 CuTe / CUTLASS 保持一致 , 同時提供 PTX 級別的 escape hatch(底層控制接口) 。 與使用 C++ 模板相比 , 這種方式可以將編譯時間縮短約 20–30 倍 。
對此 , Tri Dao 更是在 X 上發帖稱感到「莫名興奮」 , 這意味著 , 安裝 /「編譯」現在只需幾秒鐘 , 而不是幾分鐘 / 幾小時 。

Attention 性能基準測試
團隊展示了 FlashAttention-4 在 B200(BF16)上的性能結果 , 并將其與 FlashAttention-2 以及 Triton、Gluon 和 cuDNN 的實現進行了對比 。
結果顯示:
前向傳播(forward pass):FlashAttention-4 比 cuDNN 9.13 快 1.1–1.3 倍 , 比 Triton 實現快 2.1–2.7 倍 。 反向傳播(backward pass):在長序列長度場景下 , FlashAttention-4 的表現始終優于其他基準模型 。
【FlashAttention-4正式發布:算法流水線大改,矩陣乘法級速度】


而 FlashAttention-4 一經發布 , 也引起了大家的熱議 。
Pytorch 官方宣布 FlexAttention 現已支持 FlashAttention-4 后端 。

Pytorch 表示 , 很長一段時間以來 , FlexAttention 讓研究人員能夠快速原型化各種自定義 Attention 變體 , 目前已有 1000 多個代碼倉庫采用 , 并有數十篇論文對其進行了引用 。
然而 , 用戶常常會遇到性能瓶頸 , 直到 FlashAttention-4 的出現 。
如今 , 他們已在 Hopper 和 Blackwell GPU 上為 FlexAttention 增加了 FlashAttention-4 后端 。 PyTorch 現在可以自動生成 CuTeDSL 的 score/mask 修改代碼 , 并通過 JIT 編譯為自定義 Attention 變體實例化 FlashAttention-4 。
結果顯示 , 在算力受限的工作負載下 , 相比 Triton , 仍可實現 1.2 倍到 3.2 倍的性能提升 。 研究人員再也不必在「靈活性」和「高性能」之間做單選題 。
一位網友則認為 , 「FlashAttention-4 是一個里程碑 。 」在 Blackwell 架構上 , Attention 已經能夠達到接近矩陣乘法(matmul)速度 , 這意味著計算瓶頸將完全轉移到內存與通信上 。 約 1600TFLOPs 的 Attention 性能堪稱驚人 —— 相比 FlashAttention-3 提升了 2–3 倍 。 「這將直接惠及所有前沿大模型 。 」因為 , 更快的 Attention 意味著更長的有效上下文窗口、更低的推理成本、更強的規?;评砟芰Α?br />
更多內容 , 可查看論文原文獲?。 ?
參考鏈接:
https://x.com/tri_dao/status/2029569881151263082
https://tridao.me/blog/2026/flash4/

    推薦閱讀