searchusermenu
  • 发布文章
  • 消息中心
点赞
收藏
评论
分享
原创

flash attention代码解读

2024-06-26 09:44:34
12
0
代码地址:

github.com/tspeterkim/flash-attention-minimal

该代码符号定义基本与论文一致。但是有Br = Bc的隐含假设,不适合实际复杂的情况。

 

一、  输入输出定义

1.输入

B: batch size

nh: number of heads

N: sequence length

d: head embedding dims

Q 尺寸 B x nh x N x d

K 尺寸 B x nh x N x d

V 尺寸 B x nh x N x d

Br: 每个切分的Q的大小

Bc: 每个切分的K,V的大小

Tr = ceil(N / Br)

Tc = ceil(N/ Bc)

2.  中间变量

l尺寸B x nh x N  ,初始化为0

m 尺寸B x nh x N ,初始化为负无穷

l和m开辟在Global memory上,使用时加载到寄存器上

3.  输出

O 尺寸 B x nh x N x d

O开辟在Global memory上,使用时加载到shared memory上

二、  kernel函数调用

grid_dim尺寸 B x nh,即第一个维度对应batch, 第二个维度对应head数目。则每个block处理的是一个head的运算。QKV的数据尺寸都为N x d

 

dim3 grid_dim(B, nh); // batch_size x num_heads

 

block中的thread数目是Bc(实际应该是Br).

 

dim3 block_dim(Bc); // Bc threads per block

 

三、  核函数实现

这里代码假设了Br = Bc。实际可能不是,有问题

1.  共享内存开辟

共享内存大小为Bc x d x 3 + Br x Bc,分别存储Qi, Kj, Vj和中间变量Sij

2.  核函数结构

● 每个block处理一个head。

● 每个block单次计算处理的是一个分块Qi, Kj, Vj的计算。故block内部有Tr*Tc次循环。循环结构与论文定义相同。即外循环为Kj,Vj的循环,内循环为Qi的循环。

● 一个thread单次处理的是Qi中一行的数据。由于Qi中一行会和Kj中的每一行计算。故在计算时,通过一个长度为Bc的循环实现对Kj每行的遍历。对Vj的遍历同理。

● 由于一个thread对应的是Qi的一行,所以该行m和l的状态更新只需要寄存器保存。完成Q的遍历后再写入HBM即可。

3.  核函数流程

  1. 定义外层循 环,从HBM将Kj和Vj加载到shared memory
     
    // 定义外层循环,从HBM将Kj和Vj加载到shared memory
    for (int j = 0; j < Tc; j++) {

    // Load Kj, Vj to SRAM
    // 将Kj和Vj加载到shared memory
    for (int x = 0; x < d; x++) {
    Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];
    Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];
    }
    __syncthreads();

 

2.  定义内层循环,从HBM将Qi加载到shared_memory。从HBM将mi和li加载到寄存器。

// 定义内层循环,从HBM将Qi加载到shared_memory。从HBM将mi和li加载到寄存器。
for (int i = 0; i < Tr; i++) {

// Load Qi to SRAM, l and m to registers
// 将Qi加载到shared memory,l和m加载到寄存器
for (int x = 0; x < d; x++) {
Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x];
}
float row_m_prev = m[lm_offset + (Br * i) + tx];
float row_l_prev = l[lm_offset + (Br * i) + tx];

// S = QK^T, row_m = rowmax(S)

3 . 执行QK^T计算。由于一个线程代表Qi里面的一行,对Kj的遍历通过一个长度为Bc的循环实现。row_m代表文章里的mij

// S = QK^T, row_m = rowmax(S)
float row_m = -INFINITY;
for (int y = 0; y < Bc; y++) {
float sum = 0;
for (int x = 0; x < d; x++) {
sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
}
sum *= softmax_scale;
S[(Bc * tx) + y] = sum;

if (sum > row_m)
row_m = sum;
}

4.  按行累加求和,计算lij

// P = exp(S - row_m), row_l = rowsum(P)
// 计算l_ij
float row_l = 0;
for (int y = 0; y < Bc; y++) {
S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - row_m);
row_l += S[(Bc * tx) + y];
}

 

5.  更新m_new和l_new。此时寄存器内有从HBM中加载的上一时刻的m_prev和l_prev。将当前线程计算的m和l代入公式。

// Compute new m and l
// 更新m_i
float row_m_new = max(row_m_prev, row_m);
// 更新l_i
float row_l_new = (__expf(row_m_prev - row_m_new) * row_l_prev) + (__expf(row_m - row_m_new) * row_l);

 

6.  将O,l,m写入HBM。

for (int x = 0; x < d; x++) {
float pv = 0; // Pij * Vj
// 获取pv相乘的结果,这里p
for (int y = 0; y < Bc; y++) {
pv += S[(Bc * tx) + y] * Vj[(y * d) + x];
}
O[qkv_offset + (tile_size * i) + (tx * d) + x] = (1 / row_l_new) \
* ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (tile_size * i) + (tx * d) + x]) \
+ (__expf(row_m - row_m_new) * pv));
}
m[lm_offset + (Br * i) + tx] = row_m_new;
l[lm_offset + (Br * i) + tx] = row_l_new;

 

7.  在每个内循环结束后,即Q遍历结束后,需要进行一个同步。因为接下来外循环会更新Kj和Vj,不同步会使得Qi使用错误的Kj和Vj值。

__syncthreads();  // otherwise, thread can use the wrong Kj, Vj in inner loop

 

0条评论
0 / 1000
王****鹏
4文章数
0粉丝数
王****鹏
4 文章 | 0 粉丝