笔者在学习 BERT 架构技术时,看到书中提到了 BERT 没有采用原始 Transformer 中的正弦-余弦位置编码,但是没讲原因。
于是笔者到网上查了一番资料进行了学习。
在机器学习和深度学习的领域中,BERT 是一种强大的预训练语言模型。它采用了许多优化策略,其中一个关键设计差异是其位置编码方法。与原始 Transformer 中的正弦-余弦位置编码方法不同,BERT 使用了基于可学习参数的嵌入方式来表示位置。
正弦-余弦位置编码方法回顾
原始 Transformer 论文中提出的正弦-余弦位置编码方法是一种固定的数学方法。它通过以下公式生成位置编码:
import numpy as np
import torch
def get_sinusoidal_positional_encoding(seq_len, d_model):
position = np.arange(seq_len)[:, np.newaxis]
div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
pe = np.zeros((seq_len, d_model))
pe[:, 0::2] = np.sin(position * div_term)
pe[:, 1::2] = np.cos(position * div_term)
return torch.tensor(pe, dtype=torch.float32)
positional_encoding = get_sinusoidal_positional_encoding(seq_len=10, d_model=16)
print(positional_encoding)
上述方法利用正弦和余弦函数的周期性,使模型能够感知相对位置。
优点
- 固定性:位置编码是确定的,不会随训练过程变化,具有解析性。
- 相对性:编码的相对位置信息通过周期性自然体现。
局限性
- 表达能力有限:正弦和余弦函数的固定模式对复杂语言结构的捕捉能力不足。
- 灵活性不足:无法根据任务或数据分布自适应优化。
- 难以扩展:在超长序列情况下,固定编码可能无法很好地适应。
BERT 的位置编码方法
BERT 选择了一种基于可学习参数的嵌入方式,用于位置编码。这种方法将每个位置作为索引输入一个可训练的嵌入矩阵,从而得到对应的向量表示。代码如下:
import torch
import torch.nn as nn
class LearnedPositionalEncoding(nn.Module):
def __init__(self, seq_len, d_model):
super(LearnedPositionalEncoding, self).__init__()
self.position_embeddings = nn.Embedding(seq_len, d_model)
def forward(self, input_tensor):
seq_length = input_tensor.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_tensor.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_tensor[:, :, 0])
return self.position_embeddings(position_ids)
input_tensor = torch.randn(2, 10, 16) # Batch size 2, sequence length 10, hidden size 16
learned_pe = LearnedPositionalEncoding(seq_len=10, d_model=16)
output = learned_pe(input_tensor)
print(output)
通过这种方式,位置编码可以在训练过程中与其他模型参数一起更新,以适应具体任务的需求。
优势分析
灵活性
与正弦-余弦位置编码相比,可学习的嵌入能够根据任务数据分布自动调整编码模式。例如,在涉及句法或语义分析的任务中,不同的语言结构对位置信息的需求可能有显著差异。固定编码可能无法有效捕捉这些微妙的关系。
表达能力
从数学角度来看,正弦-余弦方法的维度分解固定,受限于频率函数的形态。而 BERT 的可学习嵌入可以在高维空间中自由调整,能够更好地拟合复杂的语言分布。
实验验证
研究显示,BERT 在许多下游任务中的表现优于基于正弦-余弦位置编码的模型。这表明可学习位置编码在实际场景中具有更强的适应能力。
举例说明
假设我们有一段文本 "机器学习改变了世界"
,其位置编码的作用在于帮助模型理解 "机器学习"
和 "改变了世界"
的相对关系。
- 如果使用正弦-余弦编码,这种关系通过函数周期性体现。然而,当文本长度增加时,相对位置关系的周期性可能变得模糊。
- 如果使用可学习位置嵌入,模型可以动态调整每个位置的表示。例如,它可能在训练过程中学会对谓语动词和宾语的位置关系赋予更高权重,从而增强理解能力。
以下代码通过简单示例模拟这一过程:
from transformers import BertModel, BertTokenizer
text = "机器学习改变了世界"
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
inputs = tokenizer(text, return_tensors="pt")
model = BertModel.from_pretrained("bert-base-chinese")
outputs = model(**inputs)
print(outputs.last_hidden_state.shape)
在 BERT 中,位置编码已融入模型的嵌入层中。通过分析输出隐藏层状态,可以发现不同位置上的表征逐步捕捉了句法和语义信息。
为什么选择动态优化
真实案例
在工业应用中,例如机器翻译,文本长度往往不可控。如果采用固定位置编码,长文本的效果可能显著下降。反之,BERT 的可学习位置编码能够在训练过程中优化长短文本的表现,适应不同的上下文需求。
例如,Google 的搜索引擎使用 BERT 优化搜索结果。可学习的位置编码使模型更好地理解查询中重要词汇的位置关系,从而提高相关性排序。
实验比较
通过比较使用正弦-余弦和可学习位置编码的 Transformer 模型,可以观察到以下差异:
位置编码方式 | 训练灵活性 | 长文本性能 | 短文本性能 |
---|---|---|---|
正弦-余弦编码 | 低 | 一般 | 良好 |
可学习嵌入 | 高 | 优秀 | 优秀 |
实验表明,可学习嵌入在多种任务中表现更加稳定,尤其在涉及长文本或复杂上下文的情况下优势显著。
小结
BERT 不采用正弦-余弦位置编码的主要原因在于其灵活性和表达能力的局限。通过引入可学习的位置嵌入,BERT 能够更好地适应不同任务的需求,从而在多种自然语言处理任务中实现更高的性能。这一设计选择为语言模型的发展奠定了新的基准,也为后续模型优化提供了重要的启发。