searchusermenu
  • 发布文章
  • 消息中心
点赞
收藏
评论
分享
原创

REFORMER——更高效的TRANSFORMER

2024-08-13 09:51:44
5
0

算法介绍

1. 核心思想

Reformer主要是为了解决Transformer结构显存占用大、模型复杂度度高的缺点(无法处理长序列、无法在单GPU上进行训练),而进行的模型结构改进,主要通过以下3个方面来进行:

· 利用局部敏感哈希(LSH)来减少长序列注意力的复杂度

· 利用可逆残差网络(RevNet)来减少反向传播时的显存占用

· 利用分块技术来进行前馈网络的计算(Chunking FFN layer)来减少显存占用。

2. 详细过程

2.1 局部敏感哈希(LSH)

·  首先作了K、Q同源的操作,也就是通过1组线性变换得到相同的K和Q,相比原始Transformer的2组不同的线性变换来生成K和Q,得到了一定的显存和计算量的节省。

·  原始的Attention公式如下所示,最大的计算量就在[MISSING IMAGE: ,  ]上面。通过分析可知,Q和K的内积(每个元素相互的关系),其实绝大多数都是0,只有少部分才是有相互关系的,所以[MISSING IMAGE: ,  ]的计算,可以做一个近似,也就是只计算每个元素有关系的那些元素的内积,其余元素就填充0。

 

·  如何找到某个元素,有关系的那些元素。这就是局部敏感哈希(LSH)的任务,作者设计了圆环状的hash桶,通过3次随机变换,如果两个点都在一个hash桶内,就说明两者是相近的(有关系的)。

 

 

·  同时,作者通过增加不同的hash函数,可以更多的找到相近点,减少遗漏。

·  下图就是一个简化的示例,展示了局部敏感哈希桶化、相似排序(聚类)、分块计算注意力等步骤,实现了对注意力机制的简化。

 

 

2.2 可逆残差网络(RevNet)

·  作者相对于ResNet,设计了可逆残差网络(RevNet),其主要结构如下图所示,可以看到对于ResNet,是无法通过[MISSING IMAGE: ,  ]计算[MISSING IMAGE: ,  ]甚至[MISSING IMAGE: ,  ]的,所以这些值在反向传播时都得存储在显存中。而RevNet,设计了[MISSING IMAGE: ,  ][MISSING IMAGE: ,  ]两个激活值,通过他们可以首先计算[MISSING IMAGE: ,  ],再计算出[MISSING IMAGE: ,  ],这样显存中就可以只存储最终的激活值,中间过程的激活值可以通过计算得到,节省显存空间。

 

·  具体到这篇论文,作者就是将Attention作为函数F,FeedForword作为了函数G,来构造了这个可逆残差网络(RevNet)。 

2.3 前馈网络分块(Chunking FFN layer)

·  前馈层中间向量维度很高,计算时需要耗费很大的显存。由于前馈层的计算是序列无关的,所以前向和后向的计算以及反向的计算都可以被分割成块。 

·  这样的处理,对于计算量没有影响,而对于显存的占用,可以做到很大的节省。

3. 应用场景

· Transformer适用的任务(尤其单Gpu限制下的任务场景)

· 更长序列的文本任务(例如enwik8,长度为64K)

· 更长序列的图像生成任务(例如imagenet64,长度为12K)

4. 附图

1.  Q、K同源以及可逆残差网络对精度的影响

 

 

2.  不同LSH数量对精度影响

 

 

3.  不同LSH数量的推理速度

0条评论
0 / 1000
钱****翔
6文章数
0粉丝数
钱****翔
6 文章 | 0 粉丝
原创

REFORMER——更高效的TRANSFORMER

2024-08-13 09:51:44
5
0

算法介绍

1. 核心思想

Reformer主要是为了解决Transformer结构显存占用大、模型复杂度度高的缺点(无法处理长序列、无法在单GPU上进行训练),而进行的模型结构改进,主要通过以下3个方面来进行:

· 利用局部敏感哈希(LSH)来减少长序列注意力的复杂度

· 利用可逆残差网络(RevNet)来减少反向传播时的显存占用

· 利用分块技术来进行前馈网络的计算(Chunking FFN layer)来减少显存占用。

2. 详细过程

2.1 局部敏感哈希(LSH)

·  首先作了K、Q同源的操作,也就是通过1组线性变换得到相同的K和Q,相比原始Transformer的2组不同的线性变换来生成K和Q,得到了一定的显存和计算量的节省。

·  原始的Attention公式如下所示,最大的计算量就在[MISSING IMAGE: ,  ]上面。通过分析可知,Q和K的内积(每个元素相互的关系),其实绝大多数都是0,只有少部分才是有相互关系的,所以[MISSING IMAGE: ,  ]的计算,可以做一个近似,也就是只计算每个元素有关系的那些元素的内积,其余元素就填充0。

 

·  如何找到某个元素,有关系的那些元素。这就是局部敏感哈希(LSH)的任务,作者设计了圆环状的hash桶,通过3次随机变换,如果两个点都在一个hash桶内,就说明两者是相近的(有关系的)。

 

 

·  同时,作者通过增加不同的hash函数,可以更多的找到相近点,减少遗漏。

·  下图就是一个简化的示例,展示了局部敏感哈希桶化、相似排序(聚类)、分块计算注意力等步骤,实现了对注意力机制的简化。

 

 

2.2 可逆残差网络(RevNet)

·  作者相对于ResNet,设计了可逆残差网络(RevNet),其主要结构如下图所示,可以看到对于ResNet,是无法通过[MISSING IMAGE: ,  ]计算[MISSING IMAGE: ,  ]甚至[MISSING IMAGE: ,  ]的,所以这些值在反向传播时都得存储在显存中。而RevNet,设计了[MISSING IMAGE: ,  ][MISSING IMAGE: ,  ]两个激活值,通过他们可以首先计算[MISSING IMAGE: ,  ],再计算出[MISSING IMAGE: ,  ],这样显存中就可以只存储最终的激活值,中间过程的激活值可以通过计算得到,节省显存空间。

 

·  具体到这篇论文,作者就是将Attention作为函数F,FeedForword作为了函数G,来构造了这个可逆残差网络(RevNet)。 

2.3 前馈网络分块(Chunking FFN layer)

·  前馈层中间向量维度很高,计算时需要耗费很大的显存。由于前馈层的计算是序列无关的,所以前向和后向的计算以及反向的计算都可以被分割成块。 

·  这样的处理,对于计算量没有影响,而对于显存的占用,可以做到很大的节省。

3. 应用场景

· Transformer适用的任务(尤其单Gpu限制下的任务场景)

· 更长序列的文本任务(例如enwik8,长度为64K)

· 更长序列的图像生成任务(例如imagenet64,长度为12K)

4. 附图

1.  Q、K同源以及可逆残差网络对精度的影响

 

 

2.  不同LSH数量对精度影响

 

 

3.  不同LSH数量的推理速度

文章来自个人专栏
计算机视觉
6 文章 | 1 订阅
0条评论
0 / 1000
请输入你的评论
0
0