IncreFlashAttention源码分析
在自回归(Auto-regressive)语言模型的推理过程中,随着新词汇的不断生成,输入序列的长度持续增加,这对计算效率提出了严峻挑战。FlashAttention算子,作为一种高效的注意力机制实现,尤其在增量推理场景下展现出其独特优势。在此场景下,FlashAttention的query维度(S轴)被固定为1,而key和value则通过KV Cache机制,将先前推理过程中的状态信息累积并叠加,以适应每个Batch可能不同的实际长度。值得注意的是,尽管输入数据经过padding处理以维持固定长度,但FlashAttention能够灵活应对这种变化。此外,在全量推理场景中,尽管query的S轴大小不再固定,但FlashAttention的推理流程与增量推理保持一致,确保了算法的通用性和高效性。
实现原理
self-attention(自注意力)机制通过挖掘输入样本内部元素间的相互关系,构建了一种强大的注意力模型。具体而言,对于长度为n的输入序列x,其每个元素均为d维向量(即token embedding)。通过三个权重矩阵的变换,得到Q(查询)、K(键)、V(值)三个矩阵,分别代表输入样本在不同特征空间中的表示。self-attention的计算过程涉及Q与K的矩阵乘法、缩放、softmax归一化以及最终与V的矩阵乘法,这一系列操作共同构建了输入序列内部的注意力分布。
self-attention的计算公式一般定义如下,其中Q、K、V为输入样本的重要属性元素,是输入样本经过空间变换得到,且可以统一到一个特征空间中。公式及算子名称中的"Attention"为"self-attention"的简写。
其中Q和K转置的乘积代表输入x的注意力,为避免该值变得过大,通常除以d的开根号进行缩放,并对每行进行softmax归一化,与V相乘后得到一个n*d的矩阵。在上述基础上,考虑online-softmax等优化,FA的计算流程图如下所示。
按照flashAttention正向计算流程实现,整体计算流程如下:
query与转置后的key做matmul矩阵乘法计算后得到最初步的attention_score注意力得分,然后乘以缩放系数scale_value后再与位置编码pse相加。此时的结果通过atten_mask进行select掩码选择操作,将atten_mask中为true的位置进行遮蔽,得到结果masked_attention_score,即atten_mask中为true的位置在select后结果为负的极小值,经过softmax计算之后变成0从而达到遮蔽效果。
为了实现FlashAttention加速算法,使用FlashSoftmax操作对masked_attention_score进行运算,用以代替原公式中的softmax运算,而后将结果与value做matmul运算。由于FlashSoftmax操作对masked_attention_score的Skv(输入key、value的sequence length)方向进行了切分,故实现过程中存在一个刷新流程,具体如下:
- 当i = 0时,计算出的MM[PV]结果直接保存到attention_out[0]的ub中。
- 从i = 1开始,需要增加Mul和Add操作,即将上一次的MM[PV]的结果prev_mm2_res和当前exp相乘,相乘完的结果和本次MM[PV]的结果相加得到的结果保存到ttention_out[1]的ub中。以此类推,遍历tiling块完成计算。
- 由于FlashSoftmax计算中的除sum被后移到输出attention_out之前,因此最后需要将ub中的attention_out按行除以softmax_sum并将最终完整的结果保存到输出内存attention_out(Final)上。
源码分析
- IFA伪代码如下
// 单核计算伪码
void compute() {
// 1. 核内分块数量 loops
loops = blocks_to_compute_of_this_core(); // 当前核需要计算几个数据块
for (i = 0; i < loops; i++) {
// 2. 块内切分 innerloops
block = get_curr_block(i);
bidx, nidx, sidx = dims_of_this_block(block);
innerloops = get_inner_loops_of_this_block_by_actual_seq_len(bidx, nidx, sidx); // 数据块实际内切份数
q_offset = get_offset_of_query(bidx, nidx);
// softmax计算中间变量
softmax_sum = {0};
softmax_exp = {0};
softmax_max = {min_float};
for (j = 0; j < innerloops; j++) { // flash attention循环
kv_offset = get_offset_of_kv_block(j);
// 1. mm1:S=q@k转置
qk_res = matmul(q + qoffset, k + kv_offset);
// 2. vec1:mul+add+select
qk_res = elementwise(qk_res); // scale,pse, atten-mask
// 3. vec1:online softmax P=softmax(S)
qk_res, softmax_max, softmax_sum, softmax_exp = softmaxflash(qk_res, softmax_max, softmax_sum);
// 4. mm2:O=P@V
res = matmul(qk_res, v + kv_offset);
// 5. online softmax 刷新 flash attention update
prev_res = load_prev_res();// 上次结果
res += prev_res * softmax_exp; // O = O + O_pre * softmax_exp
store(res); // 保存本次结果
if (j == innerloops - 1) {
// 最后 执行softmax的 除sum过程
// sum 是所有块计算叠加全部完成后才得到最终的结果
res = div(res, softmax_sum);
output(res); // 输出结果到GM
}
}
}
}
- Ascend C融合算子开源仓中对应的IFA源码(简化保留主流程)
单核计算过程入口函数
// 1. 单核计算过程入口函数
Process() {
if (g_coreType == AIV && tmpBlockIdx >= usedCoreNum) {
// 跳过不需要的核心
} else {
// 1. 单核处理多块
for (uint32_t bn2Idx = 0; bn2Idx < bn2LoopTimes; bn2Idx++) {
// 2. 分块参数计算
GetBN2id(bn2Idx);
GetActualSeqLen();
// 计算BN2方向的offset
CalcBN2OffsetAndParams();
// 根据当前块实际长度, 重配flashattention循环条件
UpdateInnerLoopCond();
pipe_barrier(PIPE_V);
// 3. softmax复初始值
Duplicate(softmaxMaxUb, SOFTMAX_MIN_NUM, BUFFER_SIZE_BYTE_2K / sizeof(T)); // 最小值
Duplicate(softmaxSumUb, FLOAT_ZERO, BUFFER_SIZE_BYTE_2K / sizeof(T)); // 置0
// 4. 块内切分多次计算
// GQA场景需要处理G,1、mm1 A矩阵 singleM=G 2、mm1结果vector1内部切分mm1的M轴 3、涉及souter的地方,需要注意GQA
for (uint32_t sInnerLoopIdx = 0; sInnerLoopIdx < sInnerLoopTimes; sInnerLoopIdx++) {
// 5. 计算 块内切分参数 s2方向的offset
CalcSInnerOffsetAndParams(sInnerLoopIdx);
//6. 单次 fa计算核心
SInnerLoopFunc(bn2Idx, sInnerLoopIdx);
}
}
}
}
单次 fa计算流程
// 2. 单次 fa计算核心
SInnerLoopFunc(const uint32_t bn2Idx, const uint32_t sInnerLoopIdx) {
// 1. 配置 mm1 mm2 矩阵乘法高阶函数 原始矩阵大小
SetMMOrgShapeCommon();
// 2. mm1:q@k转置
Bmm1ComputeCommon(bn2Idx, sInnerLoopIdx);
// 3. v1: mul + add + select + flash_softmax
ProcessVec1Inner(sInnerLoopIdx);
// 4. mm2: p@v
Bmm2Compute(bn2Idx, sInnerLoopIdx);
// 5. v2: flash_softmax 刷新: mul+add div
ProcessVec2Inner(sInnerLoopIdx);
}
// 3.配置 mm1 mm2 矩阵乘法高阶函数 原始矩阵大小
SetMMOrgShapeCommon() {
if (curSingleProcessSInnerSizeAlign != actualSingleProcessSInnerSizeAlign) {
// mm1 setOrgShape
uint32_t orgKa = headDim;
mm.SetOrgShape(gSizeMulMsd, kvSeqSize, orgKa, kvHeadMulDim, actualSingleProcessSInnerSizeAlign);
// mm2 setOrgShape
bmm2.SetOrgShape(gSizeMulMsd, kvHeadMulDim, actualSingleProcessSInnerSizeAlign, kvSeqSize, headDimAlign);
// 更新curSingleProcessSInnerSizeAlign,为了下一次判断是否进行setOrgShape使用
curSingleProcessSInnerSizeAlign = actualSingleProcessSInnerSizeAlign;
}
}
mm1: Q@K转置 矩阵乘法cube运算
// 4. mm1: Q@K转置 矩阵乘法cube运算
Bmm1ComputeCommon(uint32_t bn2Idx, uint32_t sInnerLoopIdx) {
mm.SetTensorA(queryGm[tensorACoreOffset]); // 设置左矩阵Q
mm.SetTensorB(keyGm[tensorBOffset], true); // 设置右矩阵K 需要转置
mm.SetTail(gSizeMulMsd, actualSingleProcessSInnerSize, headDim);// 设置分块大小
mm.template IterateAll<false>(mm1ResGm, false, false, true); // 执行矩阵计算 结果存放 mm1ResGm
mm.WaitIterateAll(); // 等待计算完成
mm.End();
}
vec1: 向量计算,mm1之后的计算,mul+add+select+softmax
// 5. vec1: 向量计算,mm1之后的计算,mul+add+select+softmax
ProcessVec1Inner(uint32_t sInnerLoopIdx) {
// vec相比cube计算性能差,需要再分块分多次计算
uint32_t gSplitSize = BASE_BLOCK_MAX_ELEMENT_NUM / actualSingleProcessSInnerSizeAlign;
if (gSplitSize > gSize) {
gSplitSize = gSize; // 每次计算数量
}
uint32_t loopCount = (gSize + gSplitSize - 1) / gSplitSize; // 循环次数
uint32_t tailSplitSize = gSize - (loopCount - 1) * gSplitSize;// 尾次数量
for (uint32_t i = 0, dealSize = gSplitSize; i < loopCount; i++) {
if (i == (loopCount - 1)) {
dealSize = tailSplitSize; // 尾次
}
// vec1分块计算核心:q@k mm1之后的vector操作 mul+add+select+softmax
DealBmm1ResBaseBlock(sInnerLoopIdx, i * gSplitSize, dealSize, actualSingleProcessSInnerSizeAlign,
actualSingleProcessSInnerSize);
}
}
// 5.1 vec1分块计算核心:mm1之后的vector操作 mul+add+select+softmax
DealBmm1ResBaseBlock(uint32_t sInnerLoopIdx, uint32_t startRow, uint32_t dealRowCount,
uint32_t columnCount, uint32_t actualColumnCount) {
uint32_t computeSize = dealRowCount * columnCount;
LocalTensor<T> mmResUb = tmpBuff1.Get<T>();
size_t batchBase = 0;
// 1. 拿到mm1 结果
{
LocalTensor<MM_OUT_T> tmpMmResUb = inputQue1.AllocTensor<MM_OUT_T>();
DataCopy(tmpMmResUb, mm1ResGm[batchBase + startRow * columnCount], computeSize);
inputQue1.EnQue(tmpMmResUb);
inputQue1.DeQue<MM_OUT_T>();
DataCopy(mmResUb, tmpMmResUb, computeSize);
inputQue1.FreeTensor(tmpMmResUb);
pipe_barrier(PIPE_V);
}
// 2. mul+add+select
ElewiseCompute(mmResUb, tmpBuff2, startRow, dealRowCount, columnCount, actualColumnCount);
// 3. 在线 softmax
LocalTensor<T> tmpAFloorUb = tmpBuff2.Get<T>();
LocalTensor<uint8_t> softmaxTmpUb = tmpAFloorUb.template ReinterpretCast<uint8_t>();
SoftmaxFlashV2Compute(mmResUb, softmaxTmpUb, startRow, dealRowCount, columnCount, actualColumnCount);
pipe_barrier(PIPE_V);
// 4. 结果类型转换为 KV的数据类型
LocalTensor<KV_T> tmpMMResCastTensor = outputQue1.AllocTensor<KV_T>();
Cast(tmpMMResCastTensor, mmResUb, AscendC::RoundMode::CAST_ROUND, computeSize); // mm1+vec1之后的结果转换成和 mm2的 value类型相同
// 5. 拷贝到全局内存vec1ResGm中
outputQue1.EnQue(tmpMMResCastTensor);
outputQue1.DeQue<KV_T>();
DataCopy(vec1ResGm[batchBase + startRow * columnCount], tmpMMResCastTensor, computeSize);
outputQue1.FreeTensor(tmpMMResCastTensor);
}
// 5.2 vec1向量计算 mul+add+select
ElewiseCompute(LocalTensor<T>& mmResUb, TBuf<>& tmpBuf, uint32_t startRow,
uint32_t dealRowCount, uint32_t columnCount, uint32_t actualColumnCount) {
// 1. 乘 缩放 q@k * 1/sqrt(n)
Muls(mmResUb, mmResUb, static_cast<T>(tilingData->baseParams.scaleValue), dealRowCount * columnCount);
pipe_barrier(PIPE_V);
// 2. 加 位置编码 pse shift mask
if (pseShiftFlag) {
// 2.1 拷贝编码数据
PseShiftCopyIn(startRow, dealRowCount, actualColumnCount);
LocalTensor<pseShiftType> pseShiftUb = inputQue1.DeQue<pseShiftType>();
LocalTensor<float> pseShiftUbFloat = tmpBuf.Get<float>();
// 2.2 转换为浮点类型
for (uint32_t i = 0; i < dealRowCount; ++i) {
Cast(pseShiftUbFloat[i * columnCount], pseShiftUb[i * pseMaskSizeAlign], AscendC::RoundMode::CAST_NONE,
pseMaskSizeAlign);
}
// 2.3 加上位置编码
inputQue1.FreeTensor(pseShiftUb);
pipe_barrier(PIPE_V);
Add(mmResUb, mmResUb, pseShiftUbFloat, dealRowCount * columnCount);
pipe_barrier(PIPE_V);
}
// 3. 掩码处理 attenMask
if (attenMaskFlag == 1) {
// 3.1 拷贝掩码数据
AttenMaskCopyIn(attenMaskOffset, dealRowCount, actualColumnCount);
LocalTensor<bool> attenMaskUb = inputQue2.DeQue<bool>();
for (int i = 1; i < dealRowCount; i++) {
DataCopy(attenMaskUb[i * attenMaskSizeAlign], attenMaskUb, attenMaskSizeAlign);
}
pipe_barrier(PIPE_V);
// 3.2 select 执行掩码选取操作
LocalTensor<uint8_t> ubWorkSpace = tmpBuf.Get<uint8_t>(selectWithByteMaskTmpMinSize);
SelectWithBytesMaskShapeInfo selectWithBytesMaskShapeInfo;
selectWithBytesMaskShapeInfo.firstAxis = dealRowCount;
selectWithBytesMaskShapeInfo.srcLastAxis = columnCount;
selectWithBytesMaskShapeInfo.maskLastAxis = attenMaskSizeAlign;
attenMaskUb.SetSize(dealRowCount * attenMaskSizeAlign); // Select接口要求mask size与参数匹配
mmResUb.SetSize(dealRowCount * columnCount); // Select接口要求src size与参数匹配
SelectWithBytesMask(mmResUb, mmResUb, BOOL_ATTEN_MASK_SCALAR_VALUE, attenMaskUb, ubWorkSpace,
selectWithBytesMaskShapeInfo);
mmResUb.SetSize(BUFFER_SIZE_BYTE_32K / sizeof(T)); // mmResUb Size复原,mask不用复原,与原来一致
inputQue2.FreeTensor(attenMaskUb);
pipe_barrier(PIPE_V);
}
}
// 5.3vec1向量计算 在线softmax
SoftmaxFlashV2Compute(LocalTensor<T>& mmResUb, LocalTensor<uint8_t>& softmaxTmpUb,
uint32_t startRow, uint32_t dealRowCount,
uint32_t columnCount, uint32_t actualColumnCount) {
uint32_t baseOffset = startRow * BLOCK_ELEMENT_NUM;
SoftMaxShapeInfo srcShape = {dealRowCount, columnCount, dealRowCount, actualColumnCount};
// 计算 tiling分块参数
SoftMaxTiling newTiling =
SoftMaxFlashV2TilingFunc(srcShape, sizeof(T), sizeof(T), softmaxTmpUb.GetSize(), true, false);
// 执行 flash softmax运算
SoftmaxFlashV2<T, true, true, false, false, IFA_SOFTMAX_FLASHV2_CFG> (mmResUb, softmaxSumUb[baseOffset],
softmaxMaxUb[baseOffset], mmResUb, softmaxExpUb[baseOffset], softmaxSumUb[baseOffset], softmaxMaxUb[baseOffset],
softmaxTmpUb, newTiling, srcShape);
}
mm2: P@V 矩阵乘法cube运算
// 6. mm2: P@V 矩阵乘法cube运算
Bmm2ComputeCommon(uint32_t bn2Idx, uint32_t sInnerLoopIdx) {
bmm2.SetTensorA(vec1ResGm); // 设置左矩阵P=softmax(Q@K转置)
bmm2.SetTensorB(valueGm[valueOffset]);// 设置右矩阵V
bmm2.SetTail(gSizeMulMsd, headDim, actualSingleProcessSInnerSize);// 设置分块大小
bmm2.template IterateAll<false>(mm2ResGm, false, false, true);// 执行矩阵计算 结果存放 mm2ResGm
bmm2.WaitIterateAll();// 等待计算完成
bmm2.End();
}
vec2: 向量计算,mm2之后的计算 softmax的刷新操作 mul add div
// 7. vec2: 向量计算,mm2之后的计算 softmax的刷新操作 mul add div
ProcessVec2Inner(const uint32_t sInnerLoopIdx) {
uint32_t gSplitSize = BASE_BLOCK_MAX_ELEMENT_NUM / headDimAlign;
if (gSplitSize > gSize) {
gSplitSize = gSize; // 每次处理的数量
}
uint32_t loopCount = (gSize + gSplitSize - 1) / gSplitSize; // 循环次数
uint32_t tailSplitSize = gSize - (loopCount - 1) * gSplitSize;// 尾次数量
// 分块处理
for (uint32_t i = 0, dealSize = gSplitSize; i < loopCount; i++) {
if (i == (loopCount - 1)) {
dealSize = tailSplitSize; // 尾次
}
// vec2分块计算核心:p@v mm2之后的计算 softmax的刷新操作 mul add div
DealBmm2ResBaseBlock(sInnerLoopIdx, i * gSplitSize, dealSize, headDimAlign, headDim);
}
}
// 7.1 vec2分块计算核心:p@v mm2之后的计算 softmax的刷新操作 mul add div
DealBmm2ResBaseBlock(uint32_t sInnerLoopIdx, uint32_t startRow, uint32_t dealRowCount,
uint32_t columnCount, uint32_t actualColumnCount) {
uint32_t vec2ComputeSize = dealRowCount * columnCount;
uint32_t baseOffset = startRow * BLOCK_ELEMENT_NUM;
LocalTensor<T> bmm2ResUb = tmpBuff1.Get<T>();
bmm2ResUb.SetSize(vec2ComputeSize);
// 1. 获取mm2 计算结果
{
LocalTensor<MM_OUT_T> tmpBmm2ResUb = inputQue1.AllocTensor<MM_OUT_T>();
DataCopy(tmpBmm2ResUb, mm2ResGm[batchBase + startRow * columnCount], vec2ComputeSize);
inputQue1.EnQue(tmpBmm2ResUb);
inputQue1.DeQue<MM_OUT_T>();
DataCopy(bmm2ResUb, tmpBmm2ResUb, vec2ComputeSize);
inputQue1.FreeTensor(tmpBmm2ResUb);
}
// 除第一个循环外,均需要更新中间计算结果
if (sInnerLoopIdx > 0) {
// 2. 得到上次的结果 O_pre
LocalTensor<T> bmm2ResPreUb = inputQue2.AllocTensor<T>();
DataCopy(bmm2ResPreUb, vec2ResGm[batchBase + startRow * columnCount], vec2ComputeSize);
inputQue2.EnQue(bmm2ResPreUb);
inputQue2.DeQue<T>();
pipe_barrier(PIPE_V);
// 3. 更新结果 O_pre = mul(O_pre, softmax_exp)
RowMuls(bmm2ResPreUb, bmm2ResPreUb, softmaxExpUb[baseOffset], dealRowCount, columnCount, actualColumnCount);
pipe_barrier(PIPE_V);
// 4. 累加 O = O + O_pre
Add(bmm2ResUb, bmm2ResUb, bmm2ResPreUb, vec2ComputeSize);
inputQue2.FreeTensor(bmm2ResPreUb);
}
// 最后一次输出计算结果,否则将中间结果暂存至workspace
if (sInnerLoopIdx + 1 == sInnerLoopTimes) {
pipe_barrier(PIPE_V);
// 5. 最后一次 除以 softmaxSum
RowDivs(bmm2ResUb, bmm2ResUb, softmaxSumUb[baseOffset], dealRowCount, columnCount, actualColumnCount);
pipe_barrier(PIPE_V);
// 6. 数据类型转换
LocalTensor<OUT_T> tmpBmm2ResCastTensor = outputQue1.AllocTensor<OUT_T>();
Cast(tmpBmm2ResCastTensor, bmm2ResUb, AscendC::RoundMode::CAST_ROUND, dealRowCount * columnCount);
outputQue1.EnQue(tmpBmm2ResCastTensor);
outputQue1.DeQue<OUT_T>();
// 7. 结果拷贝 到 attentionOutGm 最终结果
DataCopyExtParams dataCopyParams;
dataCopyParams.blockCount = dealRowCount;
dataCopyParams.blockLen = actualColumnCount * sizeof(OUT_T);
dataCopyParams.srcStride = (columnCount - actualColumnCount) / (BYTE_BLOCK / sizeof(OUT_T));
dataCopyParams.dstStride = 0;
DataCopyPad(attentionOutGm[attenOutOffset + startRow * actualColumnCount], attenOutUb, dataCopyParams);
outputQue1.FreeTensor(tmpBmm2ResCastTensor);
} else {
// 8. 非最后一次 将中间结果暂存至workspace vec2ResGm
pipe_barrier(PIPE_V);
LocalTensor<T> tmpBmm2Res = outputQue1.AllocTensor<T>();
DataCopy(tmpBmm2Res, bmm2ResUb, dealRowCount * columnCount);
outputQue1.EnQue(tmpBmm2Res);
outputQue1.DeQue<T>();
//issue bmm2ResUb 为啥不 直接拷贝到 vec2ResGm ?
DataCopy(vec2ResGm[batchBase + startRow * columnCount], tmpBmm2Res, vec2ComputeSize);
outputQue1.FreeTensor(tmpBmm2Res);
}
}