searchusermenu
  • 发布文章
  • 消息中心
点赞
收藏
评论
分享
原创

IncreFlashAttention源码分析

2024-09-18 09:21:16
65
0

IncreFlashAttention源码分析

在自回归(Auto-regressive)语言模型的推理过程中,随着新词汇的不断生成,输入序列的长度持续增加,这对计算效率提出了严峻挑战。FlashAttention算子,作为一种高效的注意力机制实现,尤其在增量推理场景下展现出其独特优势。在此场景下,FlashAttention的query维度(S轴)被固定为1,而key和value则通过KV Cache机制,将先前推理过程中的状态信息累积并叠加,以适应每个Batch可能不同的实际长度。值得注意的是,尽管输入数据经过padding处理以维持固定长度,但FlashAttention能够灵活应对这种变化。此外,在全量推理场景中,尽管query的S轴大小不再固定,但FlashAttention的推理流程与增量推理保持一致,确保了算法的通用性和高效性。

实现原理

self-attention(自注意力)机制通过挖掘输入样本内部元素间的相互关系,构建了一种强大的注意力模型。具体而言,对于长度为n的输入序列x,其每个元素均为d维向量(即token embedding)。通过三个权重矩阵的变换,得到Q(查询)、K(键)、V(值)三个矩阵,分别代表输入样本在不同特征空间中的表示。self-attention的计算过程涉及Q与K的矩阵乘法、缩放、softmax归一化以及最终与V的矩阵乘法,这一系列操作共同构建了输入序列内部的注意力分布。

self-attention的计算公式一般定义如下,其中Q、K、V为输入样本的重要属性元素,是输入样本经过空间变换得到,且可以统一到一个特征空间中。公式及算子名称中的"Attention"为"self-attention"的简写。

sat.JPG

其中Q和K转置的乘积代表输入x的注意力,为避免该值变得过大,通常除以d的开根号进行缩放,并对每行进行softmax归一化,与V相乘后得到一个n*d的矩阵。在上述基础上,考虑online-softmax等优化,FA的计算流程图如下所示。

IFA.png

按照flashAttention正向计算流程实现,整体计算流程如下:

query与转置后的key做matmul矩阵乘法计算后得到最初步的attention_score注意力得分,然后乘以缩放系数scale_value后再与位置编码pse相加。此时的结果通过atten_mask进行select掩码选择操作,将atten_mask中为true的位置进行遮蔽,得到结果masked_attention_score,即atten_mask中为true的位置在select后结果为负的极小值,经过softmax计算之后变成0从而达到遮蔽效果。

为了实现FlashAttention加速算法,使用FlashSoftmax操作对masked_attention_score进行运算,用以代替原公式中的softmax运算,而后将结果与value做matmul运算。由于FlashSoftmax操作对masked_attention_score的Skv(输入key、value的sequence length)方向进行了切分,故实现过程中存在一个刷新流程,具体如下:

  1. 当i = 0时,计算出的MM[PV]结果直接保存到attention_out[0]的ub中。
  2. 从i = 1开始,需要增加Mul和Add操作,即将上一次的MM[PV]的结果prev_mm2_res和当前exp相乘,相乘完的结果和本次MM[PV]的结果相加得到的结果保存到ttention_out[1]的ub中。以此类推,遍历tiling块完成计算。
  3. 由于FlashSoftmax计算中的除sum被后移到输出attention_out之前,因此最后需要将ub中的attention_out按行除以softmax_sum并将最终完整的结果保存到输出内存attention_out(Final)上。

源码分析

  • IFA伪代码如下
// 单核计算伪码
void compute() {
  // 1. 核内分块数量 loops
  loops = blocks_to_compute_of_this_core(); // 当前核需要计算几个数据块
  for (i = 0; i < loops; i++) {
    // 2. 块内切分 innerloops
    block = get_curr_block(i);
    bidx, nidx, sidx = dims_of_this_block(block);
    innerloops = get_inner_loops_of_this_block_by_actual_seq_len(bidx, nidx, sidx); // 数据块实际内切份数 
    q_offset = get_offset_of_query(bidx, nidx);
    
    // softmax计算中间变量
    softmax_sum = {0};
    softmax_exp = {0};
    softmax_max = {min_float};
      
    for (j = 0; j < innerloops; j++) { // flash attention循环
      kv_offset = get_offset_of_kv_block(j);
      
      // 1. mm1:S=q@k转置
      qk_res = matmul(q + qoffset, k + kv_offset);
      // 2. vec1:mul+add+select
      qk_res = elementwise(qk_res); // scale,pse, atten-mask
      // 3. vec1:online softmax P=softmax(S)
      qk_res, softmax_max, softmax_sum, softmax_exp = softmaxflash(qk_res, softmax_max, softmax_sum);
      // 4. mm2:O=P@V
      res = matmul(qk_res, v + kv_offset);
      // 5. online softmax 刷新 flash attention update
      prev_res = load_prev_res();// 上次结果
      res += prev_res * softmax_exp;  // O = O + O_pre * softmax_exp
      store(res); // 保存本次结果

      if (j == innerloops - 1) {
        // 最后 执行softmax的 除sum过程
        // sum 是所有块计算叠加全部完成后才得到最终的结果
        res = div(res, softmax_sum);
        output(res); // 输出结果到GM
      }
    }
  }
}
  • Ascend C融合算子开源仓中对应的IFA源码(简化保留主流程)

单核计算过程入口函数

// 1. 单核计算过程入口函数
Process() {
  if (g_coreType == AIV && tmpBlockIdx >= usedCoreNum) {
    // 跳过不需要的核心
  } else {
    // 1. 单核处理多块
    for (uint32_t bn2Idx = 0; bn2Idx < bn2LoopTimes; bn2Idx++) {
      // 2. 分块参数计算
      GetBN2id(bn2Idx);
      GetActualSeqLen();
      // 计算BN2方向的offset
      CalcBN2OffsetAndParams();
      // 根据当前块实际长度, 重配flashattention循环条件
      UpdateInnerLoopCond();
      pipe_barrier(PIPE_V);

      // 3. softmax复初始值
      Duplicate(softmaxMaxUb, SOFTMAX_MIN_NUM, BUFFER_SIZE_BYTE_2K / sizeof(T)); // 最小值
      Duplicate(softmaxSumUb, FLOAT_ZERO, BUFFER_SIZE_BYTE_2K / sizeof(T)); // 置0

      // 4. 块内切分多次计算
      // GQA场景需要处理G,1、mm1 A矩阵 singleM=G 2、mm1结果vector1内部切分mm1的M轴 3、涉及souter的地方,需要注意GQA
      for (uint32_t sInnerLoopIdx = 0; sInnerLoopIdx < sInnerLoopTimes; sInnerLoopIdx++) {
        // 5. 计算 块内切分参数 s2方向的offset
        CalcSInnerOffsetAndParams(sInnerLoopIdx);
        
        //6. 单次 fa计算核心
        SInnerLoopFunc(bn2Idx, sInnerLoopIdx);
      }
    }
  }
}

单次 fa计算流程

// 2. 单次 fa计算核心
SInnerLoopFunc(const uint32_t bn2Idx, const uint32_t sInnerLoopIdx) {
  // 1. 配置 mm1 mm2 矩阵乘法高阶函数 原始矩阵大小
  SetMMOrgShapeCommon();
  // 2. mm1:q@k转置
  Bmm1ComputeCommon(bn2Idx, sInnerLoopIdx);
  // 3. v1: mul + add + select + flash_softmax
  ProcessVec1Inner(sInnerLoopIdx);
  // 4. mm2: p@v
  Bmm2Compute(bn2Idx, sInnerLoopIdx);
  // 5. v2: flash_softmax 刷新: mul+add div
  ProcessVec2Inner(sInnerLoopIdx);
}

// 3.配置 mm1 mm2 矩阵乘法高阶函数 原始矩阵大小
SetMMOrgShapeCommon() {
  if (curSingleProcessSInnerSizeAlign != actualSingleProcessSInnerSizeAlign) {
    // mm1 setOrgShape
    uint32_t orgKa = headDim;
    mm.SetOrgShape(gSizeMulMsd, kvSeqSize, orgKa, kvHeadMulDim, actualSingleProcessSInnerSizeAlign);
    // mm2 setOrgShape
    bmm2.SetOrgShape(gSizeMulMsd, kvHeadMulDim, actualSingleProcessSInnerSizeAlign, kvSeqSize, headDimAlign);
    // 更新curSingleProcessSInnerSizeAlign,为了下一次判断是否进行setOrgShape使用
    curSingleProcessSInnerSizeAlign = actualSingleProcessSInnerSizeAlign;
  }
}

mm1: Q@K转置 矩阵乘法cube运算

// 4. mm1: Q@K转置 矩阵乘法cube运算
Bmm1ComputeCommon(uint32_t bn2Idx, uint32_t sInnerLoopIdx) {
  mm.SetTensorA(queryGm[tensorACoreOffset]); // 设置左矩阵Q
  mm.SetTensorB(keyGm[tensorBOffset], true); // 设置右矩阵K 需要转置
  mm.SetTail(gSizeMulMsd, actualSingleProcessSInnerSize, headDim);// 设置分块大小
  mm.template IterateAll<false>(mm1ResGm, false, false, true); // 执行矩阵计算 结果存放 mm1ResGm
  mm.WaitIterateAll(); // 等待计算完成
  mm.End();
}

vec1: 向量计算,mm1之后的计算,mul+add+select+softmax

// 5. vec1: 向量计算,mm1之后的计算,mul+add+select+softmax
ProcessVec1Inner(uint32_t sInnerLoopIdx) {
  // vec相比cube计算性能差,需要再分块分多次计算
  uint32_t gSplitSize = BASE_BLOCK_MAX_ELEMENT_NUM / actualSingleProcessSInnerSizeAlign;
  if (gSplitSize > gSize) {
    gSplitSize = gSize; // 每次计算数量
  }
  uint32_t loopCount = (gSize + gSplitSize - 1) / gSplitSize; // 循环次数
  uint32_t tailSplitSize = gSize - (loopCount - 1) * gSplitSize;// 尾次数量

  for (uint32_t i = 0, dealSize = gSplitSize; i < loopCount; i++) {
    if (i == (loopCount - 1)) {
      dealSize = tailSplitSize; // 尾次
    }
    // vec1分块计算核心:q@k mm1之后的vector操作 mul+add+select+softmax
    DealBmm1ResBaseBlock(sInnerLoopIdx, i * gSplitSize, dealSize, actualSingleProcessSInnerSizeAlign,
                          actualSingleProcessSInnerSize);
  }
}


// 5.1 vec1分块计算核心:mm1之后的vector操作 mul+add+select+softmax
DealBmm1ResBaseBlock(uint32_t sInnerLoopIdx, uint32_t startRow, uint32_t dealRowCount, 
                     uint32_t columnCount, uint32_t actualColumnCount) {
  uint32_t computeSize = dealRowCount * columnCount;
  LocalTensor<T> mmResUb = tmpBuff1.Get<T>();
  size_t batchBase = 0;
  // 1. 拿到mm1 结果
  {
    LocalTensor<MM_OUT_T> tmpMmResUb = inputQue1.AllocTensor<MM_OUT_T>();
    DataCopy(tmpMmResUb, mm1ResGm[batchBase + startRow * columnCount], computeSize);
    inputQue1.EnQue(tmpMmResUb);
    inputQue1.DeQue<MM_OUT_T>();
    DataCopy(mmResUb, tmpMmResUb, computeSize);
    inputQue1.FreeTensor(tmpMmResUb);
    pipe_barrier(PIPE_V);
  }
  
  // 2. mul+add+select
  ElewiseCompute(mmResUb, tmpBuff2, startRow, dealRowCount, columnCount, actualColumnCount);
  
  // 3. 在线 softmax
  LocalTensor<T> tmpAFloorUb = tmpBuff2.Get<T>();
  LocalTensor<uint8_t> softmaxTmpUb = tmpAFloorUb.template ReinterpretCast<uint8_t>();
  SoftmaxFlashV2Compute(mmResUb, softmaxTmpUb, startRow, dealRowCount, columnCount, actualColumnCount);
  pipe_barrier(PIPE_V);
  
  // 4. 结果类型转换为 KV的数据类型
  LocalTensor<KV_T> tmpMMResCastTensor = outputQue1.AllocTensor<KV_T>();
  Cast(tmpMMResCastTensor, mmResUb, AscendC::RoundMode::CAST_ROUND, computeSize); // mm1+vec1之后的结果转换成和 mm2的 value类型相同
  
  // 5. 拷贝到全局内存vec1ResGm中
  outputQue1.EnQue(tmpMMResCastTensor);
  outputQue1.DeQue<KV_T>();
  DataCopy(vec1ResGm[batchBase + startRow * columnCount], tmpMMResCastTensor, computeSize);
  outputQue1.FreeTensor(tmpMMResCastTensor);
}


// 5.2 vec1向量计算 mul+add+select
ElewiseCompute(LocalTensor<T>& mmResUb, TBuf<>& tmpBuf, uint32_t startRow,
               uint32_t dealRowCount, uint32_t columnCount, uint32_t actualColumnCount) {
  // 1. 乘 缩放  q@k * 1/sqrt(n)
  Muls(mmResUb, mmResUb, static_cast<T>(tilingData->baseParams.scaleValue), dealRowCount * columnCount);
  pipe_barrier(PIPE_V);

  // 2. 加 位置编码 pse shift mask
  if (pseShiftFlag) {
    // 2.1 拷贝编码数据
    PseShiftCopyIn(startRow, dealRowCount, actualColumnCount);
    LocalTensor<pseShiftType> pseShiftUb = inputQue1.DeQue<pseShiftType>();
    LocalTensor<float> pseShiftUbFloat = tmpBuf.Get<float>();
    // 2.2 转换为浮点类型
    for (uint32_t i = 0; i < dealRowCount; ++i) {
      Cast(pseShiftUbFloat[i * columnCount], pseShiftUb[i * pseMaskSizeAlign], AscendC::RoundMode::CAST_NONE,
           pseMaskSizeAlign);
    }
    // 2.3 加上位置编码
    inputQue1.FreeTensor(pseShiftUb);
    pipe_barrier(PIPE_V);
    Add(mmResUb, mmResUb, pseShiftUbFloat, dealRowCount * columnCount);
    pipe_barrier(PIPE_V);
  }

  // 3. 掩码处理 attenMask
  if (attenMaskFlag == 1) {
    // 3.1 拷贝掩码数据
    AttenMaskCopyIn(attenMaskOffset, dealRowCount, actualColumnCount);
    LocalTensor<bool> attenMaskUb = inputQue2.DeQue<bool>();
    for (int i = 1; i < dealRowCount; i++) {
      DataCopy(attenMaskUb[i * attenMaskSizeAlign], attenMaskUb, attenMaskSizeAlign);
    }
    pipe_barrier(PIPE_V);
    
    // 3.2 select 执行掩码选取操作
    LocalTensor<uint8_t> ubWorkSpace = tmpBuf.Get<uint8_t>(selectWithByteMaskTmpMinSize);
    SelectWithBytesMaskShapeInfo selectWithBytesMaskShapeInfo;
    selectWithBytesMaskShapeInfo.firstAxis = dealRowCount;
    selectWithBytesMaskShapeInfo.srcLastAxis = columnCount;
    selectWithBytesMaskShapeInfo.maskLastAxis = attenMaskSizeAlign;
    attenMaskUb.SetSize(dealRowCount * attenMaskSizeAlign);  // Select接口要求mask size与参数匹配
    mmResUb.SetSize(dealRowCount * columnCount);             // Select接口要求src size与参数匹配
    SelectWithBytesMask(mmResUb, mmResUb, BOOL_ATTEN_MASK_SCALAR_VALUE, attenMaskUb, ubWorkSpace,
                        selectWithBytesMaskShapeInfo);
    mmResUb.SetSize(BUFFER_SIZE_BYTE_32K / sizeof(T));  // mmResUb Size复原,mask不用复原,与原来一致
    inputQue2.FreeTensor(attenMaskUb);

    pipe_barrier(PIPE_V);
  }
}

// 5.3vec1向量计算 在线softmax
SoftmaxFlashV2Compute(LocalTensor<T>& mmResUb, LocalTensor<uint8_t>& softmaxTmpUb, 
                      uint32_t startRow, uint32_t dealRowCount,
                      uint32_t columnCount, uint32_t actualColumnCount) {
  uint32_t baseOffset = startRow * BLOCK_ELEMENT_NUM;
  SoftMaxShapeInfo srcShape = {dealRowCount, columnCount, dealRowCount, actualColumnCount};
  // 计算 tiling分块参数
  SoftMaxTiling newTiling =
    SoftMaxFlashV2TilingFunc(srcShape, sizeof(T), sizeof(T), softmaxTmpUb.GetSize(), true, false);
  // 执行 flash softmax运算
  SoftmaxFlashV2<T, true, true, false, false, IFA_SOFTMAX_FLASHV2_CFG> (mmResUb, softmaxSumUb[baseOffset],
    softmaxMaxUb[baseOffset], mmResUb, softmaxExpUb[baseOffset], softmaxSumUb[baseOffset], softmaxMaxUb[baseOffset],
    softmaxTmpUb, newTiling, srcShape);
}

mm2: P@V 矩阵乘法cube运算

// 6. mm2: P@V 矩阵乘法cube运算
Bmm2ComputeCommon(uint32_t bn2Idx, uint32_t sInnerLoopIdx) {

  bmm2.SetTensorA(vec1ResGm); // 设置左矩阵P=softmax(Q@K转置)
  bmm2.SetTensorB(valueGm[valueOffset]);// 设置右矩阵V
  bmm2.SetTail(gSizeMulMsd, headDim, actualSingleProcessSInnerSize);// 设置分块大小
  bmm2.template IterateAll<false>(mm2ResGm, false, false, true);// 执行矩阵计算 结果存放 mm2ResGm
  bmm2.WaitIterateAll();// 等待计算完成
  bmm2.End();
}

vec2: 向量计算,mm2之后的计算 softmax的刷新操作 mul add div

// 7. vec2: 向量计算,mm2之后的计算 softmax的刷新操作 mul add div
ProcessVec2Inner(const uint32_t sInnerLoopIdx) {
  uint32_t gSplitSize = BASE_BLOCK_MAX_ELEMENT_NUM / headDimAlign;
  if (gSplitSize > gSize) {
    gSplitSize = gSize; // 每次处理的数量
  }
  uint32_t loopCount = (gSize + gSplitSize - 1) / gSplitSize; // 循环次数
  uint32_t tailSplitSize = gSize - (loopCount - 1) * gSplitSize;// 尾次数量
  
  // 分块处理
  for (uint32_t i = 0, dealSize = gSplitSize; i < loopCount; i++) {
    if (i == (loopCount - 1)) {
      dealSize = tailSplitSize; // 尾次
    }
    // vec2分块计算核心:p@v mm2之后的计算 softmax的刷新操作 mul add div
    DealBmm2ResBaseBlock(sInnerLoopIdx, i * gSplitSize, dealSize, headDimAlign, headDim);
  }
}

// 7.1 vec2分块计算核心:p@v mm2之后的计算 softmax的刷新操作 mul add div
DealBmm2ResBaseBlock(uint32_t sInnerLoopIdx, uint32_t startRow, uint32_t dealRowCount,
                     uint32_t columnCount, uint32_t actualColumnCount) {
  uint32_t vec2ComputeSize = dealRowCount * columnCount;
  uint32_t baseOffset = startRow * BLOCK_ELEMENT_NUM;
  
  LocalTensor<T> bmm2ResUb = tmpBuff1.Get<T>();
  bmm2ResUb.SetSize(vec2ComputeSize);
  // 1. 获取mm2 计算结果
  {
    LocalTensor<MM_OUT_T> tmpBmm2ResUb = inputQue1.AllocTensor<MM_OUT_T>();
    DataCopy(tmpBmm2ResUb, mm2ResGm[batchBase + startRow * columnCount], vec2ComputeSize);
    inputQue1.EnQue(tmpBmm2ResUb);
    inputQue1.DeQue<MM_OUT_T>();
    DataCopy(bmm2ResUb, tmpBmm2ResUb, vec2ComputeSize);
    inputQue1.FreeTensor(tmpBmm2ResUb);
  }

  // 除第一个循环外,均需要更新中间计算结果
  if (sInnerLoopIdx > 0) {
    // 2. 得到上次的结果 O_pre
    LocalTensor<T> bmm2ResPreUb = inputQue2.AllocTensor<T>();
    DataCopy(bmm2ResPreUb, vec2ResGm[batchBase + startRow * columnCount], vec2ComputeSize);
    inputQue2.EnQue(bmm2ResPreUb);
    inputQue2.DeQue<T>();
    pipe_barrier(PIPE_V);

    // 3. 更新结果 O_pre = mul(O_pre, softmax_exp)
    RowMuls(bmm2ResPreUb, bmm2ResPreUb, softmaxExpUb[baseOffset], dealRowCount, columnCount, actualColumnCount);
    pipe_barrier(PIPE_V);

    // 4. 累加 O = O + O_pre
    Add(bmm2ResUb, bmm2ResUb, bmm2ResPreUb, vec2ComputeSize);
    inputQue2.FreeTensor(bmm2ResPreUb);
  }

  // 最后一次输出计算结果,否则将中间结果暂存至workspace
  if (sInnerLoopIdx + 1 == sInnerLoopTimes) {
    pipe_barrier(PIPE_V);
    // 5. 最后一次 除以 softmaxSum
    RowDivs(bmm2ResUb, bmm2ResUb, softmaxSumUb[baseOffset], dealRowCount, columnCount, actualColumnCount);
    pipe_barrier(PIPE_V);

    // 6. 数据类型转换
    LocalTensor<OUT_T> tmpBmm2ResCastTensor = outputQue1.AllocTensor<OUT_T>();
    Cast(tmpBmm2ResCastTensor, bmm2ResUb, AscendC::RoundMode::CAST_ROUND, dealRowCount * columnCount);
    outputQue1.EnQue(tmpBmm2ResCastTensor);
    outputQue1.DeQue<OUT_T>();

    // 7. 结果拷贝 到 attentionOutGm 最终结果
    DataCopyExtParams dataCopyParams;
    dataCopyParams.blockCount = dealRowCount;
    dataCopyParams.blockLen = actualColumnCount * sizeof(OUT_T);
    dataCopyParams.srcStride = (columnCount - actualColumnCount) / (BYTE_BLOCK / sizeof(OUT_T));
    dataCopyParams.dstStride = 0;
    DataCopyPad(attentionOutGm[attenOutOffset + startRow * actualColumnCount], attenOutUb, dataCopyParams);

    outputQue1.FreeTensor(tmpBmm2ResCastTensor);

  } else {
    // 8. 非最后一次 将中间结果暂存至workspace vec2ResGm
    pipe_barrier(PIPE_V);
    LocalTensor<T> tmpBmm2Res = outputQue1.AllocTensor<T>();
    DataCopy(tmpBmm2Res, bmm2ResUb, dealRowCount * columnCount);
    outputQue1.EnQue(tmpBmm2Res);
    outputQue1.DeQue<T>();
    //issue bmm2ResUb 为啥不 直接拷贝到 vec2ResGm ?
    DataCopy(vec2ResGm[batchBase + startRow * columnCount], tmpBmm2Res, vec2ComputeSize);

    outputQue1.FreeTensor(tmpBmm2Res);
  }
}
0条评论
0 / 1000
wanyw
3文章数
1粉丝数
wanyw
3 文章 | 1 粉丝
wanyw
3文章数
1粉丝数
wanyw
3 文章 | 1 粉丝
原创

IncreFlashAttention源码分析

2024-09-18 09:21:16
65
0

IncreFlashAttention源码分析

在自回归(Auto-regressive)语言模型的推理过程中,随着新词汇的不断生成,输入序列的长度持续增加,这对计算效率提出了严峻挑战。FlashAttention算子,作为一种高效的注意力机制实现,尤其在增量推理场景下展现出其独特优势。在此场景下,FlashAttention的query维度(S轴)被固定为1,而key和value则通过KV Cache机制,将先前推理过程中的状态信息累积并叠加,以适应每个Batch可能不同的实际长度。值得注意的是,尽管输入数据经过padding处理以维持固定长度,但FlashAttention能够灵活应对这种变化。此外,在全量推理场景中,尽管query的S轴大小不再固定,但FlashAttention的推理流程与增量推理保持一致,确保了算法的通用性和高效性。

实现原理

self-attention(自注意力)机制通过挖掘输入样本内部元素间的相互关系,构建了一种强大的注意力模型。具体而言,对于长度为n的输入序列x,其每个元素均为d维向量(即token embedding)。通过三个权重矩阵的变换,得到Q(查询)、K(键)、V(值)三个矩阵,分别代表输入样本在不同特征空间中的表示。self-attention的计算过程涉及Q与K的矩阵乘法、缩放、softmax归一化以及最终与V的矩阵乘法,这一系列操作共同构建了输入序列内部的注意力分布。

self-attention的计算公式一般定义如下,其中Q、K、V为输入样本的重要属性元素,是输入样本经过空间变换得到,且可以统一到一个特征空间中。公式及算子名称中的"Attention"为"self-attention"的简写。

sat.JPG

其中Q和K转置的乘积代表输入x的注意力,为避免该值变得过大,通常除以d的开根号进行缩放,并对每行进行softmax归一化,与V相乘后得到一个n*d的矩阵。在上述基础上,考虑online-softmax等优化,FA的计算流程图如下所示。

IFA.png

按照flashAttention正向计算流程实现,整体计算流程如下:

query与转置后的key做matmul矩阵乘法计算后得到最初步的attention_score注意力得分,然后乘以缩放系数scale_value后再与位置编码pse相加。此时的结果通过atten_mask进行select掩码选择操作,将atten_mask中为true的位置进行遮蔽,得到结果masked_attention_score,即atten_mask中为true的位置在select后结果为负的极小值,经过softmax计算之后变成0从而达到遮蔽效果。

为了实现FlashAttention加速算法,使用FlashSoftmax操作对masked_attention_score进行运算,用以代替原公式中的softmax运算,而后将结果与value做matmul运算。由于FlashSoftmax操作对masked_attention_score的Skv(输入key、value的sequence length)方向进行了切分,故实现过程中存在一个刷新流程,具体如下:

  1. 当i = 0时,计算出的MM[PV]结果直接保存到attention_out[0]的ub中。
  2. 从i = 1开始,需要增加Mul和Add操作,即将上一次的MM[PV]的结果prev_mm2_res和当前exp相乘,相乘完的结果和本次MM[PV]的结果相加得到的结果保存到ttention_out[1]的ub中。以此类推,遍历tiling块完成计算。
  3. 由于FlashSoftmax计算中的除sum被后移到输出attention_out之前,因此最后需要将ub中的attention_out按行除以softmax_sum并将最终完整的结果保存到输出内存attention_out(Final)上。

源码分析

  • IFA伪代码如下
// 单核计算伪码
void compute() {
  // 1. 核内分块数量 loops
  loops = blocks_to_compute_of_this_core(); // 当前核需要计算几个数据块
  for (i = 0; i < loops; i++) {
    // 2. 块内切分 innerloops
    block = get_curr_block(i);
    bidx, nidx, sidx = dims_of_this_block(block);
    innerloops = get_inner_loops_of_this_block_by_actual_seq_len(bidx, nidx, sidx); // 数据块实际内切份数 
    q_offset = get_offset_of_query(bidx, nidx);
    
    // softmax计算中间变量
    softmax_sum = {0};
    softmax_exp = {0};
    softmax_max = {min_float};
      
    for (j = 0; j < innerloops; j++) { // flash attention循环
      kv_offset = get_offset_of_kv_block(j);
      
      // 1. mm1:S=q@k转置
      qk_res = matmul(q + qoffset, k + kv_offset);
      // 2. vec1:mul+add+select
      qk_res = elementwise(qk_res); // scale,pse, atten-mask
      // 3. vec1:online softmax P=softmax(S)
      qk_res, softmax_max, softmax_sum, softmax_exp = softmaxflash(qk_res, softmax_max, softmax_sum);
      // 4. mm2:O=P@V
      res = matmul(qk_res, v + kv_offset);
      // 5. online softmax 刷新 flash attention update
      prev_res = load_prev_res();// 上次结果
      res += prev_res * softmax_exp;  // O = O + O_pre * softmax_exp
      store(res); // 保存本次结果

      if (j == innerloops - 1) {
        // 最后 执行softmax的 除sum过程
        // sum 是所有块计算叠加全部完成后才得到最终的结果
        res = div(res, softmax_sum);
        output(res); // 输出结果到GM
      }
    }
  }
}
  • Ascend C融合算子开源仓中对应的IFA源码(简化保留主流程)

单核计算过程入口函数

// 1. 单核计算过程入口函数
Process() {
  if (g_coreType == AIV && tmpBlockIdx >= usedCoreNum) {
    // 跳过不需要的核心
  } else {
    // 1. 单核处理多块
    for (uint32_t bn2Idx = 0; bn2Idx < bn2LoopTimes; bn2Idx++) {
      // 2. 分块参数计算
      GetBN2id(bn2Idx);
      GetActualSeqLen();
      // 计算BN2方向的offset
      CalcBN2OffsetAndParams();
      // 根据当前块实际长度, 重配flashattention循环条件
      UpdateInnerLoopCond();
      pipe_barrier(PIPE_V);

      // 3. softmax复初始值
      Duplicate(softmaxMaxUb, SOFTMAX_MIN_NUM, BUFFER_SIZE_BYTE_2K / sizeof(T)); // 最小值
      Duplicate(softmaxSumUb, FLOAT_ZERO, BUFFER_SIZE_BYTE_2K / sizeof(T)); // 置0

      // 4. 块内切分多次计算
      // GQA场景需要处理G,1、mm1 A矩阵 singleM=G 2、mm1结果vector1内部切分mm1的M轴 3、涉及souter的地方,需要注意GQA
      for (uint32_t sInnerLoopIdx = 0; sInnerLoopIdx < sInnerLoopTimes; sInnerLoopIdx++) {
        // 5. 计算 块内切分参数 s2方向的offset
        CalcSInnerOffsetAndParams(sInnerLoopIdx);
        
        //6. 单次 fa计算核心
        SInnerLoopFunc(bn2Idx, sInnerLoopIdx);
      }
    }
  }
}

单次 fa计算流程

// 2. 单次 fa计算核心
SInnerLoopFunc(const uint32_t bn2Idx, const uint32_t sInnerLoopIdx) {
  // 1. 配置 mm1 mm2 矩阵乘法高阶函数 原始矩阵大小
  SetMMOrgShapeCommon();
  // 2. mm1:q@k转置
  Bmm1ComputeCommon(bn2Idx, sInnerLoopIdx);
  // 3. v1: mul + add + select + flash_softmax
  ProcessVec1Inner(sInnerLoopIdx);
  // 4. mm2: p@v
  Bmm2Compute(bn2Idx, sInnerLoopIdx);
  // 5. v2: flash_softmax 刷新: mul+add div
  ProcessVec2Inner(sInnerLoopIdx);
}

// 3.配置 mm1 mm2 矩阵乘法高阶函数 原始矩阵大小
SetMMOrgShapeCommon() {
  if (curSingleProcessSInnerSizeAlign != actualSingleProcessSInnerSizeAlign) {
    // mm1 setOrgShape
    uint32_t orgKa = headDim;
    mm.SetOrgShape(gSizeMulMsd, kvSeqSize, orgKa, kvHeadMulDim, actualSingleProcessSInnerSizeAlign);
    // mm2 setOrgShape
    bmm2.SetOrgShape(gSizeMulMsd, kvHeadMulDim, actualSingleProcessSInnerSizeAlign, kvSeqSize, headDimAlign);
    // 更新curSingleProcessSInnerSizeAlign,为了下一次判断是否进行setOrgShape使用
    curSingleProcessSInnerSizeAlign = actualSingleProcessSInnerSizeAlign;
  }
}

mm1: Q@K转置 矩阵乘法cube运算

// 4. mm1: Q@K转置 矩阵乘法cube运算
Bmm1ComputeCommon(uint32_t bn2Idx, uint32_t sInnerLoopIdx) {
  mm.SetTensorA(queryGm[tensorACoreOffset]); // 设置左矩阵Q
  mm.SetTensorB(keyGm[tensorBOffset], true); // 设置右矩阵K 需要转置
  mm.SetTail(gSizeMulMsd, actualSingleProcessSInnerSize, headDim);// 设置分块大小
  mm.template IterateAll<false>(mm1ResGm, false, false, true); // 执行矩阵计算 结果存放 mm1ResGm
  mm.WaitIterateAll(); // 等待计算完成
  mm.End();
}

vec1: 向量计算,mm1之后的计算,mul+add+select+softmax

// 5. vec1: 向量计算,mm1之后的计算,mul+add+select+softmax
ProcessVec1Inner(uint32_t sInnerLoopIdx) {
  // vec相比cube计算性能差,需要再分块分多次计算
  uint32_t gSplitSize = BASE_BLOCK_MAX_ELEMENT_NUM / actualSingleProcessSInnerSizeAlign;
  if (gSplitSize > gSize) {
    gSplitSize = gSize; // 每次计算数量
  }
  uint32_t loopCount = (gSize + gSplitSize - 1) / gSplitSize; // 循环次数
  uint32_t tailSplitSize = gSize - (loopCount - 1) * gSplitSize;// 尾次数量

  for (uint32_t i = 0, dealSize = gSplitSize; i < loopCount; i++) {
    if (i == (loopCount - 1)) {
      dealSize = tailSplitSize; // 尾次
    }
    // vec1分块计算核心:q@k mm1之后的vector操作 mul+add+select+softmax
    DealBmm1ResBaseBlock(sInnerLoopIdx, i * gSplitSize, dealSize, actualSingleProcessSInnerSizeAlign,
                          actualSingleProcessSInnerSize);
  }
}


// 5.1 vec1分块计算核心:mm1之后的vector操作 mul+add+select+softmax
DealBmm1ResBaseBlock(uint32_t sInnerLoopIdx, uint32_t startRow, uint32_t dealRowCount, 
                     uint32_t columnCount, uint32_t actualColumnCount) {
  uint32_t computeSize = dealRowCount * columnCount;
  LocalTensor<T> mmResUb = tmpBuff1.Get<T>();
  size_t batchBase = 0;
  // 1. 拿到mm1 结果
  {
    LocalTensor<MM_OUT_T> tmpMmResUb = inputQue1.AllocTensor<MM_OUT_T>();
    DataCopy(tmpMmResUb, mm1ResGm[batchBase + startRow * columnCount], computeSize);
    inputQue1.EnQue(tmpMmResUb);
    inputQue1.DeQue<MM_OUT_T>();
    DataCopy(mmResUb, tmpMmResUb, computeSize);
    inputQue1.FreeTensor(tmpMmResUb);
    pipe_barrier(PIPE_V);
  }
  
  // 2. mul+add+select
  ElewiseCompute(mmResUb, tmpBuff2, startRow, dealRowCount, columnCount, actualColumnCount);
  
  // 3. 在线 softmax
  LocalTensor<T> tmpAFloorUb = tmpBuff2.Get<T>();
  LocalTensor<uint8_t> softmaxTmpUb = tmpAFloorUb.template ReinterpretCast<uint8_t>();
  SoftmaxFlashV2Compute(mmResUb, softmaxTmpUb, startRow, dealRowCount, columnCount, actualColumnCount);
  pipe_barrier(PIPE_V);
  
  // 4. 结果类型转换为 KV的数据类型
  LocalTensor<KV_T> tmpMMResCastTensor = outputQue1.AllocTensor<KV_T>();
  Cast(tmpMMResCastTensor, mmResUb, AscendC::RoundMode::CAST_ROUND, computeSize); // mm1+vec1之后的结果转换成和 mm2的 value类型相同
  
  // 5. 拷贝到全局内存vec1ResGm中
  outputQue1.EnQue(tmpMMResCastTensor);
  outputQue1.DeQue<KV_T>();
  DataCopy(vec1ResGm[batchBase + startRow * columnCount], tmpMMResCastTensor, computeSize);
  outputQue1.FreeTensor(tmpMMResCastTensor);
}


// 5.2 vec1向量计算 mul+add+select
ElewiseCompute(LocalTensor<T>& mmResUb, TBuf<>& tmpBuf, uint32_t startRow,
               uint32_t dealRowCount, uint32_t columnCount, uint32_t actualColumnCount) {
  // 1. 乘 缩放  q@k * 1/sqrt(n)
  Muls(mmResUb, mmResUb, static_cast<T>(tilingData->baseParams.scaleValue), dealRowCount * columnCount);
  pipe_barrier(PIPE_V);

  // 2. 加 位置编码 pse shift mask
  if (pseShiftFlag) {
    // 2.1 拷贝编码数据
    PseShiftCopyIn(startRow, dealRowCount, actualColumnCount);
    LocalTensor<pseShiftType> pseShiftUb = inputQue1.DeQue<pseShiftType>();
    LocalTensor<float> pseShiftUbFloat = tmpBuf.Get<float>();
    // 2.2 转换为浮点类型
    for (uint32_t i = 0; i < dealRowCount; ++i) {
      Cast(pseShiftUbFloat[i * columnCount], pseShiftUb[i * pseMaskSizeAlign], AscendC::RoundMode::CAST_NONE,
           pseMaskSizeAlign);
    }
    // 2.3 加上位置编码
    inputQue1.FreeTensor(pseShiftUb);
    pipe_barrier(PIPE_V);
    Add(mmResUb, mmResUb, pseShiftUbFloat, dealRowCount * columnCount);
    pipe_barrier(PIPE_V);
  }

  // 3. 掩码处理 attenMask
  if (attenMaskFlag == 1) {
    // 3.1 拷贝掩码数据
    AttenMaskCopyIn(attenMaskOffset, dealRowCount, actualColumnCount);
    LocalTensor<bool> attenMaskUb = inputQue2.DeQue<bool>();
    for (int i = 1; i < dealRowCount; i++) {
      DataCopy(attenMaskUb[i * attenMaskSizeAlign], attenMaskUb, attenMaskSizeAlign);
    }
    pipe_barrier(PIPE_V);
    
    // 3.2 select 执行掩码选取操作
    LocalTensor<uint8_t> ubWorkSpace = tmpBuf.Get<uint8_t>(selectWithByteMaskTmpMinSize);
    SelectWithBytesMaskShapeInfo selectWithBytesMaskShapeInfo;
    selectWithBytesMaskShapeInfo.firstAxis = dealRowCount;
    selectWithBytesMaskShapeInfo.srcLastAxis = columnCount;
    selectWithBytesMaskShapeInfo.maskLastAxis = attenMaskSizeAlign;
    attenMaskUb.SetSize(dealRowCount * attenMaskSizeAlign);  // Select接口要求mask size与参数匹配
    mmResUb.SetSize(dealRowCount * columnCount);             // Select接口要求src size与参数匹配
    SelectWithBytesMask(mmResUb, mmResUb, BOOL_ATTEN_MASK_SCALAR_VALUE, attenMaskUb, ubWorkSpace,
                        selectWithBytesMaskShapeInfo);
    mmResUb.SetSize(BUFFER_SIZE_BYTE_32K / sizeof(T));  // mmResUb Size复原,mask不用复原,与原来一致
    inputQue2.FreeTensor(attenMaskUb);

    pipe_barrier(PIPE_V);
  }
}

// 5.3vec1向量计算 在线softmax
SoftmaxFlashV2Compute(LocalTensor<T>& mmResUb, LocalTensor<uint8_t>& softmaxTmpUb, 
                      uint32_t startRow, uint32_t dealRowCount,
                      uint32_t columnCount, uint32_t actualColumnCount) {
  uint32_t baseOffset = startRow * BLOCK_ELEMENT_NUM;
  SoftMaxShapeInfo srcShape = {dealRowCount, columnCount, dealRowCount, actualColumnCount};
  // 计算 tiling分块参数
  SoftMaxTiling newTiling =
    SoftMaxFlashV2TilingFunc(srcShape, sizeof(T), sizeof(T), softmaxTmpUb.GetSize(), true, false);
  // 执行 flash softmax运算
  SoftmaxFlashV2<T, true, true, false, false, IFA_SOFTMAX_FLASHV2_CFG> (mmResUb, softmaxSumUb[baseOffset],
    softmaxMaxUb[baseOffset], mmResUb, softmaxExpUb[baseOffset], softmaxSumUb[baseOffset], softmaxMaxUb[baseOffset],
    softmaxTmpUb, newTiling, srcShape);
}

mm2: P@V 矩阵乘法cube运算

// 6. mm2: P@V 矩阵乘法cube运算
Bmm2ComputeCommon(uint32_t bn2Idx, uint32_t sInnerLoopIdx) {

  bmm2.SetTensorA(vec1ResGm); // 设置左矩阵P=softmax(Q@K转置)
  bmm2.SetTensorB(valueGm[valueOffset]);// 设置右矩阵V
  bmm2.SetTail(gSizeMulMsd, headDim, actualSingleProcessSInnerSize);// 设置分块大小
  bmm2.template IterateAll<false>(mm2ResGm, false, false, true);// 执行矩阵计算 结果存放 mm2ResGm
  bmm2.WaitIterateAll();// 等待计算完成
  bmm2.End();
}

vec2: 向量计算,mm2之后的计算 softmax的刷新操作 mul add div

// 7. vec2: 向量计算,mm2之后的计算 softmax的刷新操作 mul add div
ProcessVec2Inner(const uint32_t sInnerLoopIdx) {
  uint32_t gSplitSize = BASE_BLOCK_MAX_ELEMENT_NUM / headDimAlign;
  if (gSplitSize > gSize) {
    gSplitSize = gSize; // 每次处理的数量
  }
  uint32_t loopCount = (gSize + gSplitSize - 1) / gSplitSize; // 循环次数
  uint32_t tailSplitSize = gSize - (loopCount - 1) * gSplitSize;// 尾次数量
  
  // 分块处理
  for (uint32_t i = 0, dealSize = gSplitSize; i < loopCount; i++) {
    if (i == (loopCount - 1)) {
      dealSize = tailSplitSize; // 尾次
    }
    // vec2分块计算核心:p@v mm2之后的计算 softmax的刷新操作 mul add div
    DealBmm2ResBaseBlock(sInnerLoopIdx, i * gSplitSize, dealSize, headDimAlign, headDim);
  }
}

// 7.1 vec2分块计算核心:p@v mm2之后的计算 softmax的刷新操作 mul add div
DealBmm2ResBaseBlock(uint32_t sInnerLoopIdx, uint32_t startRow, uint32_t dealRowCount,
                     uint32_t columnCount, uint32_t actualColumnCount) {
  uint32_t vec2ComputeSize = dealRowCount * columnCount;
  uint32_t baseOffset = startRow * BLOCK_ELEMENT_NUM;
  
  LocalTensor<T> bmm2ResUb = tmpBuff1.Get<T>();
  bmm2ResUb.SetSize(vec2ComputeSize);
  // 1. 获取mm2 计算结果
  {
    LocalTensor<MM_OUT_T> tmpBmm2ResUb = inputQue1.AllocTensor<MM_OUT_T>();
    DataCopy(tmpBmm2ResUb, mm2ResGm[batchBase + startRow * columnCount], vec2ComputeSize);
    inputQue1.EnQue(tmpBmm2ResUb);
    inputQue1.DeQue<MM_OUT_T>();
    DataCopy(bmm2ResUb, tmpBmm2ResUb, vec2ComputeSize);
    inputQue1.FreeTensor(tmpBmm2ResUb);
  }

  // 除第一个循环外,均需要更新中间计算结果
  if (sInnerLoopIdx > 0) {
    // 2. 得到上次的结果 O_pre
    LocalTensor<T> bmm2ResPreUb = inputQue2.AllocTensor<T>();
    DataCopy(bmm2ResPreUb, vec2ResGm[batchBase + startRow * columnCount], vec2ComputeSize);
    inputQue2.EnQue(bmm2ResPreUb);
    inputQue2.DeQue<T>();
    pipe_barrier(PIPE_V);

    // 3. 更新结果 O_pre = mul(O_pre, softmax_exp)
    RowMuls(bmm2ResPreUb, bmm2ResPreUb, softmaxExpUb[baseOffset], dealRowCount, columnCount, actualColumnCount);
    pipe_barrier(PIPE_V);

    // 4. 累加 O = O + O_pre
    Add(bmm2ResUb, bmm2ResUb, bmm2ResPreUb, vec2ComputeSize);
    inputQue2.FreeTensor(bmm2ResPreUb);
  }

  // 最后一次输出计算结果,否则将中间结果暂存至workspace
  if (sInnerLoopIdx + 1 == sInnerLoopTimes) {
    pipe_barrier(PIPE_V);
    // 5. 最后一次 除以 softmaxSum
    RowDivs(bmm2ResUb, bmm2ResUb, softmaxSumUb[baseOffset], dealRowCount, columnCount, actualColumnCount);
    pipe_barrier(PIPE_V);

    // 6. 数据类型转换
    LocalTensor<OUT_T> tmpBmm2ResCastTensor = outputQue1.AllocTensor<OUT_T>();
    Cast(tmpBmm2ResCastTensor, bmm2ResUb, AscendC::RoundMode::CAST_ROUND, dealRowCount * columnCount);
    outputQue1.EnQue(tmpBmm2ResCastTensor);
    outputQue1.DeQue<OUT_T>();

    // 7. 结果拷贝 到 attentionOutGm 最终结果
    DataCopyExtParams dataCopyParams;
    dataCopyParams.blockCount = dealRowCount;
    dataCopyParams.blockLen = actualColumnCount * sizeof(OUT_T);
    dataCopyParams.srcStride = (columnCount - actualColumnCount) / (BYTE_BLOCK / sizeof(OUT_T));
    dataCopyParams.dstStride = 0;
    DataCopyPad(attentionOutGm[attenOutOffset + startRow * actualColumnCount], attenOutUb, dataCopyParams);

    outputQue1.FreeTensor(tmpBmm2ResCastTensor);

  } else {
    // 8. 非最后一次 将中间结果暂存至workspace vec2ResGm
    pipe_barrier(PIPE_V);
    LocalTensor<T> tmpBmm2Res = outputQue1.AllocTensor<T>();
    DataCopy(tmpBmm2Res, bmm2ResUb, dealRowCount * columnCount);
    outputQue1.EnQue(tmpBmm2Res);
    outputQue1.DeQue<T>();
    //issue bmm2ResUb 为啥不 直接拷贝到 vec2ResGm ?
    DataCopy(vec2ResGm[batchBase + startRow * columnCount], tmpBmm2Res, vec2ComputeSize);

    outputQue1.FreeTensor(tmpBmm2Res);
  }
}
文章来自个人专栏
LM
3 文章 | 1 订阅
0条评论
0 / 1000
请输入你的评论
0
0