searchusermenu
  • 发布文章
  • 消息中心
点赞
收藏
评论
分享
原创

基于TensorRT-LLM和Triton进行ChatGLM2-6B模型推理实践

2023-10-27 06:54:20
657
0

2023年10月19日,NVIDIA正式宣布TensorRT-LLM开放使用,TensorRT-LLM的主要特性有:

- 生成式AI应用端到端部署框架(模型构建、自定义、格式转换、部署)
- 支持多GPU多节点推理
- 包含常见大模型的转换、部署示例(LLaMA系列、ChatGLM系列、GPT系列、Baichuan、BLOOM、OPT、Falcon等)
- 提供Python API支持新模型的构建和转换
- 支持Triton推理服务框架
- 支持多种NVIDIA架构:Volta, Turing, Ampere, Hopper 和Ada Lovelace
- 除了FastTransformer中针对transformer结构的优化项,新增了多种针对大模型的优化项,如In-flight Batching、Paged KV Cache for the Attention、INT4/INT8 Weight-Only Quantization、SmoothQuant、Multi-head Attention(MHA)、Multi-query Attention (MQA)、Group-quer

本文参考TensorRT-LLM和Triton inference server开源的tensorrtllm_backend官方文档,成功验证了“TensorRT-LLM + Triton”这套方案部署ChatGLM2-6B模型推理服务的可行性。在开始正文前,做以下几点说明:

  • TensorRT-LLM开源的代码中不支持ChatGLM2-6B模型i动态批处理模式(in-flight batching),本文对源码做了一些修改以支持动态批处理;

    注:动态批处理是一种优化调度技术,处理批量请求时立即移除已完成的序列并开始处理下一个新的请求,此时batch中的其他请求仍然在进行中,这样可以提高吞吐量和GPU的利用率

  • 本文修改了tensorrtllm_backend给出的triton client推理测试脚本以增加人机交互的友好性;

  • 本文尝试多卡推理ChatGLM2-6B模型失败,给官方repo提issue得到的回复是目前TensorRT-LLM不支持ChatGLM2-6B的多卡推理,等待后续代码更新,本文用的是单机单卡;

1. TensorRT-LLM编译与镜像制作

目前TensorRT-LLM只能从源码手动安装,后续官方会推出包含TensorRT-LLM和Triton推理服务后端的Docker镜像。

  • 拉取TensorRT-LLM源码

    apt-get update && apt-get -y install git git-lfs
    
    git clone h提提ps://github.com/NVIDIA/TensorRT-LLM.git
    cd TensorRT-LLM
    git submodule update --init --recursive
    git lfs install
    git lfs pull
  • 构建TensorRT-LLM的docker镜像

    make -C docker release_build
  • 运行docker容器
    make -C docker run

    该命令会以tensorrt_llm/devel:latest为镜像启动一个名为tensorrt_llm-devel-root的容器

2. Huggingface格式模型转换为TensorRT-LLM Engine

注:以下步骤在tensorrt_llm-devel-root容器内执行

  • 获取chatglm2-6b模型预训练权重放到pyTorchModel目录下
    apt-get update
    apt-get install git-lfs
    cd examples/chatglm2-6b/
    git clone h提提ps://huggingface.co/THUDM/chatglm2-6b pyTorchModel
  • Huggingface模型转换成TensorRT-LLM Engine

    注意:

    • 目前TensorRT-LLM开源代码中的ChatGLM2-6B模型不支持inflight batching,本文对模型文件做了修改以支持inflight batching

    • build这一步设置的--world_size需要和后面推理时需要的gpu数量一致,否则后面推理会报错

    (重要)转换模型之前先修改build.py和model.py以支持inflight batching

    修改examples/chatglm2-6b/build.py如下:

    # SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
    # SPDX-License-Identifier: Apache-2.0
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    # h提提p://3w.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    import argparse
    import os
    import time
    
    import tensorrt as trt
    import torch
    import torch.multiprocessing as mp
    import transformers
    from weight import load_from_hf_chatglm2_6B
    
    import tensorrt_llm
    from tensorrt_llm._utils import str_dtype_to_trt
    from tensorrt_llm.builder import Builder
    from tensorrt_llm.layers import AttentionMaskType
    from tensorrt_llm.logger import logger
    from tensorrt_llm.mapping import Mapping
    from tensorrt_llm.models import (ChatGLM2HeadModel, smooth_quantize,
                                     weight_only_quantize)
    from tensorrt_llm.network import net_guard
    from tensorrt_llm.plugin.plugin import ContextFMHAType
    from tensorrt_llm.quantization import QuantMode
    
    MODEL_NAME = "chatglm2-6b"
    
    
    def get_engine_name(model, dtype, tp_size, rank):
        return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank)
    
    
    def serialize_engine(engine, path):
        logger.info(f'Serializing engine to {path}...')
        tik = time.time()
        with open(path, 'wb') as f:
            f.write(bytearray(engine))
        tok = time.time()
        t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
        logger.info(f'Engine serialized. Total time: {t}')
    
    
    def parse_arguments():
        parser = argparse.ArgumentParser()
        parser.add_argument('--world_size',
                            type=int,
                            default=1,
                            help='world size, only support tensor parallelism now')
        parser.add_argument('--model_dir', type=str, default="./pyTorchModel")
        parser.add_argument('--dtype',
                            type=str,
                            default='float16',
                            choices=['float16', 'float32'])
        parser.add_argument(
            '--timing_cache',
            type=str,
            default='model.cache',
            help=
            'The path of to read timing cache from, will be ignored if the file does not exist'
        )
        parser.add_argument(
            '--log_level',
            type=str,
            default='verbose',
            choices=['verbose', 'info', 'warning', 'error', 'internal_error'])
        parser.add_argument('--vocab_size', type=int, default=65024)
        parser.add_argument('--n_layer', type=int, default=28)
        parser.add_argument('--n_positions', type=int, default=2048)
        parser.add_argument('--n_embd', type=int, default=4096)
        parser.add_argument('--n_head', type=int, default=32)
        parser.add_argument('--hidden_act', type=str, default='gelu')
        parser.add_argument(
            '--rotary_pct',
            type=float,
            default=0.0,
            help="Setting this to a value > 0.0 (and <= 1.0) activates RoPE.")
        parser.add_argument('--inter_size', type=int, default=None)
        parser.add_argument('--no_bias', action="store_false")
        parser.add_argument('--max_batch_size', type=int, default=4)
        parser.add_argument('--max_input_len', type=int, default=4096)
        parser.add_argument('--max_output_len', type=int, default=2048)
        parser.add_argument('--max_beam_width', type=int, default=1)
        parser.add_argument(
            '--use_gpt_attention_plugin',
            nargs='?',
            const='float16',
            default='float16',
            # default=False,
            choices=['float16', 'float32', False])
        parser.add_argument('--use_gemm_plugin',
                            nargs='?',
                            const='float16',
                            type=str,
                            default='float16',
                            choices=['float16', 'float32'])
        parser.add_argument('--use_layernorm_plugin',
                            nargs='?',
                            const='float16',
                            type=str,
                            default='float16',
                            choices=['float16', 'float32'])
        parser.add_argument('--parallel_build', default=False, action='store_true')
        parser.add_argument('--enable_context_fmha',
                            default=False,
                            action='store_true')
        parser.add_argument('--enable_context_fmha_fp32_acc',
                            default=False,
                            action='store_true')
        parser.add_argument('--gpus_per_node', type=int, default=8)
        parser.add_argument('--builder_opt', type=int, default=None)
        parser.add_argument(
            '--output_dir',
            type=str,
            default='trtModel',
            help=
            'The path to save the serialized engine files, timing cache file and model configs'
        )
        parser.add_argument(
            "--multi_query_mode",
            "-mq",
            default=False,
            action='store_true',
            help=
            "Whether this model uses multi-query attention mechanism (default: False)"
        )
        # Arguments related to the quantization of the model.
        parser.add_argument(
            '--use_smooth_quant',
            default=False,
            action="store_true",
            help=
            'Use the SmoothQuant method to quantize activations and weights for the various GEMMs.'
            'See --per_channel and --per_token for finer-grained quantization options.'
        )
        parser.add_argument(
            '--use_weight_only',
            default=False,
            action="store_true",
            help='Quantize weights for the various GEMMs to INT4/INT8.'
            'See --weight_only_precision to set the precision')
        parser.add_argument(
            '--weight_only_precision',
            const='int8',
            type=str,
            nargs='?',
            default='int8',
            choices=['int8', 'int4'],
            help=
            'Define the precision for the weights when using weight-only quantization.'
            'You must also use --use_weight_only for that argument to have an impact.'
        )
        parser.add_argument(
            '--per_channel',
            default=False,
            action="store_true",
            help=
            'By default, we use a single static scaling factor for the GEMM\'s result. '
            'per_channel instead uses a different static scaling factor for each channel. '
            'The latter is usually more accurate, but a little slower.')
        parser.add_argument(
            '--per_token',
            default=False,
            action="store_true",
            help=
            'By default, we use a single static scaling factor to scale activations in the int8 range. '
            'per_token chooses at run time, and for each token, a custom scaling factor. '
            'The latter is usually more accurate, but a little slower.')
        parser.add_argument(
            '--int8_kv_cache',
            default=False,
            action="store_true",
            help=
            'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
        )
        parser.add_argument(
            '--random_seed',
            type=int,
            default=None,
            help=
            'Seed to use when initializing the random number generator for torch.')
    
    
        # 新增代码以支持inflight_batching, by shaun xiong (231024)
        parser.add_argument('--remove_input_padding',
                            default=False,
                            action='store_true')
        parser.add_argument(
            '--use_inflight_batching',
            action="store_true",
            default=False,
            help="Activates inflight batching mode of gptAttentionPlugin.")
        parser.add_argument(
            '--paged_kv_cache',
            action="store_true",
            default=False,
            help=
            'By default we use contiguous KV cache. By setting this flag you enable paged KV cache'
        )
        parser.add_argument('--tokens_per_block',
                            type=int,
                            default=64,
                            help='Number of tokens per block in paged KV cache')
        ########################################################
    
        args = parser.parse_args()
    
    
        # 新增代码以支持inflight_batching (by shaun xiong 231024)
        if not args.remove_input_padding:
            if args.use_gpt_attention_plugin:
                logger.warning(
                    f"It is recommended to specify --remove_input_padding when using GPT attention plugin"
                )
    
        if args.use_inflight_batching:
            if not args.use_gpt_attention_plugin:
                args.use_gpt_attention_plugin = 'float16'
                logger.info(
                    f"Using GPT attention plugin for inflight batching mode. "
                    f"Setting to default '{args.use_gpt_attention_plugin}'")
            if not args.remove_input_padding:
                args.remove_input_padding = True
                logger.info(
                    'Using remove input padding for inflight batching mode.')
            if not args.paged_kv_cache:
                args.paged_kv_cache = True
                logger.info('Using paged KV cache for inflight batching mode.')
        ########################################################
    
        assert not (
            args.use_smooth_quant and args.use_weight_only
        ), "You cannot enable both SmoothQuant and INT8 weight-only together."
    
        if args.use_smooth_quant:
            args.quant_mode = QuantMode.use_smooth_quant(args.per_token,
                                                         args.per_channel)
        elif args.use_weight_only:
            args.quant_mode = QuantMode.use_weight_only(
                args.weight_only_precision == 'int4')
        else:
            args.quant_mode = QuantMode(0)
        args.bias = not args.no_bias
    
        if args.inter_size is None:
            args.inter_size = 4 * args.n_embd
    
        if args.int8_kv_cache:
            assert (
                args.use_gpt_attention_plugin
            ), "You have to use GPT attention plugin when int8 KV cache is set"
            args.quant_mode = args.quant_mode.set_int8_kv_cache()
    
        return args
    
    
    def build_rank_engine(builder: Builder,
                          builder_config: tensorrt_llm.builder.BuilderConfig,
                          engine_name, rank, args):
        '''
           @brief: Build the engine on the given rank.
           @param rank: The rank to build the engine.
           @param args: The cmd line arguments.
           @return: The built engine.
        '''
        str_dtype_to_trt(args.dtype)
    
        # Initialize Module
        tensorrt_llm_ChatGLM2_6BModel = ChatGLM2HeadModel(
            hidden_size=4096,
            num_attention_heads=32,
            kv_channels=128,
            multi_query_group_num=2,
            apply_query_key_layer_scaling=False,
            attention_mask_type=AttentionMaskType.causal,
            qkv_bias=True,
            linear_bias=False,
            use_int8_kv_cache=False,
            mapping=Mapping(world_size=args.world_size,
                            rank=rank,
                            tp_size=args.world_size),
            ffn_hiden_size=13696,
            num_layers=28,
            eps=1e-5,
            act_func='swiglu',
            dtype=trt.float16,
            quant_mode=QuantMode(0),
            max_seq_length=32768,
            vocab_size=65024)
    
        if args.use_smooth_quant:
            tensorrt_llm_ChatGLM2_6BModel = smooth_quantize(
                tensorrt_llm_ChatGLM2_6BModel, args.quant_mode)
        elif args.use_weight_only:
            tensorrt_llm_ChatGLM2_6BModel = weight_only_quantize(
                tensorrt_llm_ChatGLM2_6BModel, args.quant_mode)
        if args.model_dir is not None:
            print('loading weights from hugging face model')
            hf_model = transformers.AutoModel.from_pretrained(
                args.model_dir, trust_remote_code=True).cpu()
            tensorrt_llm_ChatGLM2_6BModel = load_from_hf_chatglm2_6B(
                tensorrt_llm_ChatGLM2_6BModel, hf_model, dtype='float16')
            print('load finished')
            del hf_model
        # Module -> Network
        network = builder.create_network()
        network.trt_network.name = engine_name
        if args.use_gpt_attention_plugin:
            network.plugin_config.set_gpt_attention_plugin(
                dtype=args.use_gpt_attention_plugin)
        if args.use_gemm_plugin:
            network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
        if args.use_layernorm_plugin:
            network.plugin_config.set_layernorm_plugin(
                dtype=args.use_layernorm_plugin)
        assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc)
        if args.enable_context_fmha:
            network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
        if args.enable_context_fmha_fp32_acc:
            network.plugin_config.set_context_fmha(
                ContextFMHAType.enabled_with_fp32_acc)
        
        
    
        # Quantization plugins.
        if args.use_smooth_quant:
            network.plugin_config.set_smooth_quant_gemm_plugin(dtype=args.dtype)
            network.plugin_config.set_layernorm_quantization_plugin(
                dtype=args.dtype)
            # FIXME
            network.plugin_config.set_quantize_tensor_plugin()
            network.plugin_config.set_quantize_per_token_plugin()
        elif args.use_weight_only:
            network.plugin_config.set_weight_only_quant_matmul_plugin(
                dtype='float16')
    
        if args.world_size > 1:
            network.plugin_config.set_nccl_plugin(args.dtype)
        
        #新增代码以支持inflight_batching (by shaun xiong 231024)
        if args.remove_input_padding:
            network.plugin_config.enable_remove_input_padding()
        if args.paged_kv_cache:
            network.plugin_config.enable_paged_kv_cache(args.tokens_per_block)
        ####
        
    
        with net_guard(network):
            # Prepare
            network.set_named_parameters(
                tensorrt_llm_ChatGLM2_6BModel.named_parameters())
    
            # Forward
            inputs = tensorrt_llm_ChatGLM2_6BModel.prepare_inputs(
                args.max_batch_size, args.max_input_len, args.max_output_len, True,
                args.max_beam_width)
            tensorrt_llm_ChatGLM2_6BModel(*inputs)
    
        tensorrt_llm.graph_rewriting.optimize(network)
    
        engine = None
    
        # Network -> Engine
        engine = builder.build_engine(network, builder_config)
        if rank == 0:
            config_path = os.path.join(args.output_dir, 'config.json')
            builder.save_config(builder_config, config_path)
        return engine
    
    
    def build(rank, args):
        torch.cuda.set_device(rank % args.gpus_per_node)
        tensorrt_llm.logger.set_level(args.log_level)
        if not os.path.exists(args.output_dir):
            os.makedirs(args.output_dir)
    
        # when doing serializing build, all ranks share one engine
        apply_query_key_layer_scaling = False
        builder = Builder()
    
        cache = None
        for cur_rank in range(args.world_size):
            # skip other ranks if parallel_build is enabled
            if args.parallel_build and cur_rank != rank:
                continue
            builder_config = builder.create_builder_config(
                name=MODEL_NAME,
                precision=args.dtype,
                timing_cache=args.timing_cache if cache is None else cache,
                tensor_parallel=args.world_size,  # TP only
                parallel_build=args.parallel_build,
                num_layers=args.n_layer,
                num_heads=args.n_head,
                hidden_size=args.n_embd,
                vocab_size=args.vocab_size,
                hidden_act=args.hidden_act,
                max_position_embeddings=args.n_positions,
                apply_query_key_layer_scaling=apply_query_key_layer_scaling,
                max_batch_size=args.max_batch_size,
                max_input_len=args.max_input_len,
                max_output_len=args.max_output_len,
                int8=(args.quant_mode.has_act_and_weight_quant()
                      or args.quant_mode.has_int8_kv_cache()),
                opt_level=args.builder_opt,
                multi_query_mode=args.multi_query_mode)
    
            engine_name = get_engine_name(MODEL_NAME, args.dtype, args.world_size,
                                          cur_rank)
            engine = build_rank_engine(builder, builder_config, engine_name,
                                       cur_rank, args)
            assert engine is not None, f'Failed to build engine for rank {cur_rank}'
    
            if cur_rank == 0:
                # Use in-memory timing cache for multiple builder passes.
                if not args.parallel_build:
                    cache = builder_config.trt_builder_config.get_timing_cache()
    
            serialize_engine(engine, os.path.join(args.output_dir, engine_name))
    
        if rank == 0:
            ok = builder.save_timing_cache(
                builder_config, os.path.join(args.output_dir, "model.cache"))
            assert ok, "Failed to save timing cache."
    
    
    if __name__ == '__main__':
        args = parse_arguments()
    
        if args.random_seed is not None:
            torch.manual_seed(args.random_seed)
    
        logger.set_level(args.log_level)
        tik = time.time()
        if args.parallel_build and args.world_size > 1 and \
                torch.cuda.device_count() >= args.world_size:
            logger.warning(
                f'Parallelly build TensorRT engines. Please make sure that all of the {args.world_size} GPUs are totally free.'
            )
            mp.spawn(build, nprocs=args.world_size, args=(args, ))
        else:
            args.parallel_build = False
            logger.info('Serially build TensorRT engines.')
            build(0, args)
    
        tok = time.time()
        t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
        logger.info(f'Total time of building all {args.world_size} engines: {t}')
    

    修改/usr/local/lib/python3.10/dist-packages/tensorrt_llm/models/chatglm2_6b/model.py 如下:

    # SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
    # SPDX-License-Identifier: Apache-2.0
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    # h提提p://3w.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    import numpy as np
    import tensorrt as trt
    import torch
    
    from ..._common import default_net
    from ..._utils import pad_vocab_size, str_dtype_to_trt
    from ...functional import (PositionEmbeddingType, Tensor, concat, constant,
                               expand, expand_dims, gather_last_token_logits,
                               gpt_attention, index_select, select, shape, slice,
                               split)
    from ...layers import (MLP, AttentionMaskType, AttentionParams, ColumnLinear,
                           Embedding, KeyValueCacheParams, RmsNorm, RowLinear)
    from ...mapping import Mapping
    from ...module import Module, ModuleList
    from ...parameter import Parameter
    from ...quantization import QuantMode
    from ..generation_mixin import GenerationMixin
    
    
    def apply_rotary_pos_emb_trt(x: Tensor, rope_cache: Tensor) -> Tensor:
        # x-> [seq, batch, num_heads, 2]
        x = x.permute((1, 0, 2, 3))
        # sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
        sq = shape(x, 0)
        b = shape(x, 1)
        nh = shape(x, 2)
        shape(x, 3)
        # rope_cache shape: seq,batch,heads,2 rot_dim = 2* numheads
        #rope_cache: seq,batch,num_states/4,2
        rot_dim = shape(rope_cache, 2) * constant(np.array(2, dtype=np.int32))
        starts = concat([0, 0, 0, 0])
        sizes = concat([sq, b, nh, rot_dim])
        # first half
        x_rot = slice(x, starts, sizes)
        starts = concat([0, 0, 0, rot_dim])
        # second half
        x_pass = slice(x, starts, sizes)
        # truncate to support variable sizes
        rope_cache = slice(rope_cache, (0, 0, 0, 0), (concat(
            [sq,
             shape(rope_cache, 1),
             shape(rope_cache, 2),
             shape(rope_cache, 3)])))
        xshaped = x_rot.view(concat([sq, b, nh, rot_dim / 2, 2]))
        rope_cache = rope_cache.view(concat([sq, b, 1, shape(xshaped, 3), 2]))
        # first half
        xshape0 = select(xshaped, 4, 0)
        # second half
        xshape1 = select(xshaped, 4, 1)
        # first half
        rope_cache0 = select(rope_cache, 4, 0)
        # second half
        rope_cache1 = select(rope_cache, 4, 1)
        out0 = xshape0 * rope_cache0 - xshape1 * rope_cache1
        out1 = xshape1 * rope_cache0 + xshape0 * rope_cache1
        out0 = expand_dims(out0, 4)
        out1 = expand_dims(out1, 4)
        x_out2_v1 = concat([out0, out1], 4)
        x_out2 = x_out2_v1.view(
            concat([sq, b, nh, shape(x_out2_v1, 3) * shape(x_out2_v1, 4)]))
        output = concat([x_out2, x_pass], dim=3)
        # to batch,seq,num_group,head_states
        output = output.permute((1, 0, 2, 3))
        return output
    
    
    class RotaryEmbedding(Module):
    
        def __init__(self, dim):
            super().__init__()
            self.dim = dim
    
        def forward(self, seq_len: int):
            theta = 1.0 / (10000**(torch.arange(0, self.dim, 2) / self.dim))
            seq_idx = torch.arange(seq_len)
            idx_theta = torch.outer(seq_idx, theta).float()
            cache = torch.stack(
                [torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
            cache = cache.half()
            # create rope embeddings and make it constant
            cache = constant(cache.numpy())
            return cache
    
    
    class ChatGLM2Attention(Module):
    
        def __init__(
            self,
            hidden_size,
            num_attention_heads,
            layer_number,
            kv_channels=128,
            multi_query_group_num=2,
            apply_query_key_layer_scaling=False,
            attention_mask_type=AttentionMaskType.causal,
            qkv_bias=True,
            linear_bias=False,
            dtype='float16',
            use_int8_kv_cache=False,
            tp_group=None,
            tp_size=1,
        ):
            super().__init__()
    
            self.attention_mask_type = attention_mask_type
            self.attention_head_size = hidden_size // num_attention_heads
            self.num_attention_heads = num_attention_heads // tp_size
            self.num_multi_query_groups_per_partition = multi_query_group_num
            self.num_attention_kv_heads = self.num_attention_heads
            self.hidden_size = hidden_size // tp_size
            self.projection_size = num_attention_heads * kv_channels
            self.hidden_size_per_attention_head = kv_channels
            self.layer_number = layer_number
            self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
            self.q_scaling = 1
            if apply_query_key_layer_scaling:
                self.q_scaling *= self.layer_number
            self.position_embedding_type = PositionEmbeddingType.learned_absolute
            self.multi_block_mode = False
            self.multi_query_mode = False
    
            self.rotary_embedding_dim = 0
    
            self.dtype = dtype
    
            self.use_int8_kv_cache = use_int8_kv_cache
            if self.use_int8_kv_cache:
                self.kv_orig_quant_scale = Parameter(shape=(1, ), dtype='float32')
                self.kv_quant_orig_scale = Parameter(shape=(1, ), dtype='float32')
            else:
                self.register_parameter('kv_orig_quant_scale', None)
                self.register_parameter('kv_quant_orig_scale', None)
    
            # Note: in multi_query_mode, only query heads are split between multiple GPUs,
            # while key/value head are not split as there is only one head per key/value.
            # The output feature size is therefore (h/tp + 2) * d, where h is num_heads,
            # d is head_size, and tp is tensor_parallel_size.
            # In ColumnLinear op, the output dim is calculated by (h + 2*tp) * d / tp,
            # which matches the desired output size (h/tp + 2) * d after splitting
            self.qkv_hidden_size = (self.projection_size +
                                    2 * self.hidden_size_per_attention_head * 2)
            self.qkv = ColumnLinear(hidden_size,
                                    self.qkv_hidden_size,
                                    bias=qkv_bias,
                                    dtype=dtype,
                                    tp_group=tp_group,
                                    tp_size=tp_size,
                                    gather_output=False)
            self.dense = RowLinear(hidden_size,
                                   hidden_size,
                                   bias=linear_bias,
                                   dtype=dtype,
                                   tp_group=tp_group,
                                   tp_size=tp_size)
    
        def forward(self,
                    hidden_states: Tensor,
                    rotary_pos_emb,
                    use_cache=True,
                    kv_cache_params=None,
                    attention_params=None):
            if not default_net().plugin_config.gpt_attention_plugin:
                raise ValueError(
                    'ChatGLM2 is only supported with GPTAttention plugin,pleas build it with --use_gpt_attention_plugin argument.'
                )
            assert isinstance(hidden_states, Tensor)
            qkv = self.qkv(hidden_states)
            query, key, value = split(qkv, [
                self.num_attention_heads * self.hidden_size_per_attention_head,
                self.num_multi_query_groups_per_partition *
                self.hidden_size_per_attention_head,
                self.num_multi_query_groups_per_partition *
                self.hidden_size_per_attention_head,
            ],
                                      dim=-1)
            query = query.view(
                concat([
                    shape(qkv, 0),
                    shape(qkv, 1), self.num_attention_heads,
                    self.attention_head_size
                ]))
            key = key.view(
                concat([
                    shape(qkv, 0),
                    shape(qkv, 1), self.num_multi_query_groups_per_partition,
                    self.attention_head_size
                ]))
            value = value.view(
                concat([
                    shape(qkv, 0),
                    shape(qkv, 1), self.num_multi_query_groups_per_partition,
                    self.attention_head_size
                ]))
    
            if rotary_pos_emb is not None:
                query = apply_rotary_pos_emb_trt(query, rotary_pos_emb)
                key = apply_rotary_pos_emb_trt(key, rotary_pos_emb)
            # batch,seq,num_group,1,head_states
            key = expand_dims(key, 3)
            #expand 16x
            expand_rate = self.num_attention_heads // self.num_multi_query_groups_per_partition
            key = expand(
                key,
                concat([
                    shape(key, 0),
                    shape(key, 1),
                    shape(key, 2), expand_rate,
                    shape(key, 4)
                ]))
            # batch,seq,num_heads,head_states
            key = key.view(
                concat([
                    shape(key, 0),
                    shape(key, 1),
                    shape(key, 2) * shape(key, 3),
                    shape(key, 4)
                ]))
            value = expand_dims(value, 3)
            value = expand(
                value,
                concat([
                    shape(value, 0),
                    shape(value, 1),
                    shape(value, 2), expand_rate,
                    shape(value, 4)
                ]))
            value = value.view(
                concat([
                    shape(value, 0),
                    shape(value, 1),
                    shape(value, 2) * shape(value, 3),
                    shape(value, 4)
                ]))
            qkv = concat([query, key, value], dim=2)
            qkv = qkv.view(
                concat([shape(qkv, 0),
                        shape(qkv, 1), self.hidden_size * 3]))
            assert attention_params.is_valid(
                default_net().plugin_config.gpt_attention_plugin,
                default_net().plugin_config.remove_input_padding)
            assert kv_cache_params.is_valid(
                default_net().plugin_config.gpt_attention_plugin)
            kv_orig_quant_scale = self.kv_orig_quant_scale.value if self.use_int8_kv_cache else None
            kv_quant_orig_scale = self.kv_quant_orig_scale.value if self.use_int8_kv_cache else None
            context, past_key_value = gpt_attention(
                tensor=qkv,
                past_key_value=kv_cache_params.get_first_past_key_value(),
                sequence_length=attention_params.sequence_length,
                host_past_key_value_lengths=kv_cache_params.
                host_past_key_value_lengths,
                context_lengths=attention_params.context_lengths,
                cache_indirection=kv_cache_params.cache_indirection,
                host_request_types=attention_params.host_request_types,
                num_heads=self.num_attention_heads,
                num_kv_heads=self.
                num_attention_heads,  # since self.multi_query_mode is set to False
                hidden_size_per_head=self.attention_head_size,
                q_scaling=self.q_scaling,
                rotary_embedding_dim=self.rotary_embedding_dim,
                position_embedding_type=self.position_embedding_type,
                multi_block_mode=self.multi_block_mode,
                kv_orig_quant_scale=kv_orig_quant_scale,
                kv_quant_orig_scale=kv_quant_orig_scale,
                kv_cache_quant_mode=QuantMode.INT8_KV_CACHE
                if self.use_int8_kv_cache else QuantMode(0),
                kv_cache_block_pointers=kv_cache_params.
                get_first_kv_cache_block_pointers(),
                max_context_length=attention_params.max_context_length,
                host_context_lengths=attention_params.host_context_lengths)
            # dense layer after self-attention
            context = self.dense(context)
            if use_cache:
                return (context, past_key_value)
            else:
                return context
    
    
    class ChatGLM2Block(Module):
    
        def __init__(self,
                     hidden_size,
                     num_attention_heads,
                     kv_channels=128,
                     multi_query_group_num=2,
                     apply_query_key_layer_scaling=False,
                     attention_mask_type=AttentionMaskType.causal,
                     qkv_bias=True,
                     linear_bias=False,
                     use_int8_kv_cache=False,
                     tp_group=None,
                     tp_size=1,
                     ffn_hiden_size=13696,
                     layer_number=1,
                     eps=1e-5,
                     act_func='swiglu',
                     dtype=trt.float16,
                     quant_mode=QuantMode(0)):
            super(ChatGLM2Block, self).__init__()
            self.layer_number = layer_number
            self.hidden_size = hidden_size
            self.num_attention_heads = num_attention_heads
            self.dtype = dtype
            self.ffn_hiden_size = ffn_hiden_size
            self.apply_residual_connection_post_layernorm = False
            self.fp32_residual_connection = False
    
            LayerNormFunc = RmsNorm
            # Layernorm on the input data.
            self.input_layernorm = LayerNormFunc(self.hidden_size,
                                                 eps=eps,
                                                 dtype=dtype)
    
            # Self attention.
            self.self_attention = ChatGLM2Attention(
                hidden_size, num_attention_heads, layer_number, kv_channels,
                multi_query_group_num, apply_query_key_layer_scaling,
                attention_mask_type, qkv_bias, linear_bias, dtype,
                use_int8_kv_cache, tp_group, tp_size)
            self.hidden_dropout = 0.0
    
            # Layernorm on the attention output
            self.post_attention_layernorm = LayerNormFunc(self.hidden_size,
                                                          eps=eps,
                                                          dtype=dtype)
    
            self.mlp = MLP(self.hidden_size, ffn_hiden_size, act_func, linear_bias,
                           dtype)
    
        def forward(self,
                    hidden_states,
                    rotary_pos_emb,
                    use_cache=True,
                    kv_cache_params=None,
                    attention_params=None):
            # hidden_states: [s, b, h]
    
            # Layer norm at the beginning of the transformer layer.
            layernorm_output = self.input_layernorm(hidden_states)
            # Self attention.
    
            attention_output, kv_cache = self.self_attention(
                layernorm_output,
                rotary_pos_emb,
                use_cache=use_cache,
                kv_cache_params=kv_cache_params,
                attention_params=attention_params)
    
            # Residual connection.
            if self.apply_residual_connection_post_layernorm:
                residual = layernorm_output
            else:
                residual = hidden_states
    
            layernorm_input = hidden_states + attention_output
    
            # Layer norm post the self attention.
            layernorm_output = self.post_attention_layernorm(layernorm_input)
    
            # MLP.
            mlp_output = self.mlp(layernorm_output)
    
            # Second residual connection.
            if self.apply_residual_connection_post_layernorm:
                residual = layernorm_output
            else:
                residual = layernorm_input
    
            output = residual + mlp_output
    
            return output, kv_cache
    
    
    class ChatGLM2Transformer(Module):
        """Transformer class."""
    
        def __init__(self,
                     hidden_size,
                     num_attention_heads,
                     kv_channels=128,
                     multi_query_group_num=2,
                     apply_query_key_layer_scaling=False,
                     attention_mask_type=AttentionMaskType.causal,
                     qkv_bias=True,
                     linear_bias=False,
                     use_int8_kv_cache=False,
                     tp_group=None,
                     tp_size=1,
                     ffn_hiden_size=13696,
                     num_layers=28,
                     eps=1e-5,
                     act_func='swiglu',
                     dtype=trt.float16,
                     quant_mode=QuantMode(0)):
            super(ChatGLM2Transformer, self).__init__()
    
            self.fp32_residual_connection = False
            self.post_layer_norm = True
    
            # Number of layers.
            self.num_layers = num_layers
    
            # Transformer layers.
            def build_layer(layer_number):
                return ChatGLM2Block(hidden_size, num_attention_heads, kv_channels,
                                     multi_query_group_num,
                                     apply_query_key_layer_scaling,
                                     attention_mask_type, qkv_bias, linear_bias,
                                     use_int8_kv_cache, tp_group, tp_size,
                                     ffn_hiden_size, layer_number, eps, act_func,
                                     dtype, quant_mode)
    
            self.layers = ModuleList(
                build_layer(i + 1) for i in range(self.num_layers))
    
            if self.post_layer_norm:
                self.final_layernorm = RmsNorm(hidden_size, eps=eps, dtype=dtype)
    
            self.gradient_checkpointing = False
    
        def _get_layer(self, layer_number):
            return self.layers[layer_number]
    
        def forward(self,
                    hidden_states,
                    rotary_pos_emb,
                    use_cache=True,
                    kv_cache_params=None,
                    attention_params=None):
    
            presents = []
            for index in range(self.num_layers):
                layer = self._get_layer(index)
                hidden_states, kv_cache = layer(
                    hidden_states,
                    rotary_pos_emb,
                    use_cache=use_cache,
                    kv_cache_params=KeyValueCacheParams(
                        past_key_value=[kv_cache_params.past_key_value[index]],
                        kv_cache_block_pointers=[
                            kv_cache_params.kv_cache_block_pointers[index]
                        ],
                        host_past_key_value_lengths=kv_cache_params.
                        host_past_key_value_lengths,
                        cache_indirection=kv_cache_params.cache_indirection),
                    attention_params=attention_params)
                presents.append(kv_cache)
    
            if self.post_layer_norm:
                hidden_states = self.final_layernorm(hidden_states)
    
            return hidden_states, presents
    
    
    class ChatGLM2Model(Module):
    
        def __init__(self,
                     hidden_size,
                     num_attention_heads,
                     kv_channels=128,
                     multi_query_group_num=2,
                     apply_query_key_layer_scaling=False,
                     attention_mask_type=AttentionMaskType.causal,
                     qkv_bias=True,
                     linear_bias=False,
                     use_int8_kv_cache=False,
                     mapping=Mapping(),
                     ffn_hiden_size=13696,
                     num_layers=28,
                     eps=1e-5,
                     act_func='swiglu',
                     dtype=trt.float16,
                     quant_mode=QuantMode(0),
                     max_seq_length=32768,
                     vocab_size=65024):
            super(ChatGLM2Model, self).__init__()
    
            self.dtype = dtype
            self.embedding = Embedding(vocab_size, hidden_size, dtype=dtype)
            self.num_layers = num_layers
            self.multi_query_group_num = multi_query_group_num
            self.kv_channels = kv_channels
    
            # Rotary positional embeddings
            self.max_seq_length = max_seq_length
            rotary_dim = kv_channels
            self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, )
            self.encoder = ChatGLM2Transformer(
                hidden_size, num_attention_heads, kv_channels,
                multi_query_group_num, apply_query_key_layer_scaling,
                attention_mask_type, qkv_bias, linear_bias, use_int8_kv_cache,
                mapping.tp_group, mapping.tp_size, ffn_hiden_size, num_layers, eps,
                act_func, dtype, quant_mode)
    
        def forward(
            self,
            input_ids: Tensor,
            position_ids,
            use_cache=True,
            kv_cache_params=None,
            attention_params=None,
        ):
    
            inputs_embeds = self.embedding(input_ids)
            # Rotary positional embeddings
            # generate 32768 pos embeddings
            # max_seq_length,head_dim/4,2
            rotary_pos_emb = self.rotary_pos_emb(self.max_seq_length)
            flat_position = position_ids.view(
                concat([shape(position_ids, 0) * shape(position_ids, 1)]))
            selected_pos_emb = index_select(rotary_pos_emb, 0, flat_position)
            # selected batch,seq from rotary_pos_emb
            selected_pos_emb = selected_pos_emb.view(
                concat([
                    shape(position_ids, 0),
                    shape(position_ids, 1),
                    shape(rotary_pos_emb, 1),
                    shape(rotary_pos_emb, 2)
                ]))
            # seq,batch
            selected_pos_emb = selected_pos_emb.permute((1, 0, 2, 3))
            # return inputs_embeds,selected_pos_emb
            # Run encoder.
    
            hidden_states, presents = self.encoder(
                inputs_embeds,
                selected_pos_emb,
                use_cache=use_cache,
                kv_cache_params=kv_cache_params,
                attention_params=attention_params,
            )
            return hidden_states, presents
    
    
    class ChatGLM2HeadModel(ChatGLM2Model, GenerationMixin):
    
        def __init__(self,
                     hidden_size,
                     num_attention_heads,
                     kv_channels=128,
                     multi_query_group_num=2,
                     apply_query_key_layer_scaling=False,
                     attention_mask_type=AttentionMaskType.causal,
                     qkv_bias=True,
                     linear_bias=False,
                     use_int8_kv_cache=False,
                     mapping=Mapping(),
                     ffn_hiden_size=13696,
                     num_layers=28,
                     eps=1e-5,
                     act_func='swiglu',
                     dtype=trt.float16,
                     quant_mode=QuantMode(0),
                     max_seq_length=32768,
                     vocab_size=65024,
                     use_cache=True,
                     kv_cache_block_pointers=None):
            if isinstance(dtype, str):
                self._kv_dtype = str_dtype_to_trt(dtype)
            else:
                assert isinstance(dtype, trt.DataType)
                self._kv_dtype = dtype
            self._dtype = self._kv_dtype
            if quant_mode.has_int8_kv_cache():
                self._kv_dtype = str_dtype_to_trt('int8')
            elif quant_mode.has_fp8_kv_cache():
                self._kv_dtype = str_dtype_to_trt('fp8')
            self.use_cache = use_cache
            self.kv_cache_block_pointers = kv_cache_block_pointers
            self.quant_mode = quant_mode
            self._num_layers = num_layers
            self._num_heads = num_attention_heads
            self._hidden_size = hidden_size
            self._vocab_size = vocab_size
            self._tp_size = mapping.tp_size
            self.mapping = mapping
            super().__init__(hidden_size, num_attention_heads, kv_channels,
                             multi_query_group_num, apply_query_key_layer_scaling,
                             attention_mask_type, qkv_bias, linear_bias,
                             use_int8_kv_cache, mapping, ffn_hiden_size, num_layers,
                             eps, act_func, dtype, quant_mode, max_seq_length,
                             vocab_size)
            vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size)
            self.lm_head = ColumnLinear(hidden_size,
                                        vocab_size_padded,
                                        bias=False,
                                        dtype=dtype,
                                        tp_group=mapping.tp_group,
                                        tp_size=mapping.tp_size,
                                        gather_output=True)
    
        def forward(self,
                    input_ids=None,
                    position_ids=None,
                    last_token_ids=None,
                    kv_cache_params=None,
                    attention_params=None):
    
            hidden_states = super().forward(input_ids, position_ids, self.use_cache,
                                            kv_cache_params, attention_params)
    
            if self.use_cache:
                hidden_states, presents = hidden_states
    
            hidden_states = gather_last_token_logits(
                hidden_states, last_token_ids,
                default_net().plugin_config.remove_input_padding)
    
            lm_logits = self.lm_head(hidden_states)
            lm_logits.mark_output('logits', self._dtype)
    
            if default_net().plugin_config.paged_kv_cache == False:
                for i, present in enumerate(presents):
                    present.mark_output(f'present_key_value_{i}', self._kv_dtype)
                return (lm_logits, presents)
            return lm_logits
    
        def prepare_inputs(self,
                           max_batch_size,
                           max_input_len,
                           max_new_tokens,
                           use_cache,
                           max_beam_width: int = 1):
            '''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the
                ranges of the dimensions of when using TRT dynamic shapes.
    
                @return: a list contains values which can be fed into the self.forward()
            '''
            # Prepare inputs
            head_size = self._hidden_size // self._num_heads
            num_heads = self._num_heads // self._tp_size
            remove_input_padding = default_net().plugin_config.remove_input_padding
            use_gpt_attention_plugin = default_net(
            ).plugin_config.gpt_attention_plugin
            use_gemm_plugin = default_net().plugin_config.gemm_plugin
    
            #以下代码已修改,支持inflight_batching (by shaun xiong 231024)
            paged_kv_cache = default_net().plugin_config.paged_kv_cache
            tokens_per_block = default_net().plugin_config.tokens_per_block
            use_custom_all_reduce = default_net().plugin_config.use_custom_all_reduce
    
            model_inputs = self.prepare_basic_inputs(
                max_batch_size=max_batch_size,
                max_beam_width=max_beam_width,
                max_input_len=max_input_len,
                max_new_tokens=max_new_tokens,
                num_kv_heads = num_heads,
                num_heads = num_heads,
                head_size=head_size,
                num_layers=self.num_layers,
                kv_dtype=self._kv_dtype,
                remove_input_padding=remove_input_padding,
                use_gpt_attention_plugin=use_gpt_attention_plugin,
                use_gemm_plugin=use_gemm_plugin,
                use_custom_all_reduce=use_custom_all_reduce,
                paged_kv_cache=paged_kv_cache, 
                tokens_per_block=tokens_per_block,
                mapping=self.mapping
                )
            ##################################
            
    
            return (model_inputs['input_ids'], model_inputs['position_ids'],
                    model_inputs['last_token_ids'],
                    KeyValueCacheParams(
                        past_key_value=model_inputs['past_key_value'],
                        host_past_key_value_lengths=model_inputs[
                            'host_past_key_value_lengths'],
                        kv_cache_block_pointers=model_inputs[
                            'kv_cache_block_pointers_list'],
                        cache_indirection=model_inputs['cache_indirection'],
                    ),
                    AttentionParams(
                        sequence_length=model_inputs['sequence_length'],
                        context_lengths=model_inputs['context_lengths'],
                        host_context_lengths=model_inputs['host_context_lengths'],
                        max_context_length=max_input_len,
                        host_request_types=model_inputs['host_request_types']))
    

    以上代码修改完毕后,运行examples/chatglm2-6b/build.py脚本进行模型转换:

    python3 build.py --model_dir=./pyTorchModel \
                     --dtype float16 \
                     --use_gpt_attention_plugin float16 \
                     --use_gemm_plugin float16 \
                     --use_inflight_batching \
                     --use_weight_only

    增加--use_weight_only对权重进行INT 8量化,转换后的TensorRT-LLM Engine模型文件chatglm2-6b_float16_tp1_rank0.engine位于trtModel目录下:

    trtModel/
    ├── [6.3G]  chatglm2-6b_float16_tp1_rank0.engine
    ├── [1.3K]  config.json
    └── [424K]  model.cache
  • 获取chatglm2-6b模型预训练权重放到pyTorchModel目录下
    exit

3. 给Triton Inference Server添加 TensorRT-LLM Backend

目前最新版本的Triton Inference Server(23.09)暂不支持TensorRT-LLM,官方repo说将在23.10版本支持TensorRT-LLM,最新版本发布信息可关注官网,以下步骤通过手动构建包含TensorRT-LLM Backend的Triton Inference Server镜像。

  • 拉取tensorrtllm_backend库

    git clone h提提ps://github.com/triton-inference-server/tensorrtllm_backend.git

     

  • 手动构建TensorRT-LLM Backend镜像
    cd tensorrtllm_backend
    git submodule update --init --recursive
    git lfs install
    git lfs pull
    DOCKER_BUILDKIT=1 docker build -t triton_trt_llm -f dockerfile/Dockerfile.trt_llm_backend .

    运行docker build命令可能会遇到以下问题:

    a. 无法连接到pypi.org,无法wget获取cmake安装包等问题

    解决方法:修改tensorrtllm_backend/dockerfile/Dockerfile.trt_llm_backend文件更换pip源、手动安装cmake:

    #修改pip源
    RUN python -m pip install --upgrade pip
    RUN pip config set global.index-url h提提ps://pypi.tuna.tsinghua.edu.cn/simple
    
    # CMake
    # COPY tensorrt_llm/docker/common/install_cmake.sh /tmp/
    # RUN bash /tmp/install_cmake.sh && rm /tmp/install_cmake.sh
    # ENV PATH="/usr/local/cmake/bin:${PATH}"
    
    #改成手动安装cmake
    COPY tensorrt_llm/docker/common/cmake-3.27.7-linux-x86_64.tar.gz /tmp
    RUN tar -xf /tmp/cmake-3.27.7-linux-x86_64.tar.gz -C /usr/local/
    RUN ln -s /usr/local/cmake-3.27.7-linux-x86_64 /usr/local/cmake
    ENV PATH="/usr/local/cmake/bin:${PATH}"
    RUN rm -rf /tmp/cmake-3.27.7-linux-x86_64.tar.gz

    b. 

    执行到build.sh这一步时报错:fatal: unable to access 'h提提ps://github.com/triton-inference-server/common.git/': GnuTLS recv error (-110): The TLS connection was non-properly terminated.

    解决方法:在build.sh脚本cmake这一步前面添加以下两行命令更新h提提p post缓冲区的值:

    git config --global h提提p.postBuffer 1048576000
    git config --global h提提ps.postBuffer 1048576000

    上述命令执行完毕会生成一个支持TensorRT-LLM的Triton Inference Server镜像"triton_trt_llm:latest"

    4. 部署Triton推理服务

    • 准备模型文件和配置文件:

      参照tensorrtllm_backend官方文档分别准备好模型文件并修改配置文件,config文件请安装官方文档设置成支持inflight batching:如果不启用inflight batching,可以正常启动triton inference server,但官方给的推理测试脚本是inflight batching的,导致无法进行推理验证。。。

    • 启动triton inference server

    # Launch the Triton container
    docker run --rm -it --net host --shm-size=2g --ulimit memlock=-1 --ulimit stack=67108864 --gpus all -v /root/xiongx7/tensorrtllm_backend:/tensorrtllm_backend triton_trt_llm bash
    
    cd /tensorrtllm_backend
    # --world_size is the number of GPUs you want to use for serving
    python3 scripts/launch_triton_server.py --world_size=1 --model_repo=/tensorrtllm_backend/triton_model_repo

     

    • client推理测试

      python3 tools/inflight_batcher_llm/end_to_end_streaming_client.py -p "什么是机器学习?" -S -o 245

      输出了一堆编码:

      需要对tools/inflight_batcher_llm/end_to_end_streaming_client.py做一些修改,将输出结果解码为中文,并且连成一句话不要换行。具体地,修改end_to_end_streaming_client.py脚本中的test()函数:
      def test(triton_client, prompt):
          model_name = "ensemble"
      
          input0 = [[prompt]]
          input0_data = np.array(input0).astype(object)
          output0_len = np.ones_like(input0).astype(np.uint32) * FLAGS.output_len
          bad_words_list = np.array([[""]], dtype=object)
          stop_words_list = np.array([[""]], dtype=object)
          streaming = [[FLAGS.streaming]]
          streaming_data = np.array(streaming, dtype=bool)
      
          inputs = [
              utils.prepare_tensor("text_input", input0_data, FLAGS.protocol),
              utils.prepare_tensor("max_tokens", output0_len, FLAGS.protocol),
              utils.prepare_tensor("bad_words", bad_words_list, FLAGS.protocol),
              utils.prepare_tensor("stop_words", stop_words_list, FLAGS.protocol),
              utils.prepare_tensor("stream", streaming_data, FLAGS.protocol),
          ]
      
          user_data = UserData()
          # Establish stream
          triton_client.start_stream(callback=partial(callback, user_data))
          
          # 创建一个 StringIO 对象来捕获输出
          output_catcher = StringIO()
      
          # 重定向标准输出到 StringIO,避免编码直接在终端打印输出
          sys.stdout = output_catcher
          triton_client.async_stream_infer(model_name, inputs) # Send request
          
      
          #Wait for server to close the stream
          triton_client.stop_stream()
      
          # 恢复标准输出
          sys.stdout = sys.__stdout__
      
          # Parse the responses
          response_text = []
          while True:
              try:
                  result = user_data._completed_requests.get(block=False)
              except Exception:
                  break
      
              if type(result) == InferenceServerException:
                  print("Received an error from server:")
                  print(result)
              else:
                  text_output = result.as_numpy('text_output')
                  #将result解码成中文字符
                  for item in text_output:
                      chinese_text = item.decode('utf-8')
                      response_text.append(chinese_text)
          # 将所有token连接成一句话
          chinese_sentence = ''.join(response_text)
          print(chinese_sentence)

      再次运行client推理测试脚本可以正常返回:

      python3 tools/inflight_batcher_llm/end_to_end_streaming_client.py -p "什么是机器学习?" -S -o 245

      至此,成功完成基于TensorRT-LLM和Triton的ChatGLM2-6B模型推理验证!


    • 最后用以下命令终止Triton inference server
      pgrep tritonserver | xargs kill -9

5. 踩坑记录

  • config文件里设置了不启用inflight batching,Triton服务可以正常起来,但运行client测试脚本时报错维度不匹配

引起报错的原因:config文件中disable inflight batching, 给的client脚本是inflight batching的。

  • config文件里设置了启用inflight batching,启动triton inference server时候报错:TrtGptModelInflightBatching requires GPT attention plugin with packed input and paged KV cache.,这是因为前面模型转换的时候没有支持inflight batching

以上两个问题的解决方法:按照前文步骤修改build.pymodel.py,并且config文件中设置为启用inflight batching。

0条评论
0 / 1000
熊****雄
6文章数
0粉丝数
熊****雄
6 文章 | 0 粉丝
原创

基于TensorRT-LLM和Triton进行ChatGLM2-6B模型推理实践

2023-10-27 06:54:20
657
0

2023年10月19日,NVIDIA正式宣布TensorRT-LLM开放使用,TensorRT-LLM的主要特性有:

- 生成式AI应用端到端部署框架(模型构建、自定义、格式转换、部署)
- 支持多GPU多节点推理
- 包含常见大模型的转换、部署示例(LLaMA系列、ChatGLM系列、GPT系列、Baichuan、BLOOM、OPT、Falcon等)
- 提供Python API支持新模型的构建和转换
- 支持Triton推理服务框架
- 支持多种NVIDIA架构:Volta, Turing, Ampere, Hopper 和Ada Lovelace
- 除了FastTransformer中针对transformer结构的优化项,新增了多种针对大模型的优化项,如In-flight Batching、Paged KV Cache for the Attention、INT4/INT8 Weight-Only Quantization、SmoothQuant、Multi-head Attention(MHA)、Multi-query Attention (MQA)、Group-quer

本文参考TensorRT-LLM和Triton inference server开源的tensorrtllm_backend官方文档,成功验证了“TensorRT-LLM + Triton”这套方案部署ChatGLM2-6B模型推理服务的可行性。在开始正文前,做以下几点说明:

  • TensorRT-LLM开源的代码中不支持ChatGLM2-6B模型i动态批处理模式(in-flight batching),本文对源码做了一些修改以支持动态批处理;

    注:动态批处理是一种优化调度技术,处理批量请求时立即移除已完成的序列并开始处理下一个新的请求,此时batch中的其他请求仍然在进行中,这样可以提高吞吐量和GPU的利用率

  • 本文修改了tensorrtllm_backend给出的triton client推理测试脚本以增加人机交互的友好性;

  • 本文尝试多卡推理ChatGLM2-6B模型失败,给官方repo提issue得到的回复是目前TensorRT-LLM不支持ChatGLM2-6B的多卡推理,等待后续代码更新,本文用的是单机单卡;

1. TensorRT-LLM编译与镜像制作

目前TensorRT-LLM只能从源码手动安装,后续官方会推出包含TensorRT-LLM和Triton推理服务后端的Docker镜像。

  • 拉取TensorRT-LLM源码

    apt-get update && apt-get -y install git git-lfs
    
    git clone h提提ps://github.com/NVIDIA/TensorRT-LLM.git
    cd TensorRT-LLM
    git submodule update --init --recursive
    git lfs install
    git lfs pull
  • 构建TensorRT-LLM的docker镜像

    make -C docker release_build
  • 运行docker容器
    make -C docker run

    该命令会以tensorrt_llm/devel:latest为镜像启动一个名为tensorrt_llm-devel-root的容器

2. Huggingface格式模型转换为TensorRT-LLM Engine

注:以下步骤在tensorrt_llm-devel-root容器内执行

  • 获取chatglm2-6b模型预训练权重放到pyTorchModel目录下
    apt-get update
    apt-get install git-lfs
    cd examples/chatglm2-6b/
    git clone h提提ps://huggingface.co/THUDM/chatglm2-6b pyTorchModel
  • Huggingface模型转换成TensorRT-LLM Engine

    注意:

    • 目前TensorRT-LLM开源代码中的ChatGLM2-6B模型不支持inflight batching,本文对模型文件做了修改以支持inflight batching

    • build这一步设置的--world_size需要和后面推理时需要的gpu数量一致,否则后面推理会报错

    (重要)转换模型之前先修改build.py和model.py以支持inflight batching

    修改examples/chatglm2-6b/build.py如下:

    # SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
    # SPDX-License-Identifier: Apache-2.0
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    # h提提p://3w.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    import argparse
    import os
    import time
    
    import tensorrt as trt
    import torch
    import torch.multiprocessing as mp
    import transformers
    from weight import load_from_hf_chatglm2_6B
    
    import tensorrt_llm
    from tensorrt_llm._utils import str_dtype_to_trt
    from tensorrt_llm.builder import Builder
    from tensorrt_llm.layers import AttentionMaskType
    from tensorrt_llm.logger import logger
    from tensorrt_llm.mapping import Mapping
    from tensorrt_llm.models import (ChatGLM2HeadModel, smooth_quantize,
                                     weight_only_quantize)
    from tensorrt_llm.network import net_guard
    from tensorrt_llm.plugin.plugin import ContextFMHAType
    from tensorrt_llm.quantization import QuantMode
    
    MODEL_NAME = "chatglm2-6b"
    
    
    def get_engine_name(model, dtype, tp_size, rank):
        return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank)
    
    
    def serialize_engine(engine, path):
        logger.info(f'Serializing engine to {path}...')
        tik = time.time()
        with open(path, 'wb') as f:
            f.write(bytearray(engine))
        tok = time.time()
        t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
        logger.info(f'Engine serialized. Total time: {t}')
    
    
    def parse_arguments():
        parser = argparse.ArgumentParser()
        parser.add_argument('--world_size',
                            type=int,
                            default=1,
                            help='world size, only support tensor parallelism now')
        parser.add_argument('--model_dir', type=str, default="./pyTorchModel")
        parser.add_argument('--dtype',
                            type=str,
                            default='float16',
                            choices=['float16', 'float32'])
        parser.add_argument(
            '--timing_cache',
            type=str,
            default='model.cache',
            help=
            'The path of to read timing cache from, will be ignored if the file does not exist'
        )
        parser.add_argument(
            '--log_level',
            type=str,
            default='verbose',
            choices=['verbose', 'info', 'warning', 'error', 'internal_error'])
        parser.add_argument('--vocab_size', type=int, default=65024)
        parser.add_argument('--n_layer', type=int, default=28)
        parser.add_argument('--n_positions', type=int, default=2048)
        parser.add_argument('--n_embd', type=int, default=4096)
        parser.add_argument('--n_head', type=int, default=32)
        parser.add_argument('--hidden_act', type=str, default='gelu')
        parser.add_argument(
            '--rotary_pct',
            type=float,
            default=0.0,
            help="Setting this to a value > 0.0 (and <= 1.0) activates RoPE.")
        parser.add_argument('--inter_size', type=int, default=None)
        parser.add_argument('--no_bias', action="store_false")
        parser.add_argument('--max_batch_size', type=int, default=4)
        parser.add_argument('--max_input_len', type=int, default=4096)
        parser.add_argument('--max_output_len', type=int, default=2048)
        parser.add_argument('--max_beam_width', type=int, default=1)
        parser.add_argument(
            '--use_gpt_attention_plugin',
            nargs='?',
            const='float16',
            default='float16',
            # default=False,
            choices=['float16', 'float32', False])
        parser.add_argument('--use_gemm_plugin',
                            nargs='?',
                            const='float16',
                            type=str,
                            default='float16',
                            choices=['float16', 'float32'])
        parser.add_argument('--use_layernorm_plugin',
                            nargs='?',
                            const='float16',
                            type=str,
                            default='float16',
                            choices=['float16', 'float32'])
        parser.add_argument('--parallel_build', default=False, action='store_true')
        parser.add_argument('--enable_context_fmha',
                            default=False,
                            action='store_true')
        parser.add_argument('--enable_context_fmha_fp32_acc',
                            default=False,
                            action='store_true')
        parser.add_argument('--gpus_per_node', type=int, default=8)
        parser.add_argument('--builder_opt', type=int, default=None)
        parser.add_argument(
            '--output_dir',
            type=str,
            default='trtModel',
            help=
            'The path to save the serialized engine files, timing cache file and model configs'
        )
        parser.add_argument(
            "--multi_query_mode",
            "-mq",
            default=False,
            action='store_true',
            help=
            "Whether this model uses multi-query attention mechanism (default: False)"
        )
        # Arguments related to the quantization of the model.
        parser.add_argument(
            '--use_smooth_quant',
            default=False,
            action="store_true",
            help=
            'Use the SmoothQuant method to quantize activations and weights for the various GEMMs.'
            'See --per_channel and --per_token for finer-grained quantization options.'
        )
        parser.add_argument(
            '--use_weight_only',
            default=False,
            action="store_true",
            help='Quantize weights for the various GEMMs to INT4/INT8.'
            'See --weight_only_precision to set the precision')
        parser.add_argument(
            '--weight_only_precision',
            const='int8',
            type=str,
            nargs='?',
            default='int8',
            choices=['int8', 'int4'],
            help=
            'Define the precision for the weights when using weight-only quantization.'
            'You must also use --use_weight_only for that argument to have an impact.'
        )
        parser.add_argument(
            '--per_channel',
            default=False,
            action="store_true",
            help=
            'By default, we use a single static scaling factor for the GEMM\'s result. '
            'per_channel instead uses a different static scaling factor for each channel. '
            'The latter is usually more accurate, but a little slower.')
        parser.add_argument(
            '--per_token',
            default=False,
            action="store_true",
            help=
            'By default, we use a single static scaling factor to scale activations in the int8 range. '
            'per_token chooses at run time, and for each token, a custom scaling factor. '
            'The latter is usually more accurate, but a little slower.')
        parser.add_argument(
            '--int8_kv_cache',
            default=False,
            action="store_true",
            help=
            'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
        )
        parser.add_argument(
            '--random_seed',
            type=int,
            default=None,
            help=
            'Seed to use when initializing the random number generator for torch.')
    
    
        # 新增代码以支持inflight_batching, by shaun xiong (231024)
        parser.add_argument('--remove_input_padding',
                            default=False,
                            action='store_true')
        parser.add_argument(
            '--use_inflight_batching',
            action="store_true",
            default=False,
            help="Activates inflight batching mode of gptAttentionPlugin.")
        parser.add_argument(
            '--paged_kv_cache',
            action="store_true",
            default=False,
            help=
            'By default we use contiguous KV cache. By setting this flag you enable paged KV cache'
        )
        parser.add_argument('--tokens_per_block',
                            type=int,
                            default=64,
                            help='Number of tokens per block in paged KV cache')
        ########################################################
    
        args = parser.parse_args()
    
    
        # 新增代码以支持inflight_batching (by shaun xiong 231024)
        if not args.remove_input_padding:
            if args.use_gpt_attention_plugin:
                logger.warning(
                    f"It is recommended to specify --remove_input_padding when using GPT attention plugin"
                )
    
        if args.use_inflight_batching:
            if not args.use_gpt_attention_plugin:
                args.use_gpt_attention_plugin = 'float16'
                logger.info(
                    f"Using GPT attention plugin for inflight batching mode. "
                    f"Setting to default '{args.use_gpt_attention_plugin}'")
            if not args.remove_input_padding:
                args.remove_input_padding = True
                logger.info(
                    'Using remove input padding for inflight batching mode.')
            if not args.paged_kv_cache:
                args.paged_kv_cache = True
                logger.info('Using paged KV cache for inflight batching mode.')
        ########################################################
    
        assert not (
            args.use_smooth_quant and args.use_weight_only
        ), "You cannot enable both SmoothQuant and INT8 weight-only together."
    
        if args.use_smooth_quant:
            args.quant_mode = QuantMode.use_smooth_quant(args.per_token,
                                                         args.per_channel)
        elif args.use_weight_only:
            args.quant_mode = QuantMode.use_weight_only(
                args.weight_only_precision == 'int4')
        else:
            args.quant_mode = QuantMode(0)
        args.bias = not args.no_bias
    
        if args.inter_size is None:
            args.inter_size = 4 * args.n_embd
    
        if args.int8_kv_cache:
            assert (
                args.use_gpt_attention_plugin
            ), "You have to use GPT attention plugin when int8 KV cache is set"
            args.quant_mode = args.quant_mode.set_int8_kv_cache()
    
        return args
    
    
    def build_rank_engine(builder: Builder,
                          builder_config: tensorrt_llm.builder.BuilderConfig,
                          engine_name, rank, args):
        '''
           @brief: Build the engine on the given rank.
           @param rank: The rank to build the engine.
           @param args: The cmd line arguments.
           @return: The built engine.
        '''
        str_dtype_to_trt(args.dtype)
    
        # Initialize Module
        tensorrt_llm_ChatGLM2_6BModel = ChatGLM2HeadModel(
            hidden_size=4096,
            num_attention_heads=32,
            kv_channels=128,
            multi_query_group_num=2,
            apply_query_key_layer_scaling=False,
            attention_mask_type=AttentionMaskType.causal,
            qkv_bias=True,
            linear_bias=False,
            use_int8_kv_cache=False,
            mapping=Mapping(world_size=args.world_size,
                            rank=rank,
                            tp_size=args.world_size),
            ffn_hiden_size=13696,
            num_layers=28,
            eps=1e-5,
            act_func='swiglu',
            dtype=trt.float16,
            quant_mode=QuantMode(0),
            max_seq_length=32768,
            vocab_size=65024)
    
        if args.use_smooth_quant:
            tensorrt_llm_ChatGLM2_6BModel = smooth_quantize(
                tensorrt_llm_ChatGLM2_6BModel, args.quant_mode)
        elif args.use_weight_only:
            tensorrt_llm_ChatGLM2_6BModel = weight_only_quantize(
                tensorrt_llm_ChatGLM2_6BModel, args.quant_mode)
        if args.model_dir is not None:
            print('loading weights from hugging face model')
            hf_model = transformers.AutoModel.from_pretrained(
                args.model_dir, trust_remote_code=True).cpu()
            tensorrt_llm_ChatGLM2_6BModel = load_from_hf_chatglm2_6B(
                tensorrt_llm_ChatGLM2_6BModel, hf_model, dtype='float16')
            print('load finished')
            del hf_model
        # Module -> Network
        network = builder.create_network()
        network.trt_network.name = engine_name
        if args.use_gpt_attention_plugin:
            network.plugin_config.set_gpt_attention_plugin(
                dtype=args.use_gpt_attention_plugin)
        if args.use_gemm_plugin:
            network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
        if args.use_layernorm_plugin:
            network.plugin_config.set_layernorm_plugin(
                dtype=args.use_layernorm_plugin)
        assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc)
        if args.enable_context_fmha:
            network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
        if args.enable_context_fmha_fp32_acc:
            network.plugin_config.set_context_fmha(
                ContextFMHAType.enabled_with_fp32_acc)
        
        
    
        # Quantization plugins.
        if args.use_smooth_quant:
            network.plugin_config.set_smooth_quant_gemm_plugin(dtype=args.dtype)
            network.plugin_config.set_layernorm_quantization_plugin(
                dtype=args.dtype)
            # FIXME
            network.plugin_config.set_quantize_tensor_plugin()
            network.plugin_config.set_quantize_per_token_plugin()
        elif args.use_weight_only:
            network.plugin_config.set_weight_only_quant_matmul_plugin(
                dtype='float16')
    
        if args.world_size > 1:
            network.plugin_config.set_nccl_plugin(args.dtype)
        
        #新增代码以支持inflight_batching (by shaun xiong 231024)
        if args.remove_input_padding:
            network.plugin_config.enable_remove_input_padding()
        if args.paged_kv_cache:
            network.plugin_config.enable_paged_kv_cache(args.tokens_per_block)
        ####
        
    
        with net_guard(network):
            # Prepare
            network.set_named_parameters(
                tensorrt_llm_ChatGLM2_6BModel.named_parameters())
    
            # Forward
            inputs = tensorrt_llm_ChatGLM2_6BModel.prepare_inputs(
                args.max_batch_size, args.max_input_len, args.max_output_len, True,
                args.max_beam_width)
            tensorrt_llm_ChatGLM2_6BModel(*inputs)
    
        tensorrt_llm.graph_rewriting.optimize(network)
    
        engine = None
    
        # Network -> Engine
        engine = builder.build_engine(network, builder_config)
        if rank == 0:
            config_path = os.path.join(args.output_dir, 'config.json')
            builder.save_config(builder_config, config_path)
        return engine
    
    
    def build(rank, args):
        torch.cuda.set_device(rank % args.gpus_per_node)
        tensorrt_llm.logger.set_level(args.log_level)
        if not os.path.exists(args.output_dir):
            os.makedirs(args.output_dir)
    
        # when doing serializing build, all ranks share one engine
        apply_query_key_layer_scaling = False
        builder = Builder()
    
        cache = None
        for cur_rank in range(args.world_size):
            # skip other ranks if parallel_build is enabled
            if args.parallel_build and cur_rank != rank:
                continue
            builder_config = builder.create_builder_config(
                name=MODEL_NAME,
                precision=args.dtype,
                timing_cache=args.timing_cache if cache is None else cache,
                tensor_parallel=args.world_size,  # TP only
                parallel_build=args.parallel_build,
                num_layers=args.n_layer,
                num_heads=args.n_head,
                hidden_size=args.n_embd,
                vocab_size=args.vocab_size,
                hidden_act=args.hidden_act,
                max_position_embeddings=args.n_positions,
                apply_query_key_layer_scaling=apply_query_key_layer_scaling,
                max_batch_size=args.max_batch_size,
                max_input_len=args.max_input_len,
                max_output_len=args.max_output_len,
                int8=(args.quant_mode.has_act_and_weight_quant()
                      or args.quant_mode.has_int8_kv_cache()),
                opt_level=args.builder_opt,
                multi_query_mode=args.multi_query_mode)
    
            engine_name = get_engine_name(MODEL_NAME, args.dtype, args.world_size,
                                          cur_rank)
            engine = build_rank_engine(builder, builder_config, engine_name,
                                       cur_rank, args)
            assert engine is not None, f'Failed to build engine for rank {cur_rank}'
    
            if cur_rank == 0:
                # Use in-memory timing cache for multiple builder passes.
                if not args.parallel_build:
                    cache = builder_config.trt_builder_config.get_timing_cache()
    
            serialize_engine(engine, os.path.join(args.output_dir, engine_name))
    
        if rank == 0:
            ok = builder.save_timing_cache(
                builder_config, os.path.join(args.output_dir, "model.cache"))
            assert ok, "Failed to save timing cache."
    
    
    if __name__ == '__main__':
        args = parse_arguments()
    
        if args.random_seed is not None:
            torch.manual_seed(args.random_seed)
    
        logger.set_level(args.log_level)
        tik = time.time()
        if args.parallel_build and args.world_size > 1 and \
                torch.cuda.device_count() >= args.world_size:
            logger.warning(
                f'Parallelly build TensorRT engines. Please make sure that all of the {args.world_size} GPUs are totally free.'
            )
            mp.spawn(build, nprocs=args.world_size, args=(args, ))
        else:
            args.parallel_build = False
            logger.info('Serially build TensorRT engines.')
            build(0, args)
    
        tok = time.time()
        t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
        logger.info(f'Total time of building all {args.world_size} engines: {t}')
    

    修改/usr/local/lib/python3.10/dist-packages/tensorrt_llm/models/chatglm2_6b/model.py 如下:

    # SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
    # SPDX-License-Identifier: Apache-2.0
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    # h提提p://3w.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    import numpy as np
    import tensorrt as trt
    import torch
    
    from ..._common import default_net
    from ..._utils import pad_vocab_size, str_dtype_to_trt
    from ...functional import (PositionEmbeddingType, Tensor, concat, constant,
                               expand, expand_dims, gather_last_token_logits,
                               gpt_attention, index_select, select, shape, slice,
                               split)
    from ...layers import (MLP, AttentionMaskType, AttentionParams, ColumnLinear,
                           Embedding, KeyValueCacheParams, RmsNorm, RowLinear)
    from ...mapping import Mapping
    from ...module import Module, ModuleList
    from ...parameter import Parameter
    from ...quantization import QuantMode
    from ..generation_mixin import GenerationMixin
    
    
    def apply_rotary_pos_emb_trt(x: Tensor, rope_cache: Tensor) -> Tensor:
        # x-> [seq, batch, num_heads, 2]
        x = x.permute((1, 0, 2, 3))
        # sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
        sq = shape(x, 0)
        b = shape(x, 1)
        nh = shape(x, 2)
        shape(x, 3)
        # rope_cache shape: seq,batch,heads,2 rot_dim = 2* numheads
        #rope_cache: seq,batch,num_states/4,2
        rot_dim = shape(rope_cache, 2) * constant(np.array(2, dtype=np.int32))
        starts = concat([0, 0, 0, 0])
        sizes = concat([sq, b, nh, rot_dim])
        # first half
        x_rot = slice(x, starts, sizes)
        starts = concat([0, 0, 0, rot_dim])
        # second half
        x_pass = slice(x, starts, sizes)
        # truncate to support variable sizes
        rope_cache = slice(rope_cache, (0, 0, 0, 0), (concat(
            [sq,
             shape(rope_cache, 1),
             shape(rope_cache, 2),
             shape(rope_cache, 3)])))
        xshaped = x_rot.view(concat([sq, b, nh, rot_dim / 2, 2]))
        rope_cache = rope_cache.view(concat([sq, b, 1, shape(xshaped, 3), 2]))
        # first half
        xshape0 = select(xshaped, 4, 0)
        # second half
        xshape1 = select(xshaped, 4, 1)
        # first half
        rope_cache0 = select(rope_cache, 4, 0)
        # second half
        rope_cache1 = select(rope_cache, 4, 1)
        out0 = xshape0 * rope_cache0 - xshape1 * rope_cache1
        out1 = xshape1 * rope_cache0 + xshape0 * rope_cache1
        out0 = expand_dims(out0, 4)
        out1 = expand_dims(out1, 4)
        x_out2_v1 = concat([out0, out1], 4)
        x_out2 = x_out2_v1.view(
            concat([sq, b, nh, shape(x_out2_v1, 3) * shape(x_out2_v1, 4)]))
        output = concat([x_out2, x_pass], dim=3)
        # to batch,seq,num_group,head_states
        output = output.permute((1, 0, 2, 3))
        return output
    
    
    class RotaryEmbedding(Module):
    
        def __init__(self, dim):
            super().__init__()
            self.dim = dim
    
        def forward(self, seq_len: int):
            theta = 1.0 / (10000**(torch.arange(0, self.dim, 2) / self.dim))
            seq_idx = torch.arange(seq_len)
            idx_theta = torch.outer(seq_idx, theta).float()
            cache = torch.stack(
                [torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
            cache = cache.half()
            # create rope embeddings and make it constant
            cache = constant(cache.numpy())
            return cache
    
    
    class ChatGLM2Attention(Module):
    
        def __init__(
            self,
            hidden_size,
            num_attention_heads,
            layer_number,
            kv_channels=128,
            multi_query_group_num=2,
            apply_query_key_layer_scaling=False,
            attention_mask_type=AttentionMaskType.causal,
            qkv_bias=True,
            linear_bias=False,
            dtype='float16',
            use_int8_kv_cache=False,
            tp_group=None,
            tp_size=1,
        ):
            super().__init__()
    
            self.attention_mask_type = attention_mask_type
            self.attention_head_size = hidden_size // num_attention_heads
            self.num_attention_heads = num_attention_heads // tp_size
            self.num_multi_query_groups_per_partition = multi_query_group_num
            self.num_attention_kv_heads = self.num_attention_heads
            self.hidden_size = hidden_size // tp_size
            self.projection_size = num_attention_heads * kv_channels
            self.hidden_size_per_attention_head = kv_channels
            self.layer_number = layer_number
            self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
            self.q_scaling = 1
            if apply_query_key_layer_scaling:
                self.q_scaling *= self.layer_number
            self.position_embedding_type = PositionEmbeddingType.learned_absolute
            self.multi_block_mode = False
            self.multi_query_mode = False
    
            self.rotary_embedding_dim = 0
    
            self.dtype = dtype
    
            self.use_int8_kv_cache = use_int8_kv_cache
            if self.use_int8_kv_cache:
                self.kv_orig_quant_scale = Parameter(shape=(1, ), dtype='float32')
                self.kv_quant_orig_scale = Parameter(shape=(1, ), dtype='float32')
            else:
                self.register_parameter('kv_orig_quant_scale', None)
                self.register_parameter('kv_quant_orig_scale', None)
    
            # Note: in multi_query_mode, only query heads are split between multiple GPUs,
            # while key/value head are not split as there is only one head per key/value.
            # The output feature size is therefore (h/tp + 2) * d, where h is num_heads,
            # d is head_size, and tp is tensor_parallel_size.
            # In ColumnLinear op, the output dim is calculated by (h + 2*tp) * d / tp,
            # which matches the desired output size (h/tp + 2) * d after splitting
            self.qkv_hidden_size = (self.projection_size +
                                    2 * self.hidden_size_per_attention_head * 2)
            self.qkv = ColumnLinear(hidden_size,
                                    self.qkv_hidden_size,
                                    bias=qkv_bias,
                                    dtype=dtype,
                                    tp_group=tp_group,
                                    tp_size=tp_size,
                                    gather_output=False)
            self.dense = RowLinear(hidden_size,
                                   hidden_size,
                                   bias=linear_bias,
                                   dtype=dtype,
                                   tp_group=tp_group,
                                   tp_size=tp_size)
    
        def forward(self,
                    hidden_states: Tensor,
                    rotary_pos_emb,
                    use_cache=True,
                    kv_cache_params=None,
                    attention_params=None):
            if not default_net().plugin_config.gpt_attention_plugin:
                raise ValueError(
                    'ChatGLM2 is only supported with GPTAttention plugin,pleas build it with --use_gpt_attention_plugin argument.'
                )
            assert isinstance(hidden_states, Tensor)
            qkv = self.qkv(hidden_states)
            query, key, value = split(qkv, [
                self.num_attention_heads * self.hidden_size_per_attention_head,
                self.num_multi_query_groups_per_partition *
                self.hidden_size_per_attention_head,
                self.num_multi_query_groups_per_partition *
                self.hidden_size_per_attention_head,
            ],
                                      dim=-1)
            query = query.view(
                concat([
                    shape(qkv, 0),
                    shape(qkv, 1), self.num_attention_heads,
                    self.attention_head_size
                ]))
            key = key.view(
                concat([
                    shape(qkv, 0),
                    shape(qkv, 1), self.num_multi_query_groups_per_partition,
                    self.attention_head_size
                ]))
            value = value.view(
                concat([
                    shape(qkv, 0),
                    shape(qkv, 1), self.num_multi_query_groups_per_partition,
                    self.attention_head_size
                ]))
    
            if rotary_pos_emb is not None:
                query = apply_rotary_pos_emb_trt(query, rotary_pos_emb)
                key = apply_rotary_pos_emb_trt(key, rotary_pos_emb)
            # batch,seq,num_group,1,head_states
            key = expand_dims(key, 3)
            #expand 16x
            expand_rate = self.num_attention_heads // self.num_multi_query_groups_per_partition
            key = expand(
                key,
                concat([
                    shape(key, 0),
                    shape(key, 1),
                    shape(key, 2), expand_rate,
                    shape(key, 4)
                ]))
            # batch,seq,num_heads,head_states
            key = key.view(
                concat([
                    shape(key, 0),
                    shape(key, 1),
                    shape(key, 2) * shape(key, 3),
                    shape(key, 4)
                ]))
            value = expand_dims(value, 3)
            value = expand(
                value,
                concat([
                    shape(value, 0),
                    shape(value, 1),
                    shape(value, 2), expand_rate,
                    shape(value, 4)
                ]))
            value = value.view(
                concat([
                    shape(value, 0),
                    shape(value, 1),
                    shape(value, 2) * shape(value, 3),
                    shape(value, 4)
                ]))
            qkv = concat([query, key, value], dim=2)
            qkv = qkv.view(
                concat([shape(qkv, 0),
                        shape(qkv, 1), self.hidden_size * 3]))
            assert attention_params.is_valid(
                default_net().plugin_config.gpt_attention_plugin,
                default_net().plugin_config.remove_input_padding)
            assert kv_cache_params.is_valid(
                default_net().plugin_config.gpt_attention_plugin)
            kv_orig_quant_scale = self.kv_orig_quant_scale.value if self.use_int8_kv_cache else None
            kv_quant_orig_scale = self.kv_quant_orig_scale.value if self.use_int8_kv_cache else None
            context, past_key_value = gpt_attention(
                tensor=qkv,
                past_key_value=kv_cache_params.get_first_past_key_value(),
                sequence_length=attention_params.sequence_length,
                host_past_key_value_lengths=kv_cache_params.
                host_past_key_value_lengths,
                context_lengths=attention_params.context_lengths,
                cache_indirection=kv_cache_params.cache_indirection,
                host_request_types=attention_params.host_request_types,
                num_heads=self.num_attention_heads,
                num_kv_heads=self.
                num_attention_heads,  # since self.multi_query_mode is set to False
                hidden_size_per_head=self.attention_head_size,
                q_scaling=self.q_scaling,
                rotary_embedding_dim=self.rotary_embedding_dim,
                position_embedding_type=self.position_embedding_type,
                multi_block_mode=self.multi_block_mode,
                kv_orig_quant_scale=kv_orig_quant_scale,
                kv_quant_orig_scale=kv_quant_orig_scale,
                kv_cache_quant_mode=QuantMode.INT8_KV_CACHE
                if self.use_int8_kv_cache else QuantMode(0),
                kv_cache_block_pointers=kv_cache_params.
                get_first_kv_cache_block_pointers(),
                max_context_length=attention_params.max_context_length,
                host_context_lengths=attention_params.host_context_lengths)
            # dense layer after self-attention
            context = self.dense(context)
            if use_cache:
                return (context, past_key_value)
            else:
                return context
    
    
    class ChatGLM2Block(Module):
    
        def __init__(self,
                     hidden_size,
                     num_attention_heads,
                     kv_channels=128,
                     multi_query_group_num=2,
                     apply_query_key_layer_scaling=False,
                     attention_mask_type=AttentionMaskType.causal,
                     qkv_bias=True,
                     linear_bias=False,
                     use_int8_kv_cache=False,
                     tp_group=None,
                     tp_size=1,
                     ffn_hiden_size=13696,
                     layer_number=1,
                     eps=1e-5,
                     act_func='swiglu',
                     dtype=trt.float16,
                     quant_mode=QuantMode(0)):
            super(ChatGLM2Block, self).__init__()
            self.layer_number = layer_number
            self.hidden_size = hidden_size
            self.num_attention_heads = num_attention_heads
            self.dtype = dtype
            self.ffn_hiden_size = ffn_hiden_size
            self.apply_residual_connection_post_layernorm = False
            self.fp32_residual_connection = False
    
            LayerNormFunc = RmsNorm
            # Layernorm on the input data.
            self.input_layernorm = LayerNormFunc(self.hidden_size,
                                                 eps=eps,
                                                 dtype=dtype)
    
            # Self attention.
            self.self_attention = ChatGLM2Attention(
                hidden_size, num_attention_heads, layer_number, kv_channels,
                multi_query_group_num, apply_query_key_layer_scaling,
                attention_mask_type, qkv_bias, linear_bias, dtype,
                use_int8_kv_cache, tp_group, tp_size)
            self.hidden_dropout = 0.0
    
            # Layernorm on the attention output
            self.post_attention_layernorm = LayerNormFunc(self.hidden_size,
                                                          eps=eps,
                                                          dtype=dtype)
    
            self.mlp = MLP(self.hidden_size, ffn_hiden_size, act_func, linear_bias,
                           dtype)
    
        def forward(self,
                    hidden_states,
                    rotary_pos_emb,
                    use_cache=True,
                    kv_cache_params=None,
                    attention_params=None):
            # hidden_states: [s, b, h]
    
            # Layer norm at the beginning of the transformer layer.
            layernorm_output = self.input_layernorm(hidden_states)
            # Self attention.
    
            attention_output, kv_cache = self.self_attention(
                layernorm_output,
                rotary_pos_emb,
                use_cache=use_cache,
                kv_cache_params=kv_cache_params,
                attention_params=attention_params)
    
            # Residual connection.
            if self.apply_residual_connection_post_layernorm:
                residual = layernorm_output
            else:
                residual = hidden_states
    
            layernorm_input = hidden_states + attention_output
    
            # Layer norm post the self attention.
            layernorm_output = self.post_attention_layernorm(layernorm_input)
    
            # MLP.
            mlp_output = self.mlp(layernorm_output)
    
            # Second residual connection.
            if self.apply_residual_connection_post_layernorm:
                residual = layernorm_output
            else:
                residual = layernorm_input
    
            output = residual + mlp_output
    
            return output, kv_cache
    
    
    class ChatGLM2Transformer(Module):
        """Transformer class."""
    
        def __init__(self,
                     hidden_size,
                     num_attention_heads,
                     kv_channels=128,
                     multi_query_group_num=2,
                     apply_query_key_layer_scaling=False,
                     attention_mask_type=AttentionMaskType.causal,
                     qkv_bias=True,
                     linear_bias=False,
                     use_int8_kv_cache=False,
                     tp_group=None,
                     tp_size=1,
                     ffn_hiden_size=13696,
                     num_layers=28,
                     eps=1e-5,
                     act_func='swiglu',
                     dtype=trt.float16,
                     quant_mode=QuantMode(0)):
            super(ChatGLM2Transformer, self).__init__()
    
            self.fp32_residual_connection = False
            self.post_layer_norm = True
    
            # Number of layers.
            self.num_layers = num_layers
    
            # Transformer layers.
            def build_layer(layer_number):
                return ChatGLM2Block(hidden_size, num_attention_heads, kv_channels,
                                     multi_query_group_num,
                                     apply_query_key_layer_scaling,
                                     attention_mask_type, qkv_bias, linear_bias,
                                     use_int8_kv_cache, tp_group, tp_size,
                                     ffn_hiden_size, layer_number, eps, act_func,
                                     dtype, quant_mode)
    
            self.layers = ModuleList(
                build_layer(i + 1) for i in range(self.num_layers))
    
            if self.post_layer_norm:
                self.final_layernorm = RmsNorm(hidden_size, eps=eps, dtype=dtype)
    
            self.gradient_checkpointing = False
    
        def _get_layer(self, layer_number):
            return self.layers[layer_number]
    
        def forward(self,
                    hidden_states,
                    rotary_pos_emb,
                    use_cache=True,
                    kv_cache_params=None,
                    attention_params=None):
    
            presents = []
            for index in range(self.num_layers):
                layer = self._get_layer(index)
                hidden_states, kv_cache = layer(
                    hidden_states,
                    rotary_pos_emb,
                    use_cache=use_cache,
                    kv_cache_params=KeyValueCacheParams(
                        past_key_value=[kv_cache_params.past_key_value[index]],
                        kv_cache_block_pointers=[
                            kv_cache_params.kv_cache_block_pointers[index]
                        ],
                        host_past_key_value_lengths=kv_cache_params.
                        host_past_key_value_lengths,
                        cache_indirection=kv_cache_params.cache_indirection),
                    attention_params=attention_params)
                presents.append(kv_cache)
    
            if self.post_layer_norm:
                hidden_states = self.final_layernorm(hidden_states)
    
            return hidden_states, presents
    
    
    class ChatGLM2Model(Module):
    
        def __init__(self,
                     hidden_size,
                     num_attention_heads,
                     kv_channels=128,
                     multi_query_group_num=2,
                     apply_query_key_layer_scaling=False,
                     attention_mask_type=AttentionMaskType.causal,
                     qkv_bias=True,
                     linear_bias=False,
                     use_int8_kv_cache=False,
                     mapping=Mapping(),
                     ffn_hiden_size=13696,
                     num_layers=28,
                     eps=1e-5,
                     act_func='swiglu',
                     dtype=trt.float16,
                     quant_mode=QuantMode(0),
                     max_seq_length=32768,
                     vocab_size=65024):
            super(ChatGLM2Model, self).__init__()
    
            self.dtype = dtype
            self.embedding = Embedding(vocab_size, hidden_size, dtype=dtype)
            self.num_layers = num_layers
            self.multi_query_group_num = multi_query_group_num
            self.kv_channels = kv_channels
    
            # Rotary positional embeddings
            self.max_seq_length = max_seq_length
            rotary_dim = kv_channels
            self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, )
            self.encoder = ChatGLM2Transformer(
                hidden_size, num_attention_heads, kv_channels,
                multi_query_group_num, apply_query_key_layer_scaling,
                attention_mask_type, qkv_bias, linear_bias, use_int8_kv_cache,
                mapping.tp_group, mapping.tp_size, ffn_hiden_size, num_layers, eps,
                act_func, dtype, quant_mode)
    
        def forward(
            self,
            input_ids: Tensor,
            position_ids,
            use_cache=True,
            kv_cache_params=None,
            attention_params=None,
        ):
    
            inputs_embeds = self.embedding(input_ids)
            # Rotary positional embeddings
            # generate 32768 pos embeddings
            # max_seq_length,head_dim/4,2
            rotary_pos_emb = self.rotary_pos_emb(self.max_seq_length)
            flat_position = position_ids.view(
                concat([shape(position_ids, 0) * shape(position_ids, 1)]))
            selected_pos_emb = index_select(rotary_pos_emb, 0, flat_position)
            # selected batch,seq from rotary_pos_emb
            selected_pos_emb = selected_pos_emb.view(
                concat([
                    shape(position_ids, 0),
                    shape(position_ids, 1),
                    shape(rotary_pos_emb, 1),
                    shape(rotary_pos_emb, 2)
                ]))
            # seq,batch
            selected_pos_emb = selected_pos_emb.permute((1, 0, 2, 3))
            # return inputs_embeds,selected_pos_emb
            # Run encoder.
    
            hidden_states, presents = self.encoder(
                inputs_embeds,
                selected_pos_emb,
                use_cache=use_cache,
                kv_cache_params=kv_cache_params,
                attention_params=attention_params,
            )
            return hidden_states, presents
    
    
    class ChatGLM2HeadModel(ChatGLM2Model, GenerationMixin):
    
        def __init__(self,
                     hidden_size,
                     num_attention_heads,
                     kv_channels=128,
                     multi_query_group_num=2,
                     apply_query_key_layer_scaling=False,
                     attention_mask_type=AttentionMaskType.causal,
                     qkv_bias=True,
                     linear_bias=False,
                     use_int8_kv_cache=False,
                     mapping=Mapping(),
                     ffn_hiden_size=13696,
                     num_layers=28,
                     eps=1e-5,
                     act_func='swiglu',
                     dtype=trt.float16,
                     quant_mode=QuantMode(0),
                     max_seq_length=32768,
                     vocab_size=65024,
                     use_cache=True,
                     kv_cache_block_pointers=None):
            if isinstance(dtype, str):
                self._kv_dtype = str_dtype_to_trt(dtype)
            else:
                assert isinstance(dtype, trt.DataType)
                self._kv_dtype = dtype
            self._dtype = self._kv_dtype
            if quant_mode.has_int8_kv_cache():
                self._kv_dtype = str_dtype_to_trt('int8')
            elif quant_mode.has_fp8_kv_cache():
                self._kv_dtype = str_dtype_to_trt('fp8')
            self.use_cache = use_cache
            self.kv_cache_block_pointers = kv_cache_block_pointers
            self.quant_mode = quant_mode
            self._num_layers = num_layers
            self._num_heads = num_attention_heads
            self._hidden_size = hidden_size
            self._vocab_size = vocab_size
            self._tp_size = mapping.tp_size
            self.mapping = mapping
            super().__init__(hidden_size, num_attention_heads, kv_channels,
                             multi_query_group_num, apply_query_key_layer_scaling,
                             attention_mask_type, qkv_bias, linear_bias,
                             use_int8_kv_cache, mapping, ffn_hiden_size, num_layers,
                             eps, act_func, dtype, quant_mode, max_seq_length,
                             vocab_size)
            vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size)
            self.lm_head = ColumnLinear(hidden_size,
                                        vocab_size_padded,
                                        bias=False,
                                        dtype=dtype,
                                        tp_group=mapping.tp_group,
                                        tp_size=mapping.tp_size,
                                        gather_output=True)
    
        def forward(self,
                    input_ids=None,
                    position_ids=None,
                    last_token_ids=None,
                    kv_cache_params=None,
                    attention_params=None):
    
            hidden_states = super().forward(input_ids, position_ids, self.use_cache,
                                            kv_cache_params, attention_params)
    
            if self.use_cache:
                hidden_states, presents = hidden_states
    
            hidden_states = gather_last_token_logits(
                hidden_states, last_token_ids,
                default_net().plugin_config.remove_input_padding)
    
            lm_logits = self.lm_head(hidden_states)
            lm_logits.mark_output('logits', self._dtype)
    
            if default_net().plugin_config.paged_kv_cache == False:
                for i, present in enumerate(presents):
                    present.mark_output(f'present_key_value_{i}', self._kv_dtype)
                return (lm_logits, presents)
            return lm_logits
    
        def prepare_inputs(self,
                           max_batch_size,
                           max_input_len,
                           max_new_tokens,
                           use_cache,
                           max_beam_width: int = 1):
            '''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the
                ranges of the dimensions of when using TRT dynamic shapes.
    
                @return: a list contains values which can be fed into the self.forward()
            '''
            # Prepare inputs
            head_size = self._hidden_size // self._num_heads
            num_heads = self._num_heads // self._tp_size
            remove_input_padding = default_net().plugin_config.remove_input_padding
            use_gpt_attention_plugin = default_net(
            ).plugin_config.gpt_attention_plugin
            use_gemm_plugin = default_net().plugin_config.gemm_plugin
    
            #以下代码已修改,支持inflight_batching (by shaun xiong 231024)
            paged_kv_cache = default_net().plugin_config.paged_kv_cache
            tokens_per_block = default_net().plugin_config.tokens_per_block
            use_custom_all_reduce = default_net().plugin_config.use_custom_all_reduce
    
            model_inputs = self.prepare_basic_inputs(
                max_batch_size=max_batch_size,
                max_beam_width=max_beam_width,
                max_input_len=max_input_len,
                max_new_tokens=max_new_tokens,
                num_kv_heads = num_heads,
                num_heads = num_heads,
                head_size=head_size,
                num_layers=self.num_layers,
                kv_dtype=self._kv_dtype,
                remove_input_padding=remove_input_padding,
                use_gpt_attention_plugin=use_gpt_attention_plugin,
                use_gemm_plugin=use_gemm_plugin,
                use_custom_all_reduce=use_custom_all_reduce,
                paged_kv_cache=paged_kv_cache, 
                tokens_per_block=tokens_per_block,
                mapping=self.mapping
                )
            ##################################
            
    
            return (model_inputs['input_ids'], model_inputs['position_ids'],
                    model_inputs['last_token_ids'],
                    KeyValueCacheParams(
                        past_key_value=model_inputs['past_key_value'],
                        host_past_key_value_lengths=model_inputs[
                            'host_past_key_value_lengths'],
                        kv_cache_block_pointers=model_inputs[
                            'kv_cache_block_pointers_list'],
                        cache_indirection=model_inputs['cache_indirection'],
                    ),
                    AttentionParams(
                        sequence_length=model_inputs['sequence_length'],
                        context_lengths=model_inputs['context_lengths'],
                        host_context_lengths=model_inputs['host_context_lengths'],
                        max_context_length=max_input_len,
                        host_request_types=model_inputs['host_request_types']))
    

    以上代码修改完毕后,运行examples/chatglm2-6b/build.py脚本进行模型转换:

    python3 build.py --model_dir=./pyTorchModel \
                     --dtype float16 \
                     --use_gpt_attention_plugin float16 \
                     --use_gemm_plugin float16 \
                     --use_inflight_batching \
                     --use_weight_only

    增加--use_weight_only对权重进行INT 8量化,转换后的TensorRT-LLM Engine模型文件chatglm2-6b_float16_tp1_rank0.engine位于trtModel目录下:

    trtModel/
    ├── [6.3G]  chatglm2-6b_float16_tp1_rank0.engine
    ├── [1.3K]  config.json
    └── [424K]  model.cache
  • 获取chatglm2-6b模型预训练权重放到pyTorchModel目录下
    exit

3. 给Triton Inference Server添加 TensorRT-LLM Backend

目前最新版本的Triton Inference Server(23.09)暂不支持TensorRT-LLM,官方repo说将在23.10版本支持TensorRT-LLM,最新版本发布信息可关注官网,以下步骤通过手动构建包含TensorRT-LLM Backend的Triton Inference Server镜像。

  • 拉取tensorrtllm_backend库

    git clone h提提ps://github.com/triton-inference-server/tensorrtllm_backend.git

     

  • 手动构建TensorRT-LLM Backend镜像
    cd tensorrtllm_backend
    git submodule update --init --recursive
    git lfs install
    git lfs pull
    DOCKER_BUILDKIT=1 docker build -t triton_trt_llm -f dockerfile/Dockerfile.trt_llm_backend .

    运行docker build命令可能会遇到以下问题:

    a. 无法连接到pypi.org,无法wget获取cmake安装包等问题

    解决方法:修改tensorrtllm_backend/dockerfile/Dockerfile.trt_llm_backend文件更换pip源、手动安装cmake:

    #修改pip源
    RUN python -m pip install --upgrade pip
    RUN pip config set global.index-url h提提ps://pypi.tuna.tsinghua.edu.cn/simple
    
    # CMake
    # COPY tensorrt_llm/docker/common/install_cmake.sh /tmp/
    # RUN bash /tmp/install_cmake.sh && rm /tmp/install_cmake.sh
    # ENV PATH="/usr/local/cmake/bin:${PATH}"
    
    #改成手动安装cmake
    COPY tensorrt_llm/docker/common/cmake-3.27.7-linux-x86_64.tar.gz /tmp
    RUN tar -xf /tmp/cmake-3.27.7-linux-x86_64.tar.gz -C /usr/local/
    RUN ln -s /usr/local/cmake-3.27.7-linux-x86_64 /usr/local/cmake
    ENV PATH="/usr/local/cmake/bin:${PATH}"
    RUN rm -rf /tmp/cmake-3.27.7-linux-x86_64.tar.gz

    b. 

    执行到build.sh这一步时报错:fatal: unable to access 'h提提ps://github.com/triton-inference-server/common.git/': GnuTLS recv error (-110): The TLS connection was non-properly terminated.

    解决方法:在build.sh脚本cmake这一步前面添加以下两行命令更新h提提p post缓冲区的值:

    git config --global h提提p.postBuffer 1048576000
    git config --global h提提ps.postBuffer 1048576000

    上述命令执行完毕会生成一个支持TensorRT-LLM的Triton Inference Server镜像"triton_trt_llm:latest"

    4. 部署Triton推理服务

    • 准备模型文件和配置文件:

      参照tensorrtllm_backend官方文档分别准备好模型文件并修改配置文件,config文件请安装官方文档设置成支持inflight batching:如果不启用inflight batching,可以正常启动triton inference server,但官方给的推理测试脚本是inflight batching的,导致无法进行推理验证。。。

    • 启动triton inference server

    # Launch the Triton container
    docker run --rm -it --net host --shm-size=2g --ulimit memlock=-1 --ulimit stack=67108864 --gpus all -v /root/xiongx7/tensorrtllm_backend:/tensorrtllm_backend triton_trt_llm bash
    
    cd /tensorrtllm_backend
    # --world_size is the number of GPUs you want to use for serving
    python3 scripts/launch_triton_server.py --world_size=1 --model_repo=/tensorrtllm_backend/triton_model_repo

     

    • client推理测试

      python3 tools/inflight_batcher_llm/end_to_end_streaming_client.py -p "什么是机器学习?" -S -o 245

      输出了一堆编码:

      需要对tools/inflight_batcher_llm/end_to_end_streaming_client.py做一些修改,将输出结果解码为中文,并且连成一句话不要换行。具体地,修改end_to_end_streaming_client.py脚本中的test()函数:
      def test(triton_client, prompt):
          model_name = "ensemble"
      
          input0 = [[prompt]]
          input0_data = np.array(input0).astype(object)
          output0_len = np.ones_like(input0).astype(np.uint32) * FLAGS.output_len
          bad_words_list = np.array([[""]], dtype=object)
          stop_words_list = np.array([[""]], dtype=object)
          streaming = [[FLAGS.streaming]]
          streaming_data = np.array(streaming, dtype=bool)
      
          inputs = [
              utils.prepare_tensor("text_input", input0_data, FLAGS.protocol),
              utils.prepare_tensor("max_tokens", output0_len, FLAGS.protocol),
              utils.prepare_tensor("bad_words", bad_words_list, FLAGS.protocol),
              utils.prepare_tensor("stop_words", stop_words_list, FLAGS.protocol),
              utils.prepare_tensor("stream", streaming_data, FLAGS.protocol),
          ]
      
          user_data = UserData()
          # Establish stream
          triton_client.start_stream(callback=partial(callback, user_data))
          
          # 创建一个 StringIO 对象来捕获输出
          output_catcher = StringIO()
      
          # 重定向标准输出到 StringIO,避免编码直接在终端打印输出
          sys.stdout = output_catcher
          triton_client.async_stream_infer(model_name, inputs) # Send request
          
      
          #Wait for server to close the stream
          triton_client.stop_stream()
      
          # 恢复标准输出
          sys.stdout = sys.__stdout__
      
          # Parse the responses
          response_text = []
          while True:
              try:
                  result = user_data._completed_requests.get(block=False)
              except Exception:
                  break
      
              if type(result) == InferenceServerException:
                  print("Received an error from server:")
                  print(result)
              else:
                  text_output = result.as_numpy('text_output')
                  #将result解码成中文字符
                  for item in text_output:
                      chinese_text = item.decode('utf-8')
                      response_text.append(chinese_text)
          # 将所有token连接成一句话
          chinese_sentence = ''.join(response_text)
          print(chinese_sentence)

      再次运行client推理测试脚本可以正常返回:

      python3 tools/inflight_batcher_llm/end_to_end_streaming_client.py -p "什么是机器学习?" -S -o 245

      至此,成功完成基于TensorRT-LLM和Triton的ChatGLM2-6B模型推理验证!


    • 最后用以下命令终止Triton inference server
      pgrep tritonserver | xargs kill -9

5. 踩坑记录

  • config文件里设置了不启用inflight batching,Triton服务可以正常起来,但运行client测试脚本时报错维度不匹配

引起报错的原因:config文件中disable inflight batching, 给的client脚本是inflight batching的。

  • config文件里设置了启用inflight batching,启动triton inference server时候报错:TrtGptModelInflightBatching requires GPT attention plugin with packed input and paged KV cache.,这是因为前面模型转换的时候没有支持inflight batching

以上两个问题的解决方法:按照前文步骤修改build.pymodel.py,并且config文件中设置为启用inflight batching。

文章来自个人专栏
云计算与AI
6 文章 | 1 订阅
0条评论
0 / 1000
请输入你的评论
0
0