searchusermenu
  • 发布文章
  • 消息中心
点赞
收藏
评论
分享
原创

FlashAttention

2024-06-26 09:44:45
2
0

背景

Transformer的复杂度

在标准的Transformer计算中,给定大小为(N, d)的三个矩阵QKVSelf-Attention的计算如下:

复杂度就为O(dN^2)。

FLOPS & MAC

        FLOPS直接定义了模型核心计算的密集程度,所以模型的计算量FLOPS与模型的计算速度有很大关系。学界有很多利用各种技巧来降低Transformer FLOPS的方法,通常将由这些方法改进得到的模型称为Efficient Transformer。但大多数Efficient Transformer通常只关注FLOPS。而实际上,FlashAttention的作者们发现,这些Efficient Transformer虽然能够有效降低模型的 FLOPS,但它们的计算速度并没有显著降低。 导致该现象的根本原因是模型的计算速度除了与FLOPS有很大关系,同时也与MACMemory Access Cost,存储访问开销)有关。尤其是当计算本身已经很高效的情况下,MAC的开销更加不能忽略。MAC的开销主要来自两方面。一是从存储中读取数据;二是向存储中写数据。与CPU的情况类似,在GPU中,当需要计算时,需将数据从显存中读取并由计算单元进行计算操作。在计算完毕后,再写回到显存中。

        为了弄清MAC对计算速度的影响,可以根据计算的密集程度,将operator 分为两类: Compute-bound :计算密集型。整个计算的耗时主要在于计算本身,对显存的读写几乎可以忽略。典型的计算包括大矩阵乘法、大channel size的卷积操作等。对于这类operator,它们的 FLOPS决定了计算的时耗。 Memory-bound :存储访问密集型。整个计算的耗时主要集中在存储的访问上,计算本身耗时较低。典型的计算包括逐元素操作(ReLUDropout等)、以及Reduce操作 (求和、softmax BatchNorm 等)。对于这类operator,他们的MAC决定了计算的耗时。

在绝大多数的神经网络中,因为含有大量的Memory-bound操作,所以MAC的开销都不能忽略。但绝大多数Efficient Transformer都忽略了MAC,所以虽然它们整体的FLOPS都降低了,但计算耗时并没有降低。

核心思想

       FlashAttention的目标是降低MAC,即使代价是增加了FLOPS。要理解FlashAttention的优化MAC的方法,需要简单了解一下GPU的结构。

上图所示为A100的存储结构示意图。其中GPU的存储主要由两部分构成:HBM High Bandwidth Memory)和SRAM Static Random-Access Memory)。从上图中可看出,SRAM的读写速度远大于HBM,但其存储空间则远小于HBM。标准Transformer的计算可以抽象为如下过程,它主要使用了HBM

中一共包含八次HBM的矩阵读写操作。这八次读写操作分别为:

第一行对Q,K的读,共两次,对S的写一次,总共三次;

第二行对S读一次,对P写一次,总共两次;

第一行对P,V的读,共两次,对O的写一次,总共三次。

为了减少对HBM的读写,FlashAttention将参与计算的矩阵进行分块送进SRAM,来提高整体读写

速度(减少了HBM读写)。

对于矩阵乘法而言,可以直接通过分块来达到分块计算的目的。但Self-Attention中有 softmax计算,而softmax的分母包含与所有元素相关的求和项,所以对Self-Attention进行分块计算的真正难点在于对softmax的分块计算。

计算流程

下图为FlashAttention伪代码:

MAC分析

标准Transformer MAC分析

根据参与计算的各矩阵大小,可以分析它的MAC次数(以访问单个float值为基准)

Ø  Line1:读QKMAC次数为2Nd,写SMAC次数为N2

Ø  Line2:读SMAC次数为N2,写PMAC次数为N2

Ø  Line3:读PMAC次数为N2,读VMAC次数为Nd,写OMAC次数为Nd

上述所有加起来的总MAC开销为4Nd+4N2。忽略掉其中的常数项,可将复杂度写为O(Nd+N2)

FA MAC分析

主要关注图4Q的开销(Q在内循环,它的MAC占绝大多数)。

Ø  一次完整的内循环需要读取完整的QMAC开销为Nd

Ø       外循环决定了内循环执行的次数,即Tc。忽略掉常数项,可知FlashAttention的开销为O(N2d2M-1)

因为M(100kB)通常远远大于ddhead dimension,通常是64128),所以FlashAttentionMAC远小于标准的Transformer

 

0条评论
0 / 1000
CY
4文章数
0粉丝数
CY
4 文章 | 0 粉丝