断点续训练指因硬件故障、系统问题、连接错误以及其他未知问题导致训练任务中断后,下一次训练可以在上一次的训练基础上继续执行。断点续训练可以减少需要长时间训练大模型的时间成本。
断点续训练是通过checkpoint机制实现。Checkpoint机制是在模型训练的过程中,不断地保存训练结果(包括但不限于EPOCH、模型权重、优化器状态、调度器状态)。即便模型训练中断,也可以基于Checkpoint接续训练。当训练任务发生中断后需要接续训练,只需要加载Checkpoint,并使用断点前最近一次Checkpoint存储的信息初始化训练状态即可。
说明目前仅支持使用昇腾芯片训练的任务断点续训,其他芯片正在开发中,敬请期待。
前提条件
- 资源组内存在空闲节点,即未被其他训练任务占用的机器。
- 训练任务所属队列内有多余配额可用且足够时,训练任务可自动触发断点续训;训练任务所属队列内无多余配额或多余配额不足时,需扩容该队列,成功后,训练任务可触发断点续训。
训练过程
一体化计算加速平台·异构计算平台会识别中断任务并使用断点前最近一次Checkpoint存储的信息将训练任务重新调度并拉齐训练任务。您只需在训练开始前设置分布式存储/读取Checkpoint,设置成功后将模型代码上传至存储或自定义镜像,并在”创建训练任务”页面选择对应存储或自定义镜像即可,在“高级配置>训练失败后”操作选择“自动重启”。
Checkpoint有以下两种工具:
工具一:使用原生Pytorch Checkpoint。
需在pytorch中设置分布式存储/读取Checkpoint,例如设置torch.save、torch.load等参数的值。
工具二:在原生Pytorch Checkpoint基础上使用一体化计算加速平台·异构计算平台提供自研CTFlashCkpt加速包,其采用异步存储机制加快训练速度,详见CTFlashCkpt介绍。
CTFlashCkpt加速包的安装和使用方法参见安装CTFlashCkpt和使用 CTFlashCkpt。
说明如果latest_checkpointed_iteration.txt内上一次训练最后保存点和下一次训练日志中开始点不一致,可能是写入存储硬盘速度较慢导致,以日志中开始点为准。