2023年10月19日,NVIDIA正式宣布TensorRT-LLM开放使用,TensorRT-LLM的主要特性有:
- 生成式AI应用端到端部署框架(模型构建、自定义、格式转换、部署)
- 支持多GPU多节点推理
- 包含常见大模型的转换、部署示例(LLaMA系列、ChatGLM系列、GPT系列、Baichuan、BLOOM、OPT、Falcon等)
- 提供Python API支持新模型的构建和转换
- 支持Triton推理服务框架
- 支持多种NVIDIA架构:Volta, Turing, Ampere, Hopper 和Ada Lovelace
- 除了FastTransformer中针对transformer结构的优化项,新增了多种针对大模型的优化项,如In-flight Batching、Paged KV Cache for the Attention、INT4/INT8 Weight-Only Quantization、SmoothQuant、Multi-head Attention(MHA)、Multi-query Attention (MQA)、Group-quer
本文参考TensorRT-LLM和Triton inference server开源的tensorrtllm_backend官方文档,成功验证了“TensorRT-LLM + Triton”这套方案部署ChatGLM2-6B模型推理服务的可行性。在开始正文前,做以下几点说明:
-
TensorRT-LLM开源的代码中不支持ChatGLM2-6B模型i动态批处理模式(in-flight batching),本文对源码做了一些修改以支持动态批处理;
注:动态批处理是一种优化调度技术,处理批量请求时立即移除已完成的序列并开始处理下一个新的请求,此时batch中的其他请求仍然在进行中,这样可以提高吞吐量和GPU的利用率
-
本文修改了tensorrtllm_backend给出的triton client推理测试脚本以增加人机交互的友好性;
-
本文尝试多卡推理ChatGLM2-6B模型失败,给官方repo提issue得到的回复是目前TensorRT-LLM不支持ChatGLM2-6B的多卡推理,等待后续代码更新,本文用的是单机单卡;
1. TensorRT-LLM编译与镜像制作
目前TensorRT-LLM只能从源码手动安装,后续官方会推出包含TensorRT-LLM和Triton推理服务后端的Docker镜像。
-
拉取TensorRT-LLM源码
apt-get update && apt-get -y install git git-lfs git clone h提提ps://github.com/NVIDIA/TensorRT-LLM.git cd TensorRT-LLM git submodule update --init --recursive git lfs install git lfs pull
-
构建TensorRT-LLM的docker镜像
make -C docker release_build
- 运行docker容器
make -C docker run
该命令会以
tensorrt_llm/devel:latest
为镜像启动一个名为tensorrt_llm-devel-root
的容器
2. Huggingface格式模型转换为TensorRT-LLM Engine
注:以下步骤在tensorrt_llm-devel-root
容器内执行
- 获取chatglm2-6b模型预训练权重放到pyTorchModel目录下
apt-get update apt-get install git-lfs cd examples/chatglm2-6b/ git clone h提提ps://huggingface.co/THUDM/chatglm2-6b pyTorchModel
- Huggingface模型转换成TensorRT-LLM Engine
注意:
-
目前TensorRT-LLM开源代码中的ChatGLM2-6B模型不支持inflight batching,本文对模型文件做了修改以支持inflight batching
-
build这一步设置的
--world_size
需要和后面推理时需要的gpu数量一致,否则后面推理会报错
(重要)转换模型之前先修改build.py和model.py以支持inflight batching
修改
examples/chatglm2-6b/build.py
如下:# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # h提提p://3w.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import os import time import tensorrt as trt import torch import torch.multiprocessing as mp import transformers from weight import load_from_hf_chatglm2_6B import tensorrt_llm from tensorrt_llm._utils import str_dtype_to_trt from tensorrt_llm.builder import Builder from tensorrt_llm.layers import AttentionMaskType from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping from tensorrt_llm.models import (ChatGLM2HeadModel, smooth_quantize, weight_only_quantize) from tensorrt_llm.network import net_guard from tensorrt_llm.plugin.plugin import ContextFMHAType from tensorrt_llm.quantization import QuantMode MODEL_NAME = "chatglm2-6b" def get_engine_name(model, dtype, tp_size, rank): return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank) def serialize_engine(engine, path): logger.info(f'Serializing engine to {path}...') tik = time.time() with open(path, 'wb') as f: f.write(bytearray(engine)) tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) logger.info(f'Engine serialized. Total time: {t}') def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--world_size', type=int, default=1, help='world size, only support tensor parallelism now') parser.add_argument('--model_dir', type=str, default="./pyTorchModel") parser.add_argument('--dtype', type=str, default='float16', choices=['float16', 'float32']) parser.add_argument( '--timing_cache', type=str, default='model.cache', help= 'The path of to read timing cache from, will be ignored if the file does not exist' ) parser.add_argument( '--log_level', type=str, default='verbose', choices=['verbose', 'info', 'warning', 'error', 'internal_error']) parser.add_argument('--vocab_size', type=int, default=65024) parser.add_argument('--n_layer', type=int, default=28) parser.add_argument('--n_positions', type=int, default=2048) parser.add_argument('--n_embd', type=int, default=4096) parser.add_argument('--n_head', type=int, default=32) parser.add_argument('--hidden_act', type=str, default='gelu') parser.add_argument( '--rotary_pct', type=float, default=0.0, help="Setting this to a value > 0.0 (and <= 1.0) activates RoPE.") parser.add_argument('--inter_size', type=int, default=None) parser.add_argument('--no_bias', action="store_false") parser.add_argument('--max_batch_size', type=int, default=4) parser.add_argument('--max_input_len', type=int, default=4096) parser.add_argument('--max_output_len', type=int, default=2048) parser.add_argument('--max_beam_width', type=int, default=1) parser.add_argument( '--use_gpt_attention_plugin', nargs='?', const='float16', default='float16', # default=False, choices=['float16', 'float32', False]) parser.add_argument('--use_gemm_plugin', nargs='?', const='float16', type=str, default='float16', choices=['float16', 'float32']) parser.add_argument('--use_layernorm_plugin', nargs='?', const='float16', type=str, default='float16', choices=['float16', 'float32']) parser.add_argument('--parallel_build', default=False, action='store_true') parser.add_argument('--enable_context_fmha', default=False, action='store_true') parser.add_argument('--enable_context_fmha_fp32_acc', default=False, action='store_true') parser.add_argument('--gpus_per_node', type=int, default=8) parser.add_argument('--builder_opt', type=int, default=None) parser.add_argument( '--output_dir', type=str, default='trtModel', help= 'The path to save the serialized engine files, timing cache file and model configs' ) parser.add_argument( "--multi_query_mode", "-mq", default=False, action='store_true', help= "Whether this model uses multi-query attention mechanism (default: False)" ) # Arguments related to the quantization of the model. parser.add_argument( '--use_smooth_quant', default=False, action="store_true", help= 'Use the SmoothQuant method to quantize activations and weights for the various GEMMs.' 'See --per_channel and --per_token for finer-grained quantization options.' ) parser.add_argument( '--use_weight_only', default=False, action="store_true", help='Quantize weights for the various GEMMs to INT4/INT8.' 'See --weight_only_precision to set the precision') parser.add_argument( '--weight_only_precision', const='int8', type=str, nargs='?', default='int8', choices=['int8', 'int4'], help= 'Define the precision for the weights when using weight-only quantization.' 'You must also use --use_weight_only for that argument to have an impact.' ) parser.add_argument( '--per_channel', default=False, action="store_true", help= 'By default, we use a single static scaling factor for the GEMM\'s result. ' 'per_channel instead uses a different static scaling factor for each channel. ' 'The latter is usually more accurate, but a little slower.') parser.add_argument( '--per_token', default=False, action="store_true", help= 'By default, we use a single static scaling factor to scale activations in the int8 range. ' 'per_token chooses at run time, and for each token, a custom scaling factor. ' 'The latter is usually more accurate, but a little slower.') parser.add_argument( '--int8_kv_cache', default=False, action="store_true", help= 'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV' ) parser.add_argument( '--random_seed', type=int, default=None, help= 'Seed to use when initializing the random number generator for torch.') # 新增代码以支持inflight_batching, by shaun xiong (231024) parser.add_argument('--remove_input_padding', default=False, action='store_true') parser.add_argument( '--use_inflight_batching', action="store_true", default=False, help="Activates inflight batching mode of gptAttentionPlugin.") parser.add_argument( '--paged_kv_cache', action="store_true", default=False, help= 'By default we use contiguous KV cache. By setting this flag you enable paged KV cache' ) parser.add_argument('--tokens_per_block', type=int, default=64, help='Number of tokens per block in paged KV cache') ######################################################## args = parser.parse_args() # 新增代码以支持inflight_batching (by shaun xiong 231024) if not args.remove_input_padding: if args.use_gpt_attention_plugin: logger.warning( f"It is recommended to specify --remove_input_padding when using GPT attention plugin" ) if args.use_inflight_batching: if not args.use_gpt_attention_plugin: args.use_gpt_attention_plugin = 'float16' logger.info( f"Using GPT attention plugin for inflight batching mode. " f"Setting to default '{args.use_gpt_attention_plugin}'") if not args.remove_input_padding: args.remove_input_padding = True logger.info( 'Using remove input padding for inflight batching mode.') if not args.paged_kv_cache: args.paged_kv_cache = True logger.info('Using paged KV cache for inflight batching mode.') ######################################################## assert not ( args.use_smooth_quant and args.use_weight_only ), "You cannot enable both SmoothQuant and INT8 weight-only together." if args.use_smooth_quant: args.quant_mode = QuantMode.use_smooth_quant(args.per_token, args.per_channel) elif args.use_weight_only: args.quant_mode = QuantMode.use_weight_only( args.weight_only_precision == 'int4') else: args.quant_mode = QuantMode(0) args.bias = not args.no_bias if args.inter_size is None: args.inter_size = 4 * args.n_embd if args.int8_kv_cache: assert ( args.use_gpt_attention_plugin ), "You have to use GPT attention plugin when int8 KV cache is set" args.quant_mode = args.quant_mode.set_int8_kv_cache() return args def build_rank_engine(builder: Builder, builder_config: tensorrt_llm.builder.BuilderConfig, engine_name, rank, args): ''' @brief: Build the engine on the given rank. @param rank: The rank to build the engine. @param args: The cmd line arguments. @return: The built engine. ''' str_dtype_to_trt(args.dtype) # Initialize Module tensorrt_llm_ChatGLM2_6BModel = ChatGLM2HeadModel( hidden_size=4096, num_attention_heads=32, kv_channels=128, multi_query_group_num=2, apply_query_key_layer_scaling=False, attention_mask_type=AttentionMaskType.causal, qkv_bias=True, linear_bias=False, use_int8_kv_cache=False, mapping=Mapping(world_size=args.world_size, rank=rank, tp_size=args.world_size), ffn_hiden_size=13696, num_layers=28, eps=1e-5, act_func='swiglu', dtype=trt.float16, quant_mode=QuantMode(0), max_seq_length=32768, vocab_size=65024) if args.use_smooth_quant: tensorrt_llm_ChatGLM2_6BModel = smooth_quantize( tensorrt_llm_ChatGLM2_6BModel, args.quant_mode) elif args.use_weight_only: tensorrt_llm_ChatGLM2_6BModel = weight_only_quantize( tensorrt_llm_ChatGLM2_6BModel, args.quant_mode) if args.model_dir is not None: print('loading weights from hugging face model') hf_model = transformers.AutoModel.from_pretrained( args.model_dir, trust_remote_code=True).cpu() tensorrt_llm_ChatGLM2_6BModel = load_from_hf_chatglm2_6B( tensorrt_llm_ChatGLM2_6BModel, hf_model, dtype='float16') print('load finished') del hf_model # Module -> Network network = builder.create_network() network.trt_network.name = engine_name if args.use_gpt_attention_plugin: network.plugin_config.set_gpt_attention_plugin( dtype=args.use_gpt_attention_plugin) if args.use_gemm_plugin: network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin) if args.use_layernorm_plugin: network.plugin_config.set_layernorm_plugin( dtype=args.use_layernorm_plugin) assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc) if args.enable_context_fmha: network.plugin_config.set_context_fmha(ContextFMHAType.enabled) if args.enable_context_fmha_fp32_acc: network.plugin_config.set_context_fmha( ContextFMHAType.enabled_with_fp32_acc) # Quantization plugins. if args.use_smooth_quant: network.plugin_config.set_smooth_quant_gemm_plugin(dtype=args.dtype) network.plugin_config.set_layernorm_quantization_plugin( dtype=args.dtype) # FIXME network.plugin_config.set_quantize_tensor_plugin() network.plugin_config.set_quantize_per_token_plugin() elif args.use_weight_only: network.plugin_config.set_weight_only_quant_matmul_plugin( dtype='float16') if args.world_size > 1: network.plugin_config.set_nccl_plugin(args.dtype) #新增代码以支持inflight_batching (by shaun xiong 231024) if args.remove_input_padding: network.plugin_config.enable_remove_input_padding() if args.paged_kv_cache: network.plugin_config.enable_paged_kv_cache(args.tokens_per_block) #### with net_guard(network): # Prepare network.set_named_parameters( tensorrt_llm_ChatGLM2_6BModel.named_parameters()) # Forward inputs = tensorrt_llm_ChatGLM2_6BModel.prepare_inputs( args.max_batch_size, args.max_input_len, args.max_output_len, True, args.max_beam_width) tensorrt_llm_ChatGLM2_6BModel(*inputs) tensorrt_llm.graph_rewriting.optimize(network) engine = None # Network -> Engine engine = builder.build_engine(network, builder_config) if rank == 0: config_path = os.path.join(args.output_dir, 'config.json') builder.save_config(builder_config, config_path) return engine def build(rank, args): torch.cuda.set_device(rank % args.gpus_per_node) tensorrt_llm.logger.set_level(args.log_level) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) # when doing serializing build, all ranks share one engine apply_query_key_layer_scaling = False builder = Builder() cache = None for cur_rank in range(args.world_size): # skip other ranks if parallel_build is enabled if args.parallel_build and cur_rank != rank: continue builder_config = builder.create_builder_config( name=MODEL_NAME, precision=args.dtype, timing_cache=args.timing_cache if cache is None else cache, tensor_parallel=args.world_size, # TP only parallel_build=args.parallel_build, num_layers=args.n_layer, num_heads=args.n_head, hidden_size=args.n_embd, vocab_size=args.vocab_size, hidden_act=args.hidden_act, max_position_embeddings=args.n_positions, apply_query_key_layer_scaling=apply_query_key_layer_scaling, max_batch_size=args.max_batch_size, max_input_len=args.max_input_len, max_output_len=args.max_output_len, int8=(args.quant_mode.has_act_and_weight_quant() or args.quant_mode.has_int8_kv_cache()), opt_level=args.builder_opt, multi_query_mode=args.multi_query_mode) engine_name = get_engine_name(MODEL_NAME, args.dtype, args.world_size, cur_rank) engine = build_rank_engine(builder, builder_config, engine_name, cur_rank, args) assert engine is not None, f'Failed to build engine for rank {cur_rank}' if cur_rank == 0: # Use in-memory timing cache for multiple builder passes. if not args.parallel_build: cache = builder_config.trt_builder_config.get_timing_cache() serialize_engine(engine, os.path.join(args.output_dir, engine_name)) if rank == 0: ok = builder.save_timing_cache( builder_config, os.path.join(args.output_dir, "model.cache")) assert ok, "Failed to save timing cache." if __name__ == '__main__': args = parse_arguments() if args.random_seed is not None: torch.manual_seed(args.random_seed) logger.set_level(args.log_level) tik = time.time() if args.parallel_build and args.world_size > 1 and \ torch.cuda.device_count() >= args.world_size: logger.warning( f'Parallelly build TensorRT engines. Please make sure that all of the {args.world_size} GPUs are totally free.' ) mp.spawn(build, nprocs=args.world_size, args=(args, )) else: args.parallel_build = False logger.info('Serially build TensorRT engines.') build(0, args) tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) logger.info(f'Total time of building all {args.world_size} engines: {t}')
修改
/usr/local/lib/python3.10/dist-packages/tensorrt_llm/models/chatglm2_6b/model.py
如下:# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # h提提p://3w.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import tensorrt as trt import torch from ..._common import default_net from ..._utils import pad_vocab_size, str_dtype_to_trt from ...functional import (PositionEmbeddingType, Tensor, concat, constant, expand, expand_dims, gather_last_token_logits, gpt_attention, index_select, select, shape, slice, split) from ...layers import (MLP, AttentionMaskType, AttentionParams, ColumnLinear, Embedding, KeyValueCacheParams, RmsNorm, RowLinear) from ...mapping import Mapping from ...module import Module, ModuleList from ...parameter import Parameter from ...quantization import QuantMode from ..generation_mixin import GenerationMixin def apply_rotary_pos_emb_trt(x: Tensor, rope_cache: Tensor) -> Tensor: # x-> [seq, batch, num_heads, 2] x = x.permute((1, 0, 2, 3)) # sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) sq = shape(x, 0) b = shape(x, 1) nh = shape(x, 2) shape(x, 3) # rope_cache shape: seq,batch,heads,2 rot_dim = 2* numheads #rope_cache: seq,batch,num_states/4,2 rot_dim = shape(rope_cache, 2) * constant(np.array(2, dtype=np.int32)) starts = concat([0, 0, 0, 0]) sizes = concat([sq, b, nh, rot_dim]) # first half x_rot = slice(x, starts, sizes) starts = concat([0, 0, 0, rot_dim]) # second half x_pass = slice(x, starts, sizes) # truncate to support variable sizes rope_cache = slice(rope_cache, (0, 0, 0, 0), (concat( [sq, shape(rope_cache, 1), shape(rope_cache, 2), shape(rope_cache, 3)]))) xshaped = x_rot.view(concat([sq, b, nh, rot_dim / 2, 2])) rope_cache = rope_cache.view(concat([sq, b, 1, shape(xshaped, 3), 2])) # first half xshape0 = select(xshaped, 4, 0) # second half xshape1 = select(xshaped, 4, 1) # first half rope_cache0 = select(rope_cache, 4, 0) # second half rope_cache1 = select(rope_cache, 4, 1) out0 = xshape0 * rope_cache0 - xshape1 * rope_cache1 out1 = xshape1 * rope_cache0 + xshape0 * rope_cache1 out0 = expand_dims(out0, 4) out1 = expand_dims(out1, 4) x_out2_v1 = concat([out0, out1], 4) x_out2 = x_out2_v1.view( concat([sq, b, nh, shape(x_out2_v1, 3) * shape(x_out2_v1, 4)])) output = concat([x_out2, x_pass], dim=3) # to batch,seq,num_group,head_states output = output.permute((1, 0, 2, 3)) return output class RotaryEmbedding(Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, seq_len: int): theta = 1.0 / (10000**(torch.arange(0, self.dim, 2) / self.dim)) seq_idx = torch.arange(seq_len) idx_theta = torch.outer(seq_idx, theta).float() cache = torch.stack( [torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) cache = cache.half() # create rope embeddings and make it constant cache = constant(cache.numpy()) return cache class ChatGLM2Attention(Module): def __init__( self, hidden_size, num_attention_heads, layer_number, kv_channels=128, multi_query_group_num=2, apply_query_key_layer_scaling=False, attention_mask_type=AttentionMaskType.causal, qkv_bias=True, linear_bias=False, dtype='float16', use_int8_kv_cache=False, tp_group=None, tp_size=1, ): super().__init__() self.attention_mask_type = attention_mask_type self.attention_head_size = hidden_size // num_attention_heads self.num_attention_heads = num_attention_heads // tp_size self.num_multi_query_groups_per_partition = multi_query_group_num self.num_attention_kv_heads = self.num_attention_heads self.hidden_size = hidden_size // tp_size self.projection_size = num_attention_heads * kv_channels self.hidden_size_per_attention_head = kv_channels self.layer_number = layer_number self.apply_query_key_layer_scaling = apply_query_key_layer_scaling self.q_scaling = 1 if apply_query_key_layer_scaling: self.q_scaling *= self.layer_number self.position_embedding_type = PositionEmbeddingType.learned_absolute self.multi_block_mode = False self.multi_query_mode = False self.rotary_embedding_dim = 0 self.dtype = dtype self.use_int8_kv_cache = use_int8_kv_cache if self.use_int8_kv_cache: self.kv_orig_quant_scale = Parameter(shape=(1, ), dtype='float32') self.kv_quant_orig_scale = Parameter(shape=(1, ), dtype='float32') else: self.register_parameter('kv_orig_quant_scale', None) self.register_parameter('kv_quant_orig_scale', None) # Note: in multi_query_mode, only query heads are split between multiple GPUs, # while key/value head are not split as there is only one head per key/value. # The output feature size is therefore (h/tp + 2) * d, where h is num_heads, # d is head_size, and tp is tensor_parallel_size. # In ColumnLinear op, the output dim is calculated by (h + 2*tp) * d / tp, # which matches the desired output size (h/tp + 2) * d after splitting self.qkv_hidden_size = (self.projection_size + 2 * self.hidden_size_per_attention_head * 2) self.qkv = ColumnLinear(hidden_size, self.qkv_hidden_size, bias=qkv_bias, dtype=dtype, tp_group=tp_group, tp_size=tp_size, gather_output=False) self.dense = RowLinear(hidden_size, hidden_size, bias=linear_bias, dtype=dtype, tp_group=tp_group, tp_size=tp_size) def forward(self, hidden_states: Tensor, rotary_pos_emb, use_cache=True, kv_cache_params=None, attention_params=None): if not default_net().plugin_config.gpt_attention_plugin: raise ValueError( 'ChatGLM2 is only supported with GPTAttention plugin,pleas build it with --use_gpt_attention_plugin argument.' ) assert isinstance(hidden_states, Tensor) qkv = self.qkv(hidden_states) query, key, value = split(qkv, [ self.num_attention_heads * self.hidden_size_per_attention_head, self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, ], dim=-1) query = query.view( concat([ shape(qkv, 0), shape(qkv, 1), self.num_attention_heads, self.attention_head_size ])) key = key.view( concat([ shape(qkv, 0), shape(qkv, 1), self.num_multi_query_groups_per_partition, self.attention_head_size ])) value = value.view( concat([ shape(qkv, 0), shape(qkv, 1), self.num_multi_query_groups_per_partition, self.attention_head_size ])) if rotary_pos_emb is not None: query = apply_rotary_pos_emb_trt(query, rotary_pos_emb) key = apply_rotary_pos_emb_trt(key, rotary_pos_emb) # batch,seq,num_group,1,head_states key = expand_dims(key, 3) #expand 16x expand_rate = self.num_attention_heads // self.num_multi_query_groups_per_partition key = expand( key, concat([ shape(key, 0), shape(key, 1), shape(key, 2), expand_rate, shape(key, 4) ])) # batch,seq,num_heads,head_states key = key.view( concat([ shape(key, 0), shape(key, 1), shape(key, 2) * shape(key, 3), shape(key, 4) ])) value = expand_dims(value, 3) value = expand( value, concat([ shape(value, 0), shape(value, 1), shape(value, 2), expand_rate, shape(value, 4) ])) value = value.view( concat([ shape(value, 0), shape(value, 1), shape(value, 2) * shape(value, 3), shape(value, 4) ])) qkv = concat([query, key, value], dim=2) qkv = qkv.view( concat([shape(qkv, 0), shape(qkv, 1), self.hidden_size * 3])) assert attention_params.is_valid( default_net().plugin_config.gpt_attention_plugin, default_net().plugin_config.remove_input_padding) assert kv_cache_params.is_valid( default_net().plugin_config.gpt_attention_plugin) kv_orig_quant_scale = self.kv_orig_quant_scale.value if self.use_int8_kv_cache else None kv_quant_orig_scale = self.kv_quant_orig_scale.value if self.use_int8_kv_cache else None context, past_key_value = gpt_attention( tensor=qkv, past_key_value=kv_cache_params.get_first_past_key_value(), sequence_length=attention_params.sequence_length, host_past_key_value_lengths=kv_cache_params. host_past_key_value_lengths, context_lengths=attention_params.context_lengths, cache_indirection=kv_cache_params.cache_indirection, host_request_types=attention_params.host_request_types, num_heads=self.num_attention_heads, num_kv_heads=self. num_attention_heads, # since self.multi_query_mode is set to False hidden_size_per_head=self.attention_head_size, q_scaling=self.q_scaling, rotary_embedding_dim=self.rotary_embedding_dim, position_embedding_type=self.position_embedding_type, multi_block_mode=self.multi_block_mode, kv_orig_quant_scale=kv_orig_quant_scale, kv_quant_orig_scale=kv_quant_orig_scale, kv_cache_quant_mode=QuantMode.INT8_KV_CACHE if self.use_int8_kv_cache else QuantMode(0), kv_cache_block_pointers=kv_cache_params. get_first_kv_cache_block_pointers(), max_context_length=attention_params.max_context_length, host_context_lengths=attention_params.host_context_lengths) # dense layer after self-attention context = self.dense(context) if use_cache: return (context, past_key_value) else: return context class ChatGLM2Block(Module): def __init__(self, hidden_size, num_attention_heads, kv_channels=128, multi_query_group_num=2, apply_query_key_layer_scaling=False, attention_mask_type=AttentionMaskType.causal, qkv_bias=True, linear_bias=False, use_int8_kv_cache=False, tp_group=None, tp_size=1, ffn_hiden_size=13696, layer_number=1, eps=1e-5, act_func='swiglu', dtype=trt.float16, quant_mode=QuantMode(0)): super(ChatGLM2Block, self).__init__() self.layer_number = layer_number self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.dtype = dtype self.ffn_hiden_size = ffn_hiden_size self.apply_residual_connection_post_layernorm = False self.fp32_residual_connection = False LayerNormFunc = RmsNorm # Layernorm on the input data. self.input_layernorm = LayerNormFunc(self.hidden_size, eps=eps, dtype=dtype) # Self attention. self.self_attention = ChatGLM2Attention( hidden_size, num_attention_heads, layer_number, kv_channels, multi_query_group_num, apply_query_key_layer_scaling, attention_mask_type, qkv_bias, linear_bias, dtype, use_int8_kv_cache, tp_group, tp_size) self.hidden_dropout = 0.0 # Layernorm on the attention output self.post_attention_layernorm = LayerNormFunc(self.hidden_size, eps=eps, dtype=dtype) self.mlp = MLP(self.hidden_size, ffn_hiden_size, act_func, linear_bias, dtype) def forward(self, hidden_states, rotary_pos_emb, use_cache=True, kv_cache_params=None, attention_params=None): # hidden_states: [s, b, h] # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. attention_output, kv_cache = self.self_attention( layernorm_output, rotary_pos_emb, use_cache=use_cache, kv_cache_params=kv_cache_params, attention_params=attention_params) # Residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = hidden_states layernorm_input = hidden_states + attention_output # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) # MLP. mlp_output = self.mlp(layernorm_output) # Second residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = layernorm_input output = residual + mlp_output return output, kv_cache class ChatGLM2Transformer(Module): """Transformer class.""" def __init__(self, hidden_size, num_attention_heads, kv_channels=128, multi_query_group_num=2, apply_query_key_layer_scaling=False, attention_mask_type=AttentionMaskType.causal, qkv_bias=True, linear_bias=False, use_int8_kv_cache=False, tp_group=None, tp_size=1, ffn_hiden_size=13696, num_layers=28, eps=1e-5, act_func='swiglu', dtype=trt.float16, quant_mode=QuantMode(0)): super(ChatGLM2Transformer, self).__init__() self.fp32_residual_connection = False self.post_layer_norm = True # Number of layers. self.num_layers = num_layers # Transformer layers. def build_layer(layer_number): return ChatGLM2Block(hidden_size, num_attention_heads, kv_channels, multi_query_group_num, apply_query_key_layer_scaling, attention_mask_type, qkv_bias, linear_bias, use_int8_kv_cache, tp_group, tp_size, ffn_hiden_size, layer_number, eps, act_func, dtype, quant_mode) self.layers = ModuleList( build_layer(i + 1) for i in range(self.num_layers)) if self.post_layer_norm: self.final_layernorm = RmsNorm(hidden_size, eps=eps, dtype=dtype) self.gradient_checkpointing = False def _get_layer(self, layer_number): return self.layers[layer_number] def forward(self, hidden_states, rotary_pos_emb, use_cache=True, kv_cache_params=None, attention_params=None): presents = [] for index in range(self.num_layers): layer = self._get_layer(index) hidden_states, kv_cache = layer( hidden_states, rotary_pos_emb, use_cache=use_cache, kv_cache_params=KeyValueCacheParams( past_key_value=[kv_cache_params.past_key_value[index]], kv_cache_block_pointers=[ kv_cache_params.kv_cache_block_pointers[index] ], host_past_key_value_lengths=kv_cache_params. host_past_key_value_lengths, cache_indirection=kv_cache_params.cache_indirection), attention_params=attention_params) presents.append(kv_cache) if self.post_layer_norm: hidden_states = self.final_layernorm(hidden_states) return hidden_states, presents class ChatGLM2Model(Module): def __init__(self, hidden_size, num_attention_heads, kv_channels=128, multi_query_group_num=2, apply_query_key_layer_scaling=False, attention_mask_type=AttentionMaskType.causal, qkv_bias=True, linear_bias=False, use_int8_kv_cache=False, mapping=Mapping(), ffn_hiden_size=13696, num_layers=28, eps=1e-5, act_func='swiglu', dtype=trt.float16, quant_mode=QuantMode(0), max_seq_length=32768, vocab_size=65024): super(ChatGLM2Model, self).__init__() self.dtype = dtype self.embedding = Embedding(vocab_size, hidden_size, dtype=dtype) self.num_layers = num_layers self.multi_query_group_num = multi_query_group_num self.kv_channels = kv_channels # Rotary positional embeddings self.max_seq_length = max_seq_length rotary_dim = kv_channels self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, ) self.encoder = ChatGLM2Transformer( hidden_size, num_attention_heads, kv_channels, multi_query_group_num, apply_query_key_layer_scaling, attention_mask_type, qkv_bias, linear_bias, use_int8_kv_cache, mapping.tp_group, mapping.tp_size, ffn_hiden_size, num_layers, eps, act_func, dtype, quant_mode) def forward( self, input_ids: Tensor, position_ids, use_cache=True, kv_cache_params=None, attention_params=None, ): inputs_embeds = self.embedding(input_ids) # Rotary positional embeddings # generate 32768 pos embeddings # max_seq_length,head_dim/4,2 rotary_pos_emb = self.rotary_pos_emb(self.max_seq_length) flat_position = position_ids.view( concat([shape(position_ids, 0) * shape(position_ids, 1)])) selected_pos_emb = index_select(rotary_pos_emb, 0, flat_position) # selected batch,seq from rotary_pos_emb selected_pos_emb = selected_pos_emb.view( concat([ shape(position_ids, 0), shape(position_ids, 1), shape(rotary_pos_emb, 1), shape(rotary_pos_emb, 2) ])) # seq,batch selected_pos_emb = selected_pos_emb.permute((1, 0, 2, 3)) # return inputs_embeds,selected_pos_emb # Run encoder. hidden_states, presents = self.encoder( inputs_embeds, selected_pos_emb, use_cache=use_cache, kv_cache_params=kv_cache_params, attention_params=attention_params, ) return hidden_states, presents class ChatGLM2HeadModel(ChatGLM2Model, GenerationMixin): def __init__(self, hidden_size, num_attention_heads, kv_channels=128, multi_query_group_num=2, apply_query_key_layer_scaling=False, attention_mask_type=AttentionMaskType.causal, qkv_bias=True, linear_bias=False, use_int8_kv_cache=False, mapping=Mapping(), ffn_hiden_size=13696, num_layers=28, eps=1e-5, act_func='swiglu', dtype=trt.float16, quant_mode=QuantMode(0), max_seq_length=32768, vocab_size=65024, use_cache=True, kv_cache_block_pointers=None): if isinstance(dtype, str): self._kv_dtype = str_dtype_to_trt(dtype) else: assert isinstance(dtype, trt.DataType) self._kv_dtype = dtype self._dtype = self._kv_dtype if quant_mode.has_int8_kv_cache(): self._kv_dtype = str_dtype_to_trt('int8') elif quant_mode.has_fp8_kv_cache(): self._kv_dtype = str_dtype_to_trt('fp8') self.use_cache = use_cache self.kv_cache_block_pointers = kv_cache_block_pointers self.quant_mode = quant_mode self._num_layers = num_layers self._num_heads = num_attention_heads self._hidden_size = hidden_size self._vocab_size = vocab_size self._tp_size = mapping.tp_size self.mapping = mapping super().__init__(hidden_size, num_attention_heads, kv_channels, multi_query_group_num, apply_query_key_layer_scaling, attention_mask_type, qkv_bias, linear_bias, use_int8_kv_cache, mapping, ffn_hiden_size, num_layers, eps, act_func, dtype, quant_mode, max_seq_length, vocab_size) vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size) self.lm_head = ColumnLinear(hidden_size, vocab_size_padded, bias=False, dtype=dtype, tp_group=mapping.tp_group, tp_size=mapping.tp_size, gather_output=True) def forward(self, input_ids=None, position_ids=None, last_token_ids=None, kv_cache_params=None, attention_params=None): hidden_states = super().forward(input_ids, position_ids, self.use_cache, kv_cache_params, attention_params) if self.use_cache: hidden_states, presents = hidden_states hidden_states = gather_last_token_logits( hidden_states, last_token_ids, default_net().plugin_config.remove_input_padding) lm_logits = self.lm_head(hidden_states) lm_logits.mark_output('logits', self._dtype) if default_net().plugin_config.paged_kv_cache == False: for i, present in enumerate(presents): present.mark_output(f'present_key_value_{i}', self._kv_dtype) return (lm_logits, presents) return lm_logits def prepare_inputs(self, max_batch_size, max_input_len, max_new_tokens, use_cache, max_beam_width: int = 1): '''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the ranges of the dimensions of when using TRT dynamic shapes. @return: a list contains values which can be fed into the self.forward() ''' # Prepare inputs head_size = self._hidden_size // self._num_heads num_heads = self._num_heads // self._tp_size remove_input_padding = default_net().plugin_config.remove_input_padding use_gpt_attention_plugin = default_net( ).plugin_config.gpt_attention_plugin use_gemm_plugin = default_net().plugin_config.gemm_plugin #以下代码已修改,支持inflight_batching (by shaun xiong 231024) paged_kv_cache = default_net().plugin_config.paged_kv_cache tokens_per_block = default_net().plugin_config.tokens_per_block use_custom_all_reduce = default_net().plugin_config.use_custom_all_reduce model_inputs = self.prepare_basic_inputs( max_batch_size=max_batch_size, max_beam_width=max_beam_width, max_input_len=max_input_len, max_new_tokens=max_new_tokens, num_kv_heads = num_heads, num_heads = num_heads, head_size=head_size, num_layers=self.num_layers, kv_dtype=self._kv_dtype, remove_input_padding=remove_input_padding, use_gpt_attention_plugin=use_gpt_attention_plugin, use_gemm_plugin=use_gemm_plugin, use_custom_all_reduce=use_custom_all_reduce, paged_kv_cache=paged_kv_cache, tokens_per_block=tokens_per_block, mapping=self.mapping ) ################################## return (model_inputs['input_ids'], model_inputs['position_ids'], model_inputs['last_token_ids'], KeyValueCacheParams( past_key_value=model_inputs['past_key_value'], host_past_key_value_lengths=model_inputs[ 'host_past_key_value_lengths'], kv_cache_block_pointers=model_inputs[ 'kv_cache_block_pointers_list'], cache_indirection=model_inputs['cache_indirection'], ), AttentionParams( sequence_length=model_inputs['sequence_length'], context_lengths=model_inputs['context_lengths'], host_context_lengths=model_inputs['host_context_lengths'], max_context_length=max_input_len, host_request_types=model_inputs['host_request_types']))
以上代码修改完毕后,运行
examples/chatglm2-6b/build.py
脚本进行模型转换:python3 build.py --model_dir=./pyTorchModel \ --dtype float16 \ --use_gpt_attention_plugin float16 \ --use_gemm_plugin float16 \ --use_inflight_batching \ --use_weight_only
增加-
-use_weight_only
对权重进行INT 8量化,转换后的TensorRT-LLM Engine模型文件chatglm2-6b_float16_tp1_rank0.engine
位于trtModel
目录下:trtModel/ ├── [6.3G] chatglm2-6b_float16_tp1_rank0.engine ├── [1.3K] config.json └── [424K] model.cache
-
- 获取chatglm2-6b模型预训练权重放到pyTorchModel目录下
exit
3. 给Triton Inference Server添加 TensorRT-LLM Backend
目前最新版本的Triton Inference Server(23.09)暂不支持TensorRT-LLM,官方repo说将在23.10版本支持TensorRT-LLM,最新版本发布信息可关注官网,以下步骤通过手动构建包含TensorRT-LLM Backend的Triton Inference Server镜像。
-
拉取tensorrtllm_backend库
git clone h提提ps://github.com/triton-inference-server/tensorrtllm_backend.git
- 手动构建TensorRT-LLM Backend镜像
cd tensorrtllm_backend git submodule update --init --recursive git lfs install git lfs pull DOCKER_BUILDKIT=1 docker build -t triton_trt_llm -f dockerfile/Dockerfile.trt_llm_backend .
运行docker build命令可能会遇到以下问题:
a. 无法连接到pypi.org,无法wget获取cmake安装包等问题
解决方法:修改
tensorrtllm_backend/dockerfile/Dockerfile.trt_llm_backend
文件更换pip源、手动安装cmake:#修改pip源 RUN python -m pip install --upgrade pip RUN pip config set global.index-url h提提ps://pypi.tuna.tsinghua.edu.cn/simple # CMake # COPY tensorrt_llm/docker/common/install_cmake.sh /tmp/ # RUN bash /tmp/install_cmake.sh && rm /tmp/install_cmake.sh # ENV PATH="/usr/local/cmake/bin:${PATH}" #改成手动安装cmake COPY tensorrt_llm/docker/common/cmake-3.27.7-linux-x86_64.tar.gz /tmp RUN tar -xf /tmp/cmake-3.27.7-linux-x86_64.tar.gz -C /usr/local/ RUN ln -s /usr/local/cmake-3.27.7-linux-x86_64 /usr/local/cmake ENV PATH="/usr/local/cmake/bin:${PATH}" RUN rm -rf /tmp/cmake-3.27.7-linux-x86_64.tar.gz
b.
执行到build.sh这一步时报错:
fatal: unable to access 'h提提ps://github.com/triton-inference-server/common.git/': GnuTLS recv error (-110): The TLS connection was non-properly terminated.
解决方法:在
build.sh
脚本cmake这一步前面添加以下两行命令更新h提提p post缓冲区的值:git config --global h提提p.postBuffer 1048576000 git config --global h提提ps.postBuffer 1048576000
上述命令执行完毕会生成一个支持TensorRT-LLM的Triton Inference Server镜像"triton_trt_llm:latest"
4. 部署Triton推理服务
-
准备模型文件和配置文件:
参照tensorrtllm_backend官方文档分别准备好模型文件并修改配置文件,config文件请安装官方文档设置成支持inflight batching:如果不启用inflight batching,可以正常启动triton inference server,但官方给的推理测试脚本是inflight batching的,导致无法进行推理验证。。。
-
启动triton inference server
# Launch the Triton container docker run --rm -it --net host --shm-size=2g --ulimit memlock=-1 --ulimit stack=67108864 --gpus all -v /root/xiongx7/tensorrtllm_backend:/tensorrtllm_backend triton_trt_llm bash cd /tensorrtllm_backend # --world_size is the number of GPUs you want to use for serving python3 scripts/launch_triton_server.py --world_size=1 --model_repo=/tensorrtllm_backend/triton_model_repo
-
client推理测试
python3 tools/inflight_batcher_llm/end_to_end_streaming_client.py -p "什么是机器学习?" -S -o 245
输出了一堆编码:
需要对tools/inflight_batcher_llm/end_to_end_streaming_client.py
做一些修改,将输出结果解码为中文,并且连成一句话不要换行。具体地,修改end_to_end_streaming_client.py
脚本中的test()函数:def test(triton_client, prompt): model_name = "ensemble" input0 = [[prompt]] input0_data = np.array(input0).astype(object) output0_len = np.ones_like(input0).astype(np.uint32) * FLAGS.output_len bad_words_list = np.array([[""]], dtype=object) stop_words_list = np.array([[""]], dtype=object) streaming = [[FLAGS.streaming]] streaming_data = np.array(streaming, dtype=bool) inputs = [ utils.prepare_tensor("text_input", input0_data, FLAGS.protocol), utils.prepare_tensor("max_tokens", output0_len, FLAGS.protocol), utils.prepare_tensor("bad_words", bad_words_list, FLAGS.protocol), utils.prepare_tensor("stop_words", stop_words_list, FLAGS.protocol), utils.prepare_tensor("stream", streaming_data, FLAGS.protocol), ] user_data = UserData() # Establish stream triton_client.start_stream(callback=partial(callback, user_data)) # 创建一个 StringIO 对象来捕获输出 output_catcher = StringIO() # 重定向标准输出到 StringIO,避免编码直接在终端打印输出 sys.stdout = output_catcher triton_client.async_stream_infer(model_name, inputs) # Send request #Wait for server to close the stream triton_client.stop_stream() # 恢复标准输出 sys.stdout = sys.__stdout__ # Parse the responses response_text = [] while True: try: result = user_data._completed_requests.get(block=False) except Exception: break if type(result) == InferenceServerException: print("Received an error from server:") print(result) else: text_output = result.as_numpy('text_output') #将result解码成中文字符 for item in text_output: chinese_text = item.decode('utf-8') response_text.append(chinese_text) # 将所有token连接成一句话 chinese_sentence = ''.join(response_text) print(chinese_sentence)
再次运行client推理测试脚本可以正常返回:
python3 tools/inflight_batcher_llm/end_to_end_streaming_client.py -p "什么是机器学习?" -S -o 245
至此,成功完成基于TensorRT-LLM和Triton的ChatGLM2-6B模型推理验证! - 最后用以下命令终止Triton inference server
pgrep tritonserver | xargs kill -9
-
5. 踩坑记录
-
config文件里设置了不启用inflight batching,Triton服务可以正常起来,但运行client测试脚本时报错维度不匹配
引起报错的原因:config文件中disable inflight batching, 给的client脚本是inflight batching的。
-
config文件里设置了启用inflight batching,启动triton inference server时候报错:
TrtGptModelInflightBatching requires GPT attention plugin with packed input and paged KV cache.
,这是因为前面模型转换的时候没有支持inflight batching
以上两个问题的解决方法:按照前文步骤修改build.py
和model.py
,并且config文件中设置为启用inflight batching。