背景
Transformer的复杂度
在标准的Transformer计算中,给定大小为(N, d)的三个矩阵Q、K、V,Self-Attention的计算如下:
复杂度就为O(dN^2)。
FLOPS & MAC
FLOPS直接定义了模型核心计算的密集程度,所以模型的计算量FLOPS与模型的计算速度有很大关系。学界有很多利用各种技巧来降低Transformer FLOPS的方法,通常将由这些方法改进得到的模型称为Efficient Transformer。但大多数Efficient Transformer通常只关注FLOPS。而实际上,FlashAttention的作者们发现,这些Efficient Transformer虽然能够有效降低模型的 FLOPS,但它们的计算速度并没有显著降低。 导致该现象的根本原因是模型的计算速度除了与FLOPS有很大关系,同时也与MAC(Memory Access Cost,存储访问开销)有关。尤其是当计算本身已经很高效的情况下,MAC的开销更加不能忽略。MAC的开销主要来自两方面。一是从存储中读取数据;二是向存储中写数据。与CPU的情况类似,在GPU中,当需要计算时,需将数据从显存中读取并由计算单元进行计算操作。在计算完毕后,再写回到显存中。
为了弄清MAC对计算速度的影响,可以根据计算的密集程度,将operator 分为两类: Compute-bound :计算密集型。整个计算的耗时主要在于计算本身,对显存的读写几乎可以忽略。典型的计算包括大矩阵乘法、大channel size的卷积操作等。对于这类operator,它们的 FLOPS决定了计算的时耗。 Memory-bound :存储访问密集型。整个计算的耗时主要集中在存储的访问上,计算本身耗时较低。典型的计算包括逐元素操作(ReLU,Dropout等)、以及Reduce操作 (求和、softmax、 BatchNorm 等)。对于这类operator,他们的MAC决定了计算的耗时。
在绝大多数的神经网络中,因为含有大量的Memory-bound操作,所以MAC的开销都不能忽略。但绝大多数Efficient Transformer都忽略了MAC,所以虽然它们整体的FLOPS都降低了,但计算耗时并没有降低。
核心思想
FlashAttention的目标是降低MAC,即使代价是增加了FLOPS。要理解FlashAttention的优化MAC的方法,需要简单了解一下GPU的结构。
上图所示为A100的存储结构示意图。其中GPU的存储主要由两部分构成:HBM (High Bandwidth Memory)和SRAM (Static Random-Access Memory)。从上图中可看出,SRAM的读写速度远大于HBM,但其存储空间则远小于HBM。标准Transformer的计算可以抽象为如下过程,它主要使用了HBM:
图中一共包含八次HBM的矩阵读写操作。这八次读写操作分别为:
• 第一行对Q,K的读,共两次,对S的写一次,总共三次;
• 第二行对S读一次,对P写一次,总共两次;
• 第一行对P,V的读,共两次,对O的写一次,总共三次。
为了减少对HBM的读写,FlashAttention将参与计算的矩阵进行分块送进SRAM,来提高整体读写
速度(减少了HBM读写)。
对于矩阵乘法而言,可以直接通过分块来达到分块计算的目的。但Self-Attention中有 softmax计算,而softmax的分母包含与所有元素相关的求和项,所以对Self-Attention进行分块计算的真正难点在于对softmax的分块计算。
计算流程
下图为FlashAttention伪代码:
MAC分析
标准Transformer MAC分析
根据参与计算的各矩阵大小,可以分析它的MAC次数(以访问单个float值为基准)。
Ø Line1:读Q、K的MAC次数为2Nd,写S的MAC次数为N2。
Ø Line2:读S的MAC次数为N2,写P的MAC次数为N2。
Ø Line3:读P的MAC次数为N2,读V的MAC次数为Nd,写O的MAC次数为Nd。
上述所有加起来的总MAC开销为4Nd+4N2。忽略掉其中的常数项,可将复杂度写为O(Nd+N2)。
FA MAC分析
主要关注图4中Q的开销(Q在内循环,它的MAC占绝大多数)。
Ø 一次完整的内循环需要读取完整的Q,MAC开销为Nd。
Ø 外循环决定了内循环执行的次数,即Tc。忽略掉常数项,可知FlashAttention的开销为O(N2d2M-1)。
因为M(100kB)通常远远大于d(d是head dimension,通常是64或128),所以FlashAttention的MAC远小于标准的Transformer。