跳转到内容

高效注意力:突破序列长度平方瓶颈

直觉版:不要让数据在内存里来回搬

Section titled “直觉版:不要让数据在内存里来回搬”

标准注意力的计算量随序列长度平方增长,更关键的是,它需要从 GPU 显存频繁读写巨大的注意力矩阵。FlashAttention 的直觉是:把计算拆成小块,在 GPU 的高速缓存(SRAM)里完成,减少慢速显存(HBM)的读写次数。这样既不近似也不损失精度,却能显著加速。

工程版:IO-aware 优化与内核融合

Section titled “工程版:IO-aware 优化与内核融合”

FlashAttention 的核心贡献是 IO-aware 算法:通过分块(tiling)和重计算(recomputation),将注意力计算的内存访问从 O(N²) 的 HBM 流量降低到接近 O(N)。FlashAttention-2 进一步优化了线程块划分和 warp 级调度;FlashAttention-3 则针对 Hopper 架构的异步执行和 FP8 做了专项优化。

除了 FlashAttention,长上下文工程还包括:

  • 稀疏注意力:滑动窗口、膨胀注意力、局部-全局混合,用近似降低计算量。
  • 线性注意力:通过核技巧或状态空间模型(SSM)将复杂度降到线性,代表工作如 Mamba。
  • 上下文压缩:把长文本压缩成更短的表示,减少需要参与注意力的 token 数。

选择方案时要评估:是否支持任意因果 mask?是否兼容现有训练框架?对短序列有没有额外开销?以及,在真实长文本任务上的端到端增益。

研究版:注意力是否必须平方?

Section titled “研究版:注意力是否必须平方?”

研究上,一个根本问题是:Transformer 的二次复杂度是否是必要的?状态空间模型、RWKV、RetNet 等工作试图在保留长程依赖能力的同时实现线性复杂度。

但注意力本身也有独特优势:动态路由、可解释性强、对训练数据分布的依赖相对较小。未来的架构可能是混合的:局部用线性方法,全局保留标准注意力,或通过 learned routing 动态选择计算模式。

本文引用论文

  • FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness — Tri Dao et al. (2022)

    FlashAttention 通过 IO-aware 的分块计算,在不牺牲精度的前提下,将注意力计算的内存 从 O(N²) 降至 O(N),速度提升 2-4 倍。它改变了长上下文训练的可行性边界, 是现代高效 LLM 训练和推理不可或缺的底层优化。

  • dao2023-flashattention2

    用更激进的 warp 级并行和 work partition 把 FlashAttention 再翻倍。今天 vLLM / SGLang / Megatron 训练后端基本都升级到 FA-2。

  • shah2024-flashattention3

    利用 H100 的异步 TMA 与 FP8,把 attention 推到 1.2 PFLOPs,并保持数值精度。是 Hopper 架构上长上下文 + FP8 训练的关键依赖。