searchusermenu
  • 发布文章
  • 消息中心
点赞
收藏
评论
分享
原创

大型语言模型的解码策略解析

2024-08-14 10:03:57
116
0

1、概述

大型语言模型(LLM)的推理过程面临着严重的性能瓶颈。这些模型通常采用自回归采样方法,需要逐个token进行串行解码,导致推理过程异常缓慢。更关键的是,每生成一个token都需要将所有参数从存储单元传输到计算单元,使得内存访问带宽成为制约因素,严重影响了推理速度。业界已经开发了多种工程优化方案,如改进计算核心实现、多卡并行计算和批处理策略等。然而,这些方法虽然在一定程度上提升了性能,却未能从根本上解决LLM解码过程受制于访存带宽的核心问题。

在这种背景下,探索更高效的解码策略变得尤为重要。本文旨在提供解码策略概览,涵盖从基础到高级的多种方法。首先探讨基础解码方法,包括贪心算法、Beam Search和采样等经典策略。接着,聚焦于最新的高级解码策略,如投机采样、美杜莎解码和对比解码。

基础解码策略

Greedy Search

Greedy Search贪婪搜索策略是一种简单的解码策略,其核心思想是每一步都选择概率最大的单词输出,最后组成整个句子输出。这种方法给出的结果一般情况结果比较差,因为只考虑了​每一步的最优解​,​并不一定是全局最优​,因为贪婪搜索会错过隐藏在低概率单词后面的高概率单词,另外当如果每个位置选错了,后续位置生成的内容很可能也是错误的,具有错误的累加效果。

贪婪搜索由于每步只需要关注最大得分,考虑的因素少,因此实现容易,执行速度快,其实现原理图如下所示。
image-20240809142541425.png

Beam Search

Beam Search 是Greedy Search的一种改进,通过在每一步保持多个候选来平衡搜索的效率和结果的质量,是一种受限的宽度优先搜索方法,经常用在各种 NLP 生成类任务中,例如机器翻译、对话系统、文本摘要。本文首先介绍 Beam Search 的相关概念和得分函数优化方法,然后介绍一种新的 Best-First Beam Search 方法,Best-First Beam Search 结合了优先队列和 A* 启发式搜索方法,可以提升 Beam Search 的速度。

如上文所示,Greedy Search在每一时刻只选择当前最有可能的单词,例如在预测第一个单词时,"我" 的概率最大,则第一个单词预测为 "我";预测第二个单词时,"爱" 的概率最大,则预测为 "爱"。Greedy Searc具有比较高的运行效率,但是每一步考虑的均是局部最优,有时候不能得到全局最优解。

Beam Search 对贪心搜索进行了改进,扩大了搜索空间,更容易得到全局最优解。Beam Search 包含一个参数 beam size=k,表示每一时刻均保留得分最高的 k 个序列,然后下一时刻用这 k 个序列继续生成。其具体的实现步骤为:
(1)在每一步解码时,保留最好的 k 个候选序列,其中 k 称为束宽(beam width)。
(2)对于每个候选序列,生成下一个可能的词,并计算新序列的概率。
(3)从所有新生成的序列中选择前 k 个最高概率的序列继续下一步。

下图展示了 Beam Search 的过程,对应的 k=2:

image.png

采样

随机采样

随机采样基于模型在每个时间步预测的概率分布来选择下一个词。不同于贪心搜索总是选择概率最高的词,随机采样会根据预测的概率随机选择一个词。采样的依据就是解码器输出的词典中每个词的概率分布。相比于按概率“掐尖”,这样会增大所选词的范围,引入更多的随机性。

采样的时候有一个可以控制的超参数,称为温度(temperature, T)。解码器的输出层后面通常会跟一个softmax函数来将输出概率归一化,通过改变T可以控制概率的形貌。softmax的公式如下,当T大的时候,概率分布趋向平均,随机性增大;当T小的时候,概率密度趋向于集中,即强者愈强,随机性降低,会更多地采样出“放之四海而皆准”的词汇。

image-20240809143012610.png

TopK

TopK就是在采样前将输出的概率分布截断,取出概率最大的k个词构成一个集合,然后将这个子集词的概率再归一化,最后从新的概率分布中采样词汇。

这个办法据说可以获得比Beam Search好很多的效果,但也有一个问题,就是这个k不太好选。 因为这个概率分布变化比较大,有时候可能很均匀(flat),有的时候比较集中(peaked)。对于集中的情况还好说,当分布均匀时,一个较小的k容易丢掉很多优质候选词。但如果k定的太大,这个方法又会退化回普通采样。

image (1).png

TopP

TopP采用累计概率的方式(也称为核采样, Nucleus Sampling),该方式不再取一个固定的k,而是固定候选集合的概率密度和在整个概率分布中的比例,即从累计概率超过某一阈值p的词汇中进行采样,根据参数p的大小调节,增大了出现概率较小的词汇的生成的概率。如下是Hugging Face中TopK和TopP采样实现的源码:

def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
   """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
       Args:
           logits: logits distribution shape (batch size, vocabulary size)
           if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
           if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
               Nucleus filtering is described in Holtzman et al.
           Make sure we keep at least min_tokens_to_keep per batch example in the output
   """
   if top_k > 0:
       top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))  # Safety check
       # Remove all tokens with a probability less than the last token of the top-k
       indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
       logits[indices_to_remove] = filter_value
​
   if top_p < 1.0:
       sorted_logits, sorted_indices = torch.sort(logits, descending=True)
       cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
​
       # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
       sorted_indices_to_remove = cumulative_probs > top_p
       if min_tokens_to_keep > 1:
           # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
           sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
       # Shift the indices to the right to keep also the first token above the threshold
       sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
       sorted_indices_to_remove[..., 0] = 0
​
       # scatter sorted tensors to original indexing
       indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
       logits[indices_to_remove] = filter_value
   return logits

惩罚重复

为了解决重复问题,还可以通过惩罚因子将出现过词的概率变小或者强制不使用重复词来解决。惩罚因子来自于同样广为流传的[《CTRL: A Conditional Transformer Language Model for Controllable Generation》],也可参考知乎文章《[Top-K+重复性惩罚]》。

以下是Huggind Face中对重复惩罚的源码实现:

# 输入的同样是logits(lprobs)
# 同时输入了之前出现过的词以及惩罚系数(大于1的)
# 考虑到了logit是正和负时处理方式应该不一样
def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
        """"""
        for i in range(batch_size * num_beams):
            for previous_token in set(prev_output_tokens[i].tolist()):
                # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
                if lprobs[i, previous_token] < 0:
                    lprobs[i, previous_token] *= repetition_penalty
                else:
                    lprobs[i, previous_token] /= repetition_penalty
                    
                    
# 去重复词
# 这个函数将会返回一个不可使用的词表
# 生成n-gram的巧妙方式大家可以借鉴一下
# 下面是一个3-gram的例子
# a = [1,2,3,4,5]
# for ngram in zip(*[a[i:] for i in range(3)]):
#    print(ngram)
def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
    # Copied from fairseq for no_repeat_ngram in beam_search"""
    if cur_len + 1 < no_repeat_ngram_size:
        # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
        return [[] for _ in range(num_hypos)]
    generated_ngrams = [{} for _ in range(num_hypos)]
    for idx in range(num_hypos):
        gen_tokens = prev_input_ids[idx].numpy().tolist()
        generated_ngram = generated_ngrams[idx]
        # 就是这巧妙的一句
        for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
            prev_ngram_tuple = tuple(ngram[:-1])
            generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
​
    def _get_generated_ngrams(hypo_idx):
        # Before decoding the next token, prevent decoding of ngrams that have already appeared
        start_idx = cur_len + 1 - no_repeat_ngram_size
        ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].numpy().tolist())
        return generated_ngrams[hypo_idx].get(ngram_idx, [])
​
    banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
    return banned_tokens

高级解码策略

投机采样解码

投机采样是一种可以从根本上解码计算访存比的方法,保证和使用原始模型的采样分布完全相同。它使用两个模型:一个是原始目标模型,另一个是比原始模型小得多的近似模型。近似模型用于进行自回归串行采样,而大型模型则用于评估采样结果。解码过程中,某些token的解码相对容易,某些token的解码则很困难。因此,简单的token生成可以交给小型模型处理,而困难的token则交给大型模型处理。这里的小型模型可以采用与原始模型相同的结构,但参数更少,或者干脆使用n-gram模型。小型模型不仅计算量较小,更重要的是减少了内存访问的需求。

投机采样(Speculative Decoding)是Google和DeepMind在2022年同时发现的大模型推理加速方法。它可以在不损失生成效果前提下,获得3x以上的加速比。GPT-4泄密报告也提到OpenAI线上模型推理使用了它。

投机采样解码的原理

  1. 就是使用一个小模型来做草稿,然后使用大模型做纠正检查。
  2. 小模型的参数量要远小于原模型参数量一个级别才效果明显。
  3. 小模型和原模型的tokenizer最好一模一样,不然会增加额外的解码、编码时间。

2f38fcf7-7a23-4ce2-9d0a-5b09eec0badd.png

投机采样解码步骤

(1)小模型连续解码γ个token
output (2).png

(2)将Prefix + γ token拼接,送入大模型
output (3).png

此时,大模型可以一次预测γ个token, 然后使用逻辑判断大模型生成的token与小型的是否一致,如果第一个就不一致,直接抛弃,使用大模型新生成的prefix + 651 token作为prompt继续让小模型连续生成γ个token。 如果生成的token一致,则被大模型接受。
(3)一次重复,直到超过限度或生成结束符
output (4).png

投机采样解码的缺点

在投机采样中,Importance Sampling技术被用来生成与原始模型的预测概率分布相同的多样化输出。然而,后来的研究表明,当调高采样温度时,这种方法往往会变得效率较低(因为原始模型的输出会更多样化,导致很大概率原始模型的疏忽与draft模型的输出不一致,导致draf模型的输出被拒绝)。

简单来说,如果你的draft模型与你的target模型一样好,理想情况下,你应该接受它的所有输出,使得过程超级高效。然而,top-p采样可能会拒绝draft模型生成,从而导致并行解码长度很短。

美杜莎解码

美杜莎解码的核心思想是通过并行生成多个可能的未来token序列,来减少模型前向传播的次数,从而加速解码过程。使用较小的参考模型在每一步生成 token 序列,然后通过较大的原始模型进行细化以获得可接受的延续。​不过获得合适的参考模型仍然具有挑战性,并且将草稿模型集成到分布式系统中更加困难​。

来自普林斯顿大学、Together.AI、伊利诺伊大学厄巴纳 - 香槟分校等机构的研究者没有使用单独的参考模型来顺序生成候选输出,而是重新审视并完善了在主干模型之上使用多个解码头加速推理的概念。他们发现,如果该技术得到有效应用,可以克服推测解码的挑战,从而无缝地集成到现有 LLM 系统中。

美杜莎头

通常将解码过程分为两种类型:

  1. 常规解码:通常称为"Next-Token"预测,即每次只预测下一个token。
  2. 多token并行解码:定义为"Next-Next-Tokens"预测,可以同时预测多个未来token。

Medusa解码正是基于后者的思想,通过在现有模型基础上增加多个"Medusa Head"来实现。这些Medusa Head与原模型的Language Model (LM) Head协同工作,共同进行预测任务。这种设计使得模型能够并行生成多个可能的未来token序列,显著减少了模型前向传播的次数,从而大幅提高解码速度。通过动态调整预测长度和验证预生成的序列,Medusa能够灵活地适应不同的生成任务需求。

d149d89f-a6bc-4241-8ac9-5c93c759b981.png
新增美杜莎头的计算公式如下所示:
output.png

在一次模型迭代中,模型原始LM Head预测下一个token,Medusa Head 1预测下下个token,Medusa Head 2预测下下下个token,以此类推。 假设一次推理中,Prompt长度为20, 而Vocab Size为32000,那模型的输出形状如下:

  • Medusa logits shape: torch.Size([4, 1, 20, 32000])
  • logits shape: torch.Size([1, 20, 32000])

如果直接使用贪心算法,则一次可预测多个token:

medusa_pred = torch.argmax(medusa_logits[..., -1, :], dim = -1)
pred        = torch.argmax(logits[..., -1, :], dim = -1)
print('Base model prediction:', tokenizer.batch_decode(pred))
print('Medusa prediction:', tokenizer.batch_decode(medusa_pred))

preds = torch.cat([pred, medusa_pred[:, 0 ], dim = -1) # 将用于Verify
print('Combined prediction:', tokenizer.batch_decode(preds))

Base model prediction: ['Once']
Medusa prediction: ['upon', 'ly', 'time', ',']
Combined prediction: ['Once', 'upon', 'ly', 'time', ',']

通常,贪心算法的单个token预测准确率只有60%左右,而我们一次预测了多个token,准确率会进一步下降。因此我们也需要先投机采样那样,将新生成的token与原始prompt拼接在一起再forward一次,通过验证决定接收长度。

树状注意力

为了进一步提高预测准确率,我们采用TopK的方式进行采样:

  • LM Head 使用Top-1
  • Medusa Head使用Top-k
    bdbdedf4-cf32-4477-b01c-0436bda865e2.png

假设Topk=3, 美杜莎头个数为4(一次前向推理可以预测1+4个token),那么可能得候选token路径就有3^4=81种,这是一个庞大的数字。并且,可能得token路径形成一个树状的结构。论文中提出了一些对树进行剪枝的方法,如下所示,并提出了Tree Attention。
061d4547-66ec-4381-bf76-05593ee0c45a.png

典型接受

剪枝后的候选路径列表已经减少很多,但是大概率仍然不止一个。因此,我们仍然需要使用一些方法选择一个最佳的路径。收到截断采样(Top-K, Top-P)的启发, 美杜莎也设计了一种截断选择token的思路,其数学表达式如下:
output (1).png
简单描述就是说,在路径path里的token,如果它对应的概率值大于一个经验阈值与token所在的head的概率向量的熵乘以一个系数中较小的值,表示可以将该token考虑在内,否则将其抛弃。

def evaluate_posterior(logits, candidates, temperature, posterior_threshold=0.3,
posterior_alpha = 0.09, top_p=0.8, sampling = 'typical', fast = True):
# Predicted logits of shape (batch_size, sequence_length, vocab_size).
# candidates (torch.Tensor): Candidate token sequences.
posterior_prob  = torch.softmax(logits[:, :-1] / temperature, dim=-1)
# 找到每个candidates序列里每个token的概率值
candidates_prob = torch.gather(posterior_prob, dim=-1, 
    index=candidates[:, 1:].unsqueeze(-1)).squeeze(-1)

# 计算每个head概率向量的熵
posterior_entropy = -torch.sum(
    posterior_prob * torch.log(posterior_prob + 1e-5), dim=-1
) 

threshold = torch.minimum(
    torch.ones_like(posterior_entropy) * posterior_threshold,
    torch.exp(-posterior_entropy) * posterior_alpha,
)
posterior_mask = candidates_prob > threshold     # 候选path里的token是否满足熵条件

# cumprod的特点是,x方向某个token如果不满足,后面的mask累积全部为零
# 因此,sum求和累积的是从树顶部到底部连续的有效token数量,一旦出现一个无效,后面的有效也没用了
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)

# 找到有效长度最长的那些路径
accept_length = candidates_accept_length.max()
if accept_length == 0:
    # 如果没有,那就选择第一条路径
    best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
else:
    best_candidates = torch.where(candidates_accept_length == accept_length)[0]
    
    # 可能不止一条,选择最好的一条
    likelihood = torch.sum(
        torch.log(candidates_prob[best_candidates, :accept_length]), dim=-1
    )
    best_candidate = best_candidates[torch.argmax(likelihood)]

对比解码

LLM一般的解码方法是在每个time step中从截断的下一个词分布中进行抽样。例如,nucleus sampling 从下一个词分布的前p百分位中抽取;top-k 抽样从下一个词分布中的前k个候选者中抽取。另一种常见方法是通过贪婪解码或束搜索来寻找最可能的文本序列;但这会导致重复和乏味的输出。
fd6ebe9b-1ea9-4093-8dfd-2daf3ae5cb36.png

较小的语言模型(LMs)比较大的语言模型更容易产生不良的模式(例如,重复、主题漂移和自相矛盾)。例如,当专家(较大的LM)和业余者(较小的LM)都将最高概率分配给重复的词时,专家LM通常对这个决策的信心较低,也会将一些概率分配给其他好的、非重复的词。

基于此,对比解码(Contrastive Decoding)目标奖励大型专家LMs所偏爱的文本模式,并惩罚小型业余LMs所偏爱的模式,简单理解就是用大模型的预测减去小模型的预测从而消除一些错误的预测。数学上表达为:
8a468816-f387-4ac5-968f-4804082f8452.png

然而,​业余LMs并非总是错误的:小型语言模型仍然捕捉到许多英语语法和常识的简单方面(例如,主谓一致)。因此,无差别地惩罚来自业余LMs的所有行为将惩罚这些正确的简单方面(误报),反过来奖励不合理的词(漏报)。为解决这个问题,我们引入了合理性约束,它补充了CD目标并避免了这些失败模式。简单理解就是对大模型的预测概率分布进行截断,不要考虑太多token。
e7259897-344a-4b70-ae74-bde461f01b75.png

一个具体的例子如下:
fe85fffe-79aa-45fe-b3b2-faba3767152c.png

对比解码的概率分布示例:

  • p(​Honolulu​) = log(0.16) - log(0.08) = 0.6931
  • p(1964) = log(0.1) - log(0.001) = 4.60
  • p(Washington) = log(0.02) - log(0.04) = -0.6931

总的来说,对比解码通过减少表面级复制、减少推理步骤的遗漏和防止抽象推理错误等方式改进了推理任务。然而,对比解码在常识推理任务中的表现可能不稳定,对事实检索的性能略有下降。传统解码方法如贪婪解码和核采样在速度和文本多样性方面具有优势,但可能在生成的文本质量和准确性上存在一些问题。因此,在实际应用中,需要根据具体任务的要求和限制选择合适的解码方法。

0条评论
0 / 1000
c****d
4文章数
0粉丝数
c****d
4 文章 | 0 粉丝
原创

大型语言模型的解码策略解析

2024-08-14 10:03:57
116
0

1、概述

大型语言模型(LLM)的推理过程面临着严重的性能瓶颈。这些模型通常采用自回归采样方法,需要逐个token进行串行解码,导致推理过程异常缓慢。更关键的是,每生成一个token都需要将所有参数从存储单元传输到计算单元,使得内存访问带宽成为制约因素,严重影响了推理速度。业界已经开发了多种工程优化方案,如改进计算核心实现、多卡并行计算和批处理策略等。然而,这些方法虽然在一定程度上提升了性能,却未能从根本上解决LLM解码过程受制于访存带宽的核心问题。

在这种背景下,探索更高效的解码策略变得尤为重要。本文旨在提供解码策略概览,涵盖从基础到高级的多种方法。首先探讨基础解码方法,包括贪心算法、Beam Search和采样等经典策略。接着,聚焦于最新的高级解码策略,如投机采样、美杜莎解码和对比解码。

基础解码策略

Greedy Search

Greedy Search贪婪搜索策略是一种简单的解码策略,其核心思想是每一步都选择概率最大的单词输出,最后组成整个句子输出。这种方法给出的结果一般情况结果比较差,因为只考虑了​每一步的最优解​,​并不一定是全局最优​,因为贪婪搜索会错过隐藏在低概率单词后面的高概率单词,另外当如果每个位置选错了,后续位置生成的内容很可能也是错误的,具有错误的累加效果。

贪婪搜索由于每步只需要关注最大得分,考虑的因素少,因此实现容易,执行速度快,其实现原理图如下所示。
image-20240809142541425.png

Beam Search

Beam Search 是Greedy Search的一种改进,通过在每一步保持多个候选来平衡搜索的效率和结果的质量,是一种受限的宽度优先搜索方法,经常用在各种 NLP 生成类任务中,例如机器翻译、对话系统、文本摘要。本文首先介绍 Beam Search 的相关概念和得分函数优化方法,然后介绍一种新的 Best-First Beam Search 方法,Best-First Beam Search 结合了优先队列和 A* 启发式搜索方法,可以提升 Beam Search 的速度。

如上文所示,Greedy Search在每一时刻只选择当前最有可能的单词,例如在预测第一个单词时,"我" 的概率最大,则第一个单词预测为 "我";预测第二个单词时,"爱" 的概率最大,则预测为 "爱"。Greedy Searc具有比较高的运行效率,但是每一步考虑的均是局部最优,有时候不能得到全局最优解。

Beam Search 对贪心搜索进行了改进,扩大了搜索空间,更容易得到全局最优解。Beam Search 包含一个参数 beam size=k,表示每一时刻均保留得分最高的 k 个序列,然后下一时刻用这 k 个序列继续生成。其具体的实现步骤为:
(1)在每一步解码时,保留最好的 k 个候选序列,其中 k 称为束宽(beam width)。
(2)对于每个候选序列,生成下一个可能的词,并计算新序列的概率。
(3)从所有新生成的序列中选择前 k 个最高概率的序列继续下一步。

下图展示了 Beam Search 的过程,对应的 k=2:

image.png

采样

随机采样

随机采样基于模型在每个时间步预测的概率分布来选择下一个词。不同于贪心搜索总是选择概率最高的词,随机采样会根据预测的概率随机选择一个词。采样的依据就是解码器输出的词典中每个词的概率分布。相比于按概率“掐尖”,这样会增大所选词的范围,引入更多的随机性。

采样的时候有一个可以控制的超参数,称为温度(temperature, T)。解码器的输出层后面通常会跟一个softmax函数来将输出概率归一化,通过改变T可以控制概率的形貌。softmax的公式如下,当T大的时候,概率分布趋向平均,随机性增大;当T小的时候,概率密度趋向于集中,即强者愈强,随机性降低,会更多地采样出“放之四海而皆准”的词汇。

image-20240809143012610.png

TopK

TopK就是在采样前将输出的概率分布截断,取出概率最大的k个词构成一个集合,然后将这个子集词的概率再归一化,最后从新的概率分布中采样词汇。

这个办法据说可以获得比Beam Search好很多的效果,但也有一个问题,就是这个k不太好选。 因为这个概率分布变化比较大,有时候可能很均匀(flat),有的时候比较集中(peaked)。对于集中的情况还好说,当分布均匀时,一个较小的k容易丢掉很多优质候选词。但如果k定的太大,这个方法又会退化回普通采样。

image (1).png

TopP

TopP采用累计概率的方式(也称为核采样, Nucleus Sampling),该方式不再取一个固定的k,而是固定候选集合的概率密度和在整个概率分布中的比例,即从累计概率超过某一阈值p的词汇中进行采样,根据参数p的大小调节,增大了出现概率较小的词汇的生成的概率。如下是Hugging Face中TopK和TopP采样实现的源码:

def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
   """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
       Args:
           logits: logits distribution shape (batch size, vocabulary size)
           if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
           if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
               Nucleus filtering is described in Holtzman et al.
           Make sure we keep at least min_tokens_to_keep per batch example in the output
   """
   if top_k > 0:
       top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))  # Safety check
       # Remove all tokens with a probability less than the last token of the top-k
       indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
       logits[indices_to_remove] = filter_value
​
   if top_p < 1.0:
       sorted_logits, sorted_indices = torch.sort(logits, descending=True)
       cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
​
       # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
       sorted_indices_to_remove = cumulative_probs > top_p
       if min_tokens_to_keep > 1:
           # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
           sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
       # Shift the indices to the right to keep also the first token above the threshold
       sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
       sorted_indices_to_remove[..., 0] = 0
​
       # scatter sorted tensors to original indexing
       indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
       logits[indices_to_remove] = filter_value
   return logits

惩罚重复

为了解决重复问题,还可以通过惩罚因子将出现过词的概率变小或者强制不使用重复词来解决。惩罚因子来自于同样广为流传的[《CTRL: A Conditional Transformer Language Model for Controllable Generation》],也可参考知乎文章《[Top-K+重复性惩罚]》。

以下是Huggind Face中对重复惩罚的源码实现:

# 输入的同样是logits(lprobs)
# 同时输入了之前出现过的词以及惩罚系数(大于1的)
# 考虑到了logit是正和负时处理方式应该不一样
def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
        """"""
        for i in range(batch_size * num_beams):
            for previous_token in set(prev_output_tokens[i].tolist()):
                # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
                if lprobs[i, previous_token] < 0:
                    lprobs[i, previous_token] *= repetition_penalty
                else:
                    lprobs[i, previous_token] /= repetition_penalty
                    
                    
# 去重复词
# 这个函数将会返回一个不可使用的词表
# 生成n-gram的巧妙方式大家可以借鉴一下
# 下面是一个3-gram的例子
# a = [1,2,3,4,5]
# for ngram in zip(*[a[i:] for i in range(3)]):
#    print(ngram)
def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
    # Copied from fairseq for no_repeat_ngram in beam_search"""
    if cur_len + 1 < no_repeat_ngram_size:
        # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
        return [[] for _ in range(num_hypos)]
    generated_ngrams = [{} for _ in range(num_hypos)]
    for idx in range(num_hypos):
        gen_tokens = prev_input_ids[idx].numpy().tolist()
        generated_ngram = generated_ngrams[idx]
        # 就是这巧妙的一句
        for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
            prev_ngram_tuple = tuple(ngram[:-1])
            generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
​
    def _get_generated_ngrams(hypo_idx):
        # Before decoding the next token, prevent decoding of ngrams that have already appeared
        start_idx = cur_len + 1 - no_repeat_ngram_size
        ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].numpy().tolist())
        return generated_ngrams[hypo_idx].get(ngram_idx, [])
​
    banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
    return banned_tokens

高级解码策略

投机采样解码

投机采样是一种可以从根本上解码计算访存比的方法,保证和使用原始模型的采样分布完全相同。它使用两个模型:一个是原始目标模型,另一个是比原始模型小得多的近似模型。近似模型用于进行自回归串行采样,而大型模型则用于评估采样结果。解码过程中,某些token的解码相对容易,某些token的解码则很困难。因此,简单的token生成可以交给小型模型处理,而困难的token则交给大型模型处理。这里的小型模型可以采用与原始模型相同的结构,但参数更少,或者干脆使用n-gram模型。小型模型不仅计算量较小,更重要的是减少了内存访问的需求。

投机采样(Speculative Decoding)是Google和DeepMind在2022年同时发现的大模型推理加速方法。它可以在不损失生成效果前提下,获得3x以上的加速比。GPT-4泄密报告也提到OpenAI线上模型推理使用了它。

投机采样解码的原理

  1. 就是使用一个小模型来做草稿,然后使用大模型做纠正检查。
  2. 小模型的参数量要远小于原模型参数量一个级别才效果明显。
  3. 小模型和原模型的tokenizer最好一模一样,不然会增加额外的解码、编码时间。

2f38fcf7-7a23-4ce2-9d0a-5b09eec0badd.png

投机采样解码步骤

(1)小模型连续解码γ个token
output (2).png

(2)将Prefix + γ token拼接,送入大模型
output (3).png

此时,大模型可以一次预测γ个token, 然后使用逻辑判断大模型生成的token与小型的是否一致,如果第一个就不一致,直接抛弃,使用大模型新生成的prefix + 651 token作为prompt继续让小模型连续生成γ个token。 如果生成的token一致,则被大模型接受。
(3)一次重复,直到超过限度或生成结束符
output (4).png

投机采样解码的缺点

在投机采样中,Importance Sampling技术被用来生成与原始模型的预测概率分布相同的多样化输出。然而,后来的研究表明,当调高采样温度时,这种方法往往会变得效率较低(因为原始模型的输出会更多样化,导致很大概率原始模型的疏忽与draft模型的输出不一致,导致draf模型的输出被拒绝)。

简单来说,如果你的draft模型与你的target模型一样好,理想情况下,你应该接受它的所有输出,使得过程超级高效。然而,top-p采样可能会拒绝draft模型生成,从而导致并行解码长度很短。

美杜莎解码

美杜莎解码的核心思想是通过并行生成多个可能的未来token序列,来减少模型前向传播的次数,从而加速解码过程。使用较小的参考模型在每一步生成 token 序列,然后通过较大的原始模型进行细化以获得可接受的延续。​不过获得合适的参考模型仍然具有挑战性,并且将草稿模型集成到分布式系统中更加困难​。

来自普林斯顿大学、Together.AI、伊利诺伊大学厄巴纳 - 香槟分校等机构的研究者没有使用单独的参考模型来顺序生成候选输出,而是重新审视并完善了在主干模型之上使用多个解码头加速推理的概念。他们发现,如果该技术得到有效应用,可以克服推测解码的挑战,从而无缝地集成到现有 LLM 系统中。

美杜莎头

通常将解码过程分为两种类型:

  1. 常规解码:通常称为"Next-Token"预测,即每次只预测下一个token。
  2. 多token并行解码:定义为"Next-Next-Tokens"预测,可以同时预测多个未来token。

Medusa解码正是基于后者的思想,通过在现有模型基础上增加多个"Medusa Head"来实现。这些Medusa Head与原模型的Language Model (LM) Head协同工作,共同进行预测任务。这种设计使得模型能够并行生成多个可能的未来token序列,显著减少了模型前向传播的次数,从而大幅提高解码速度。通过动态调整预测长度和验证预生成的序列,Medusa能够灵活地适应不同的生成任务需求。

d149d89f-a6bc-4241-8ac9-5c93c759b981.png
新增美杜莎头的计算公式如下所示:
output.png

在一次模型迭代中,模型原始LM Head预测下一个token,Medusa Head 1预测下下个token,Medusa Head 2预测下下下个token,以此类推。 假设一次推理中,Prompt长度为20, 而Vocab Size为32000,那模型的输出形状如下:

  • Medusa logits shape: torch.Size([4, 1, 20, 32000])
  • logits shape: torch.Size([1, 20, 32000])

如果直接使用贪心算法,则一次可预测多个token:

medusa_pred = torch.argmax(medusa_logits[..., -1, :], dim = -1)
pred        = torch.argmax(logits[..., -1, :], dim = -1)
print('Base model prediction:', tokenizer.batch_decode(pred))
print('Medusa prediction:', tokenizer.batch_decode(medusa_pred))

preds = torch.cat([pred, medusa_pred[:, 0 ], dim = -1) # 将用于Verify
print('Combined prediction:', tokenizer.batch_decode(preds))

Base model prediction: ['Once']
Medusa prediction: ['upon', 'ly', 'time', ',']
Combined prediction: ['Once', 'upon', 'ly', 'time', ',']

通常,贪心算法的单个token预测准确率只有60%左右,而我们一次预测了多个token,准确率会进一步下降。因此我们也需要先投机采样那样,将新生成的token与原始prompt拼接在一起再forward一次,通过验证决定接收长度。

树状注意力

为了进一步提高预测准确率,我们采用TopK的方式进行采样:

  • LM Head 使用Top-1
  • Medusa Head使用Top-k
    bdbdedf4-cf32-4477-b01c-0436bda865e2.png

假设Topk=3, 美杜莎头个数为4(一次前向推理可以预测1+4个token),那么可能得候选token路径就有3^4=81种,这是一个庞大的数字。并且,可能得token路径形成一个树状的结构。论文中提出了一些对树进行剪枝的方法,如下所示,并提出了Tree Attention。
061d4547-66ec-4381-bf76-05593ee0c45a.png

典型接受

剪枝后的候选路径列表已经减少很多,但是大概率仍然不止一个。因此,我们仍然需要使用一些方法选择一个最佳的路径。收到截断采样(Top-K, Top-P)的启发, 美杜莎也设计了一种截断选择token的思路,其数学表达式如下:
output (1).png
简单描述就是说,在路径path里的token,如果它对应的概率值大于一个经验阈值与token所在的head的概率向量的熵乘以一个系数中较小的值,表示可以将该token考虑在内,否则将其抛弃。

def evaluate_posterior(logits, candidates, temperature, posterior_threshold=0.3,
posterior_alpha = 0.09, top_p=0.8, sampling = 'typical', fast = True):
# Predicted logits of shape (batch_size, sequence_length, vocab_size).
# candidates (torch.Tensor): Candidate token sequences.
posterior_prob  = torch.softmax(logits[:, :-1] / temperature, dim=-1)
# 找到每个candidates序列里每个token的概率值
candidates_prob = torch.gather(posterior_prob, dim=-1, 
    index=candidates[:, 1:].unsqueeze(-1)).squeeze(-1)

# 计算每个head概率向量的熵
posterior_entropy = -torch.sum(
    posterior_prob * torch.log(posterior_prob + 1e-5), dim=-1
) 

threshold = torch.minimum(
    torch.ones_like(posterior_entropy) * posterior_threshold,
    torch.exp(-posterior_entropy) * posterior_alpha,
)
posterior_mask = candidates_prob > threshold     # 候选path里的token是否满足熵条件

# cumprod的特点是,x方向某个token如果不满足,后面的mask累积全部为零
# 因此,sum求和累积的是从树顶部到底部连续的有效token数量,一旦出现一个无效,后面的有效也没用了
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)

# 找到有效长度最长的那些路径
accept_length = candidates_accept_length.max()
if accept_length == 0:
    # 如果没有,那就选择第一条路径
    best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
else:
    best_candidates = torch.where(candidates_accept_length == accept_length)[0]
    
    # 可能不止一条,选择最好的一条
    likelihood = torch.sum(
        torch.log(candidates_prob[best_candidates, :accept_length]), dim=-1
    )
    best_candidate = best_candidates[torch.argmax(likelihood)]

对比解码

LLM一般的解码方法是在每个time step中从截断的下一个词分布中进行抽样。例如,nucleus sampling 从下一个词分布的前p百分位中抽取;top-k 抽样从下一个词分布中的前k个候选者中抽取。另一种常见方法是通过贪婪解码或束搜索来寻找最可能的文本序列;但这会导致重复和乏味的输出。
fd6ebe9b-1ea9-4093-8dfd-2daf3ae5cb36.png

较小的语言模型(LMs)比较大的语言模型更容易产生不良的模式(例如,重复、主题漂移和自相矛盾)。例如,当专家(较大的LM)和业余者(较小的LM)都将最高概率分配给重复的词时,专家LM通常对这个决策的信心较低,也会将一些概率分配给其他好的、非重复的词。

基于此,对比解码(Contrastive Decoding)目标奖励大型专家LMs所偏爱的文本模式,并惩罚小型业余LMs所偏爱的模式,简单理解就是用大模型的预测减去小模型的预测从而消除一些错误的预测。数学上表达为:
8a468816-f387-4ac5-968f-4804082f8452.png

然而,​业余LMs并非总是错误的:小型语言模型仍然捕捉到许多英语语法和常识的简单方面(例如,主谓一致)。因此,无差别地惩罚来自业余LMs的所有行为将惩罚这些正确的简单方面(误报),反过来奖励不合理的词(漏报)。为解决这个问题,我们引入了合理性约束,它补充了CD目标并避免了这些失败模式。简单理解就是对大模型的预测概率分布进行截断,不要考虑太多token。
e7259897-344a-4b70-ae74-bde461f01b75.png

一个具体的例子如下:
fe85fffe-79aa-45fe-b3b2-faba3767152c.png

对比解码的概率分布示例:

  • p(​Honolulu​) = log(0.16) - log(0.08) = 0.6931
  • p(1964) = log(0.1) - log(0.001) = 4.60
  • p(Washington) = log(0.02) - log(0.04) = -0.6931

总的来说,对比解码通过减少表面级复制、减少推理步骤的遗漏和防止抽象推理错误等方式改进了推理任务。然而,对比解码在常识推理任务中的表现可能不稳定,对事实检索的性能略有下降。传统解码方法如贪婪解码和核采样在速度和文本多样性方面具有优势,但可能在生成的文本质量和准确性上存在一些问题。因此,在实际应用中,需要根据具体任务的要求和限制选择合适的解码方法。

文章来自个人专栏
AI-大型语言模型
2 文章 | 1 订阅
0条评论
0 / 1000
请输入你的评论
3
2