算法介绍
1. 核心思想
Reformer主要是为了解决Transformer结构显存占用大、模型复杂度度高的缺点(无法处理长序列、无法在单GPU上进行训练),而进行的模型结构改进,主要通过以下3个方面来进行:
· 利用局部敏感哈希(LSH)来减少长序列注意力的复杂度
· 利用可逆残差网络(RevNet)来减少反向传播时的显存占用
· 利用分块技术来进行前馈网络的计算(Chunking FFN layer)来减少显存占用。
2. 详细过程
2.1 局部敏感哈希(LSH)
· 首先作了K、Q同源的操作,也就是通过1组线性变换得到相同的K和Q,相比原始Transformer的2组不同的线性变换来生成K和Q,得到了一定的显存和计算量的节省。
· 原始的Attention公式如下所示,最大的计算量就在[MISSING IMAGE: , ]上面。通过分析可知,Q和K的内积(每个元素相互的关系),其实绝大多数都是0,只有少部分才是有相互关系的,所以[MISSING IMAGE: , ]的计算,可以做一个近似,也就是只计算每个元素有关系的那些元素的内积,其余元素就填充0。
· 如何找到某个元素,有关系的那些元素。这就是局部敏感哈希(LSH)的任务,作者设计了圆环状的hash桶,通过3次随机变换,如果两个点都在一个hash桶内,就说明两者是相近的(有关系的)。
· 同时,作者通过增加不同的hash函数,可以更多的找到相近点,减少遗漏。
· 下图就是一个简化的示例,展示了局部敏感哈希桶化、相似排序(聚类)、分块计算注意力等步骤,实现了对注意力机制的简化。
2.2 可逆残差网络(RevNet)
· 作者相对于ResNet,设计了可逆残差网络(RevNet),其主要结构如下图所示,可以看到对于ResNet,是无法通过[MISSING IMAGE: , ]计算[MISSING IMAGE: , ]甚至[MISSING IMAGE: , ]的,所以这些值在反向传播时都得存储在显存中。而RevNet,设计了[MISSING IMAGE: , ]和[MISSING IMAGE: , ]两个激活值,通过他们可以首先计算[MISSING IMAGE: , ],再计算出[MISSING IMAGE: , ],这样显存中就可以只存储最终的激活值,中间过程的激活值可以通过计算得到,节省显存空间。
· 具体到这篇论文,作者就是将Attention作为函数F,FeedForword作为了函数G,来构造了这个可逆残差网络(RevNet)。
2.3 前馈网络分块(Chunking FFN layer)
· 前馈层中间向量维度很高,计算时需要耗费很大的显存。由于前馈层的计算是序列无关的,所以前向和后向的计算以及反向的计算都可以被分割成块。
· 这样的处理,对于计算量没有影响,而对于显存的占用,可以做到很大的节省。
3. 应用场景
· Transformer适用的任务(尤其单Gpu限制下的任务场景)
· 更长序列的文本任务(例如enwik8,长度为64K)
· 更长序列的图像生成任务(例如imagenet64,长度为12K)
4. 附图
1. Q、K同源以及可逆残差网络对精度的影响
2. 不同LSH数量对精度影响
3. 不同LSH数量的推理速度