github.com/tspeterkim/flash-attention-minimal
该代码符号定义基本与论文一致。但是有Br = Bc的隐含假设,不适合实际复杂的情况。
一、 输入输出定义
1.输入
B: batch size
nh: number of heads
N: sequence length
d: head embedding dims
Q 尺寸 B x nh x N x d
K 尺寸 B x nh x N x d
V 尺寸 B x nh x N x d
Br: 每个切分的Q的大小
Bc: 每个切分的K,V的大小
Tr = ceil(N / Br)
Tc = ceil(N/ Bc)
2. 中间变量
l尺寸B x nh x N ,初始化为0
m 尺寸B x nh x N ,初始化为负无穷
l和m开辟在Global memory上,使用时加载到寄存器上
3. 输出
O 尺寸 B x nh x N x d
O开辟在Global memory上,使用时加载到shared memory上
二、 kernel函数调用
grid_dim尺寸 B x nh,即第一个维度对应batch, 第二个维度对应head数目。则每个block处理的是一个head的运算。QKV的数据尺寸都为N x d
dim3 grid_dim(B, nh); // batch_size x num_heads
block中的thread数目是Bc(实际应该是Br).
dim3 block_dim(Bc); // Bc threads per block
三、 核函数实现
这里代码假设了Br = Bc。实际可能不是,有问题
1. 共享内存开辟
共享内存大小为Bc x d x 3 + Br x Bc,分别存储Qi, Kj, Vj和中间变量Sij
2. 核函数结构
● 每个block处理一个head。
● 每个block单次计算处理的是一个分块Qi, Kj, Vj的计算。故block内部有Tr*Tc次循环。循环结构与论文定义相同。即外循环为Kj,Vj的循环,内循环为Qi的循环。
● 一个thread单次处理的是Qi中一行的数据。由于Qi中一行会和Kj中的每一行计算。故在计算时,通过一个长度为Bc的循环实现对Kj每行的遍历。对Vj的遍历同理。
● 由于一个thread对应的是Qi的一行,所以该行m和l的状态更新只需要寄存器保存。完成Q的遍历后再写入HBM即可。
3. 核函数流程
- 定义外层循 环,从HBM将Kj和Vj加载到shared memory
// 定义外层循环,从HBM将Kj和Vj加载到shared memory
for (int j = 0; j < Tc; j++) {
// Load Kj, Vj to SRAM
// 将Kj和Vj加载到shared memory
for (int x = 0; x < d; x++) {
Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];
Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];
}
__syncthreads();
2. 定义内层循环,从HBM将Qi加载到shared_memory。从HBM将mi和li加载到寄存器。
// 定义内层循环,从HBM将Qi加载到shared_memory。从HBM将mi和li加载到寄存器。
for (int i = 0; i < Tr; i++) {
// Load Qi to SRAM, l and m to registers
// 将Qi加载到shared memory,l和m加载到寄存器
for (int x = 0; x < d; x++) {
Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x];
}
float row_m_prev = m[lm_offset + (Br * i) + tx];
float row_l_prev = l[lm_offset + (Br * i) + tx];
// S = QK^T, row_m = rowmax(S)
3 . 执行QK^T计算。由于一个线程代表Qi里面的一行,对Kj的遍历通过一个长度为Bc的循环实现。row_m代表文章里的mij。
// S = QK^T, row_m = rowmax(S)
float row_m = -INFINITY;
for (int y = 0; y < Bc; y++) {
float sum = 0;
for (int x = 0; x < d; x++) {
sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
}
sum *= softmax_scale;
S[(Bc * tx) + y] = sum;
if (sum > row_m)
row_m = sum;
}
4. 按行累加求和,计算lij。
// P = exp(S - row_m), row_l = rowsum(P)
// 计算l_ij
float row_l = 0;
for (int y = 0; y < Bc; y++) {
S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - row_m);
row_l += S[(Bc * tx) + y];
}
5. 更新m_new和l_new。此时寄存器内有从HBM中加载的上一时刻的m_prev和l_prev。将当前线程计算的m和l代入公式。
// Compute new m and l
// 更新m_i
float row_m_new = max(row_m_prev, row_m);
// 更新l_i
float row_l_new = (__expf(row_m_prev - row_m_new) * row_l_prev) + (__expf(row_m - row_m_new) * row_l);
6. 将O,l,m写入HBM。
for (int x = 0; x < d; x++) {
float pv = 0; // Pij * Vj
// 获取pv相乘的结果,这里p
for (int y = 0; y < Bc; y++) {
pv += S[(Bc * tx) + y] * Vj[(y * d) + x];
}
O[qkv_offset + (tile_size * i) + (tx * d) + x] = (1 / row_l_new) \
* ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (tile_size * i) + (tx * d) + x]) \
+ (__expf(row_m - row_m_new) * pv));
}
m[lm_offset + (Br * i) + tx] = row_m_new;
l[lm_offset + (Br * i) + tx] = row_l_new;
7. 在每个内循环结束后,即Q遍历结束后,需要进行一个同步。因为接下来外循环会更新Kj和Vj,不同步会使得Qi使用错误的Kj和Vj值。
__syncthreads(); // otherwise, thread can use the wrong Kj, Vj in inner loop