大模型训练推理显存计算
一、不同精度模型权重
一般有fp32、fp16、bf16、int8等几种模型保存格式,主要是模型参数的保存精度不同,fp32即用32位保存一个模型参数,int8即8位。fp16和bf16用16位保存,但他们在小数点部分所使用的位数不同。
下面列举不同表示方法的模型显存计算:
- int8: 模型显存 = 1*参数量(Byte)
- fp16, bf16: 模型显存 = 2*参数量(Byte)
- fp32: 模型显存 = 4*参数量(Byte)
由于1GB约等于1B的字节,即和1B参数量的数据量级一致。
举个例子:llama 13B模型,采用bf16保存,所有下载的ckpt计算下来,整体显存占用需要 2*13=26GB显存左右。
二、训练
1. 模型权重
- 一般是采用fp32、fp16和混合精度训练。其中fp32所占的显存过大,fp16的精度往往达不到期望的程度,因此混合精度训练方式是普遍使用。
- 纯fp32 / 纯fp16 的训练方式,模型显存如上所示。
- 混合精度即 fp16/bf16 + fp32 精度在训练过程中都使用,但一般保存权重的时候是 fp16/bf16,因此显存可以计算可以直接复用上面的公式。具体来说,fp32一般用在累加算子上,防止误差的累积。
2. 优化器状态
- adamw和adam:训练时需要 主权重、动量和二阶动量,因此优化器的参数量为模型权重参数量的3倍,而且由于优化器采用fp32形式保存参数,因此优化器参数所占显存为 3*4*参数量(Byte)。
- bits-and-bytes类的8位优化器:主权重采用fp32,动量和二阶动量为 int8,所以显存为 (4+1+1)*参数量。
3. 梯度
- 梯度一般可以存储为fp32或fp16,梯度数据类型通常与模型数据类型匹配。混合精度训练中,梯度类型一般为fp16。
- 对 fp32训练,梯度显存为 4*参数量(Byte)。
- 对 fp16/混精训练,梯度显存为 2*参数量(Byte)。
4. 激活状态
在进行大模型训练时,GPU受限的往往是显存的大小,而不是算力的问题。因此,激活重计算(也称为激活检查点)变得非常流行,它是一种以计算力为代价来减少显存使用的方法。
激活重计算的主要思路是在反向传播的时候重新计算某些层的激活,代替前向计算后需要保存占用显存的操作,从而降低GPU显存的使用。具体来说,减少显存的多少取决于我们选择重新计算哪些层的激活。
假设激活数据类型为 fp16,没有使用序列并行:
- 无重计算的激活显存: (s * b * h * l) * (10+24/t+5 * a * s/ h / t) (Byte)
- 选择性重计算的激活显存:(s * b * h * l) * (10+24/t) (Byte)
- 全部重计算的激活显存:2 * (s * b * h * l) (Byte)
其中:
- s 是token 长度;
- b 是 每个GPU的batch size;
- h 是 每个hidden layer的维度;
- l 是 模型的隐层数;
- a 是 transformer 模型中注意力头 (attention heads) 的个数;
- t 是张量并行度 (如果无张量并行,则为 1)。
由于重计算的引入也会引起计算成本的增加,具体增加多少取决于选择了多少层进行重计算,但其上界为所有层都额外多了一次前向传播,因此,需要考虑更新后的前向传播计算成本。
计算成本:2 * 数据集token数 * 模型参数 ≤ C(前向传播)≤ 4 * 数据集token数 * 模型参数。
所以,训练的总的显存占用=模型权重显存+优化器显存+梯度显存+激活显存
三、推理
推理的总显存占用分成两部分:
- 模型权重: 同第一部分计算方式
- 前向计算开销:通常在模型权重的20%左右(经验估算)
因此,推理的总的显存占用在 1.2倍的模型显存左右。
四、显存计算实例
-
llama2-13B:
以Adamw优化器,混合精度,选择性重计算llama2-13B为例,计算训练和推理的显存占用情况。
- 训练时:
- 激活显存为: (2048 * 1 * 5120 * 40) * (10+24/1) (Byte) = 14.2 GB
- 训练总显存=模型显存+优化器显存+梯度显存+激活显存 = 26GB + 12*13GB + 26GB + 14.2GB = 222.2GB
- 推理时:2 * 26G = 31.2GB
-
Qwen-72B:
参数:
bf16
hidden-size 8192
seq-length 8192
num-layers 80
micro-batch-size 1
global-batch-size 512
num_attention_heads 64
TP=8
PP=8
adam优化器
8k长度,无重计算
无重计算的激活显存: (s * b * h * l) * (10+24/t+5 * a * s/ h / t)
step1:
激活显存为:(8192*1*8192*80)*(10+24/8+5*64*8192/8192/8)(Byte) = 284.5 GB
step2:
训练总显存
=模型显存+优化器显存+梯度显存+激活显存
=2*72+12*72+2*72+284.5=144+864+144+284.5=1436.5 GB
step3:
推理总显存 = 1.2 * 144G = 172.8 GB