FlashAttention解读
背景
随着深度学习技术的不断发展,Transformer模型在图像分类、自然语言处理等领域中逐渐占据了主导地位。然而,Transformer模型中的自注意力机制(Self-Attention)在处理长序列时面临计算复杂度和内存使用效率的挑战。传统的自注意力机制的时间复杂度和空间复杂度都与输入序列长度的平方成正比,这限制了模型处理更长序列的能力。因此,如何优化自注意力机制的计算效率和内存使用效率成为了一个重要的研究方向。
Flash Attention是一种旨在加速大模型中注意力计算的技术,它通过优化内存访问和计算流程,显著提高了计算速度和效率。随着技术的不断发展,Flash Attention已经推出了多个版本,并在大模型中得到了广泛应用。
简介
FlashAttention是一种旨在解决传统Transformer模型中自注意力机制计算复杂度和内存使用效率问题的新型注意力算法。它通过优化数据布局和计算流程,减少了内存访问开销,提高了计算效率。FlashAtention主要解决Transformer计算速度慢和存储占用高的问题,但与绝大多数Eficient
Transformer把改进方法集中在降低模型的计算复杂度FLOPS(foating point operations per second)不
同,FlashAttention将优化重点放在了降低存储访问开销(MemoryAccess Cost*,MAC)上。FlashAttention的出现,使得Transformer模型在处理长序列时更加高效,从而推动了深度学习在更多领域的应用。
大语言模型LLM普遍以Transformer作为核心基础组件,而Transformer的核心计算部分是SelfAttention自注意力模块。在标准的Transformer计算中,给定大小为(N,d)的三个矩阵 Q,K,V,标准的Self-Attentiont的计算如下(其中softmax前省略了Scale和Attention Mask处理):
其中 Q/K/V 由原始输入x,经过Linear线性变换得到,S在一些论文中被称为注意力得分Attention Scores, P是对S进行逐行softmax得到的结果,可理解为归一化的注意力得分Normalized Attention Scores,或注意力权重Attention Weights,O是最终的输出,N是序列长度seqlen,d是维度headdim。其中产生两个中间矩阵S和P,内存需求均是O(N^2),如果当seqlen也就是N很大时,就会消耗过多的内存,限制了模型处理更长序列的能力。
因此FlashAttention出现了,它不需要保留中的S和P矩阵,而是整个Attention计算融合到单个计算核心中。我们知道矩阵乘,具有分块和累加的特性,一个大的矩阵乘法,可以通过Tiling技术,分成小块的可以在片上计算的矩阵乘法,然后通过将各个分块矩阵乘的结果进行累加获得最后的正确结果。
遗憾的是Attention中的Softmax计算,并没有这种分块累加特性,它依赖于一个全局的分母项,必须要所有数据计算完成之后才能进行下一步操作(max、sub/exp/sum、div)。
而online softmax算法的出现,可以将前两步(max、sub/exp/sum)放在一个循环中处理,即分块处理后,存在一个刷新流程。
FlashAttention v1
FlashAttention V1的核心思想是将输入的查询(Query)、键(Key)和值(Value)矩阵切分成多个小块(tile),并在不同的计算块之间进行并行处理。通过优化数据布局和计算流程,以减少内存访问开销,从而提高计算效率。具体实现方式包括:
- 数据分块:将输入的查询(Query)、键(Key)和值(Value)矩阵切分成多个小块(tile),每个块的长度较短,以减少单次计算的复杂度。。这种分块处理可以有效利用GPU的有限内存带宽,避免一次性将整个矩阵加载到内存中导致的性能瓶颈。
- 并行计算:对每个块计算其自注意力,利用硬件的并行计算能力加速计算过程。通过并行处理,可以显著加速计算过程,尤其是对于长序列输入。
- 内存访问优化:FlashAttention V1 优化了内存访问模式,减少了不必要的内存读写操作。通过将小块数据加载到高带宽的SRAM中进行计算,并减少对低带宽HBM的访问频率,进一步提高了计算效率。
- 算子融合:在计算自注意力的核心部分(如 QK^T、Softmax 和加权求和)时,采用了算子融合技术,将多个计算步骤合并为一个步骤,以减少内存访问次数和计算复杂度。
FlashAttention V1通过减少内存访问次数和利用并行计算,显著提高了自注意力机制的计算效率。具体的FlashAttention在online softmax的基础上,在前后加入两个矩阵乘法,以及一些其他的向量计算,公式如下。
FlashAttention v1的核心特点:
计算快:通过降低对显存HBM的访问次数来加快整体运算速度,这主要得益于IO-Awareness技术,即分块计算(tiling)和核函数融合(kernel fusion)。
节省内存:通过特定的技术(如tiling分块机制)将forward和backward阶段的存储压力从O(N^2)降至O(N)。
精准注意力:实现了完全等同于标准attention的实现方式,保证了计算的准确性。
Flash Attention V1的出现,为大模型的注意力计算提供了新的解决方案。它已被成功应用于多个知名的大模型中,如GPT-3、Falcon2、Llama2等,显著提高了这些模型的计算速度和效率。
其伪代码如下。
def flash_attention_v1(Q, K, V, num_tiles):
# 假设Q, K, V的shape为[N, d],N为序列长度,d为维度
outputs = []
tile_size = N // num_tiles
for i in range(num_tiles):
# 提取当前tile的Q, K, V
Q_tile = Q[i * tile_size:(i + 1) * tile_size]
K_tile = K[i * tile_size:(i + 1) * tile_size]
V_tile = V[i * tile_size:(i + 1) * tile_size]
# 计算当前tile的自注意力
# 1. 分块矩阵乘法 mm1
S = matmul(Q_tile, K_tile.transpose(-2, -1))
# 2. 向量计算 vector1
S = mul(S , 1.0/torch.sqrt(torch.tensor(d))) # scale
S = select(S, atten_mask) # 掩码
# online softmax 1
P = sum(exp(sub(S, max(S, dim=-1))))
# 3. 分块矩阵乘法 mm2
O = matmul(P, V_tile)
# 4. 向量计算 vector2 online softmax 2 在线刷新
O_pre = mul(O_pre, softmax_exp)
O = O + O_pre
O_pre = O
outputs.append(O)
# 聚合结果
return torch.cat(outputs, dim=0)
FlashAttention v2
FlashAttention V2在V1的基础上进行了进一步优化,主要改进了循环交换和任务划分策略。
- 优化循环顺序
将V1计算逻辑中的内外循环相互交换,以此减少在shared memory上的读写次数,实现进一步提速。在 V1 中,外循环遍历 Query(Q),内循环遍历 Key(K)和 Value(V)。V2将外循环遍历 Key(K)和 Value(V),内循环遍历 Query(Q)。这种循环交换使得 Q 的数据只需从内存读取一次,减少了内存读写开销,并提高了并行度。
- 任务划分并行计算优化
在cuda层面配套做一些并行计算优化,以更好地利用GPU的并行计算能力。V2 优化了线程块和 warp 之间的任务划分,使得计算任务更加均衡地分配给不同的计算单元,从而提高了整体计算效率。
- 减少非矩阵乘法运算
V2 通过优化算法和实现方式,减少了非矩阵乘法运算的 FLOPs(浮点运算次数),进一步降低了计算复杂度。
- 动态分块
V2 可能引入了动态分块技术,根据序列的实际长度和计算资源动态调整块的大小,以达到最佳的性能和效率。
def flash_attention_v2(Q, K, V, num_tiles):
# 假设Q, K, V的shape为[N, d],N为序列长度,d为维度
outputs = []
tile_size = N // num_tiles
for j in range(num_tiles):
# 提取当前tile的K, V
K_tile = K[j * tile_size:(j + 1) * tile_size]
V_tile = V[j * tile_size:(j + 1) * tile_size]
# 对每个Q元素计算注意力
for i in range(N):
q_i = Q[i:i+1] # 提取单个Q元素
# 1. 分块矩阵乘法 mm1
S = matmul(q_i, K_tile.transpose(-2, -1))
# 2. 向量计算 vector1
S = mul(S , 1.0/torch.sqrt(torch.tensor(d))) # scale
S = select(S, atten_mask) # 掩码
# online softmax 1
P = sum(exp(sub(S, max(S, dim=-1))))
# 3. 分块矩阵乘法 mm2
O = matmul(P, V_tile)
# 4. 向量计算 vector2 online softmax 2 在线刷新
O_pre = mul(O_pre, softmax_exp)
O = O + O_pre
O_pre = O
if i == 0:
O_tile = O_i
else:
O_tile = torch.cat((O_tile, O_i), dim=0)
outputs.append(O_tile)
# 聚合结果
return torch.cat(outputs, dim=1) # 注意这里可能需要调整维度以匹配原始输出
FlashAttention v3
FlashAttention通过最小化内存读写来加速GPU上的注意力机制。 然而,它尚未利用最近硬件中的新功能,FlashAttention-2在H100 GPU上的利用率仅为35%。FlashAttention V3在V2的基础上进一步提升了性能,支持了更高效的硬件(如NVIDIA Hopper架构)和更精细的并行计算策略。V3可能引入了更多的优化技术,如更智能的线程管理和任务调度、更紧凑的数据结构等,以进一步提高计算效率和内存使用效率。
Flash Attention3提出了三种主要技术来加速Hopper GPU上的注意力机制:
- 更高效的并行计算,通过warp-specialization重叠整体计算和数据移动
生产者-消费者异步:定义了一种专门针对 warp 的软件流水线方案,通过将数据的生产者和消费者分成不同的 warp,利用数据移动和张量核心的异步执行,从而扩展算法隐藏内存和指令发出延迟的能力。
- 算法优化,交错块状矩阵乘法和softmax操作
将 softmax 中涉及的相对低吞吐量的非 GEMM 操作(如浮点乘加和指数运算)与 GEMM 的异步 WGMMA 指令重叠。在此过程中,我们重新设计了FlashAttention-2算法,以规避 softmax 和 GEMM 之间的某些顺序依赖。例如,在算法的两阶段版本中,当 softmax 在分数矩阵的一个块上执行时,WGMMA 在异步代理中执行以计算下一个块。
- 更紧凑的数据结构,利用硬件支持FP8低精度计算,以减少内存占用并提高内存利用率
调整了前向传递算法,以便针对FP8张量核心进行GEMM,几乎使测量的TFLOPs/s翻倍。这需要在WGMMA的不同布局一致性要求之间架起桥梁,因为FP32累加器和FP8操作数矩阵的内存布局假设不同。使用块量化和非相干处理技术来减轻转向FP8精度所导致的精度损失。最终,FlashAttention-3,在H100 GPU上实现了1.5-2.0×的加速,FP16达到最高740 TFLOPs/s(75%利用率),FP8接近1.2 PFLOPs/s。
利用Tensor Cores和TMA的异步性:通过warp专门化重叠整体计算和数据移动,实现更高效的计算流程。
交错块级matmul和softmax操作:优化计算过程中的数据移动和计算顺序,减少不必要的内存访问。
块量化和不连贯处理:利用FP8低精度的硬件支持,减少计算量同时保持较低的数值误差。
应用情况:
Flash Attention V3在H100 GPU上实现了显著的加速效果,FP16达到740 TFLOPs/s(75%利用率),FP8接近1.2 PFLOPs/s。这一版本的推出,进一步推动了大模型在计算能力和效率上的提升。它已被用于处理更长的序列和更复杂的任务,为大规模语言模型、图像处理等应用提供了强有力的支持。
总结
V1、V2、V3均旨在优化Transformer模型中自注意力机制的计算效率和内存使用效率。它们都采用了分块和并行计算的思想,但具体实现细节和性能优化策略有所不同。
- V1实现了基本的分块和并行计算策略。
- V2通过循环交换和任务划分策略进一步提高了并行度和减少了内存读写开销。
- V3可能引入了更高效的硬件支持和更精细的并行计算策略。
FlashAttention 的各个版本在数据分块、并行计算、内存访问优化、算子融合、循环交换、任务划分、算法优化和硬件适应性等方面进行了持续改进和优化,以提高计算效率和内存使用效率。
FlashAttention V1、V2、V3通过不断优化数据布局和计算流程,显著提高了Transformer模型中自注意力机制的计算效率和内存使用效率。这些优化技术不仅推动了深度学习在更多领域的应用,也为未来深度学习算法的发展提供了新的思路。