在Triton中部署半精度的TensorRT模型
使用 TensorRT 的API或者命令行工具可以将 ONNX 模型转换成支持 TensorRT 的 engine 模型文件,然后使用 Triton Inference Server 进行部署可以获得2~4倍的吞吐量的提升。但是,当在转换过程中设置了量化参数 — fp16或者 — best,同时待转换的模型中包含像 layernorm 这类需要以全精度运算的层时, TensorRT 会把使用 fp32 的权重参数强制转换成了 fp16 ,导致推理结果发生偏移。
...
[08/30/2023-09:33:05] [W] [TRT] TensorRT encountered issues when converting weights between types and that could affect accuracy.
[08/30/2023-09:33:05] [W] [TRT] If this is not the desired behavior, please modify the weights or retrain with regularization to adjust the magnitude of the weights.
[08/30/2023-09:33:05] [W] [TRT] Check verbose logs for the list of affected weights.
[08/30/2023-09:33:05] [W] [TRT] - 41 weights are affected by this issue: Detected subnormal FP16 values.
...
[08/30/2023-09:33:05] [W] [TRT] - 23 weights are affected by this issue: Detected values less than smallest positive FP16 subnormal value and converted them to the FP16 minimum subnormalized value.
[08/30/2023-09:33:05] [V] [TRT] List of affected weights: /model/embeddings/LayerNorm/Constant_1_output_0 + (Unnamed Layer* 105) [Shuffle], /model/encoder/layer.0/attention/output/LayerNorm/Constant_1_output_0 + (Unnamed Layer* 217) [Shuffle], /model/encoder/layer.0/output/LayerNorm/Constant_1_output_0 + (Unnamed Layer* 258) [Shuffle], /model/encoder/layer.1/attention/output/LayerNorm/Constant_1_output_0 + (Unnamed Layer* 370) [Shuffle], /model/encoder/layer.1/output/LayerNorm/Constant_1_output_0 + (Unnamed Layer* 411) [Shuffle], /model/encoder/layer.2/attention/output/LayerNorm/Constant_1_output_0 + (Unnamed Layer* 523) [Shuffle], /model/encoder/layer.2/output/LayerNorm/Constant_1_output_0 + (Unnamed Layer* 564) [Shuffle], model_embeddings_position_embeddings_weight_constant, model_embeddings_word_embeddings_weight_constant, model_encoder_layer_0_attention_self_key_bias _ (Unnamed Layer_ 125) [Shuffle]_constant, model_encoder_layer_1_attention_self_key_bias _ (Unnamed Layer_ 278) [Shuffle]_constant, model_encoder_layer_2_attention_self_key_bias _ (Unnamed Layer_ 431) [Shuffle]_constant, onnx__MatMul_465 _ (Unnamed Layer_ 140) [Shuffle]_constant, onnx__MatMul_471 _ (Unnamed Layer_ 204) [Shuffle]_constant, onnx__MatMul_472 _ (Unnamed Layer_ 228) [Shuffle]_constant, onnx__MatMul_473 _ (Unnamed Layer_ 245) [Shuffle]_constant, onnx__MatMul_474 _ (Unnamed Layer_ 269) [Shuffle]_constant, onnx__MatMul_484 _ (Unnamed Layer_ 357) [Shuffle]_constant, onnx__MatMul_486 _ (Unnamed Layer_ 398) [Shuffle]_constant, onnx__MatMul_487 _ (Unnamed Layer_ 422) [Shuffle]_constant, onnx__MatMul_491 _ (Unnamed Layer_ 446) [Shuffle]_constant, onnx__MatMul_498 _ (Unnamed Layer_ 534) [Shuffle]_constant, onnx__MatMul_499 _ (Unnamed Layer_ 551) [Shuffle]_constant
可以看出来在 layernorm 中使用 fp16 会导致有效精度位的溢出,带来的误差在网络中传播和累积,最终影响推理的结果。事实证明结果的偏差可能会非常的大,针对相同的推理文本,原文本分类模型的输出和使用 fp16 优化之后的模型输出对比,可以说是毫不相干。
| # | 模型输出的概率分布 |
| -------------- | ------------------------------------------------------------ |
| model.onnx | 3.540327E-4, 5.651282E-4, 5.4661214E-4, 0.01783814, 0.9699924, 6.4886943E-4, 7.39405E-4, 6.7994324E-4, 0.0048044445, 4.3710953E-4, 0.0033940033 |
| model.opt.plan | 0.0038471222, 0.0011329651, 5.207062E-4, 0.0011777878, 4.4202805E-4, 5.1641464E-4, 0.0015239716, 0.0019893646, 9.765625E-4, 0.0014209747, 0.9863281 |
总结一下,当前的问题是由于在模型中的 layernorm 层使用 fp16 精度运算时有效精度溢出,导致模型推理精度损失过大。解决思路如下:
- 将 pytorch 模型以半精度的方法导出为 onnx 模型,但 layernorm 层的权重数据维持 fp32 不变;
- 转换成功之后,需要改造 Triton Inference Client 来支持 fp16 推理请求和返回;
1. 解决精度损失问题
通过查询资料发现 torch >= 1.13.0 支持 ONNX opset=17,在转换的时候使用 INormalizationLayer 强制将 layernorm 层使用 fp32 的精度运算。于是回到将 pytorch 模型保存为 onnx 模型时,使用推荐的参数 opset=17 。这里可以一并将模型以半精度的方式保存。
model = PretrainedModelClassification().to("cuda:0")
model = model.half()
# ...
dummy_input = torch.ones([1, max_sequence_length], dtype=torch.int32).to("cuda:0")
torch.onnx.export(model,
dummy_input,
f=onnx_path,
opset_version=17,
input_names=['input'],
output_names=['logits', 'probability'],
do_constant_folding=True,
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}, 'probability': {0: 'batch_size'}})
⚠️ 在实验中发现,在转换的过程中还会有如下的地方需要注意 ⚠️ :
- 导出的时候,输入和模型都需要放在相同的 device 中,而且需要是在 cuda 中,因为 cpu 中不支持半精度计算;
- dummy_input = torch.ones([1, max_sequence_length]) 默认的数据类型是 float,建议将其手动设置为 int32,否则后面在推理的时候只能采用binary的形式,还要手动实现将半精度的浮点数用 short 表示,再转成 bytes 才可以进行推理;
- 模型的输出 logits,probability,在量化后都是 fp16 的精度,这就没办法了,只能在后面对 triton client 进行改造;
得到 model.onnx 之后,再使用 tensorrt api工具便可以顺利地得到优化模型 model.opt.plan。通过 Triton Inference Server 可以看到优化后的模型结构,按照预期已经输出为半精度的浮点数了,同时模型的体积也减少了一半。现在可以在 Triton 中部署半精度的 TensorRT 模型了。
{
"name": "model.opt.plan",
"platform": "tensorrt_plan",
"backend": "tensorrt",
"version_policy": {
"latest": {
"num_versions": 1
}
},
"max_batch_size": 40,
"input": [
{
"name": "input",
"data_type": "TYPE_INT32",
"format": "FORMAT_NONE",
"dims": [
300
],
"is_shape_tensor": false,
"allow_ragged_batch": false,
"optional": false
}
],
"output": [
{
"name": "logits",
"data_type": "TYPE_FP16",
"dims": [
11
],
"label_filename": "",
"is_shape_tensor": false
},
{
"name": "probability",
"data_type": "TYPE_FP16",
"dims": [
11
],
"label_filename": "labels.txt",
"is_shape_tensor": false
}
],
// ... ...
}
使用 Triton Client (Java) 构造请求对优化后的模型进行精度验证,发现之前通过编译获取的 Java Api 并不兼容半精度的浮点数,导致拿到的推理结果为 null,下面再来解决这个问题。
2. 解决Triton Client 不支持 FP16 的问题
2.1 半精度浮点数的构成
首先半精度的浮点数的组成可以参考 wiki。 在采用小端字节序下,最左边的高地址位(15位)是符号位 sign,14~10位是指数位 exponent ,9~0是小数位,也就是wiki中说的有效精度 fraction。可以看出有效精度的位数为10,换算成十进制就是1024,所以十进制精度只有3 ~ 4位。后面可以据此验证一下推理结果的误差范围是否合理。
例如:对于圆周率 3.14 的二进制表示为 0100 0010 0100 1000,对于这个二进制小数,我们带入计算公式:
计算得出:
sign = 0 (2)
exponent = 1 0000 (2)
fraction = 10 0100 1000 (2)
fp16 = ((-1)⁰) * (2^(16–15)) * (1 + 584/1024) = 3.140625
【例外的情况】:
- 如果 exponent 为全 0,那么计算公式为:((-1)^sign) * (2^(-14)) * (fraction / 1024)
0 00001 0000000000 = 2^−14 ≈ 6.10352E−5 (最小正指数)
0 00000 1111111111 = 2^−14–2^−24 ≈ 6.09756E−5 (最大尾数)
0 00000 0000000001 = 2^−24 ≈ 5.96046E−8 (最小正尾数)
- 如果 exponent 为全 1 时, fraction=0,表示 ±inf;fraction<>0,表示NaN
0 11111 0000000000 = infinity
1 11111 0000000000 = −infinity
0 11111 0000000001 = NaN (非有效数字)
2.2 构造 FP16 的推理请求
在构造 InferInputTensor 时,发现支持的数值型的数据类型有:int32、int64、uint32、unit64、float(fp32)、double(fp64)。经验证使用 triton-inference-server/client 编译的 Java Api 没有兼容半精度的模型推理。在转换的时候,我将模型的输入类型设置为 int32,从而规避了自行实现支持 fp16 的推理输入。
如果在其他需要输入浮点数模型中,无法避免地需要使用 fp16 的推理输入时,则需要手动转换(以 Java 为例):
// jdk20 引入了新的函数 floatToFloat16
// https://docs.oracle.com/en/java/javase/20/docs/api/java.base/java/lang/Float.html#floatToFloat16(float)
// jdk8 只能自己手动转了
public static short floatToHalfFloatInt(float f) {
int floatInt = Float.floatToIntBits(f);
return (short)(((floatInt >> 16) & 0x8000) | ((((floatInt >> 23) - 127 + 15) & 0x1f) << 10 )| ((floatInt >> 13) & 0x3ff));
}
// ...
ByteBuffer bb = ByteBuffer.allocate(floats.length * 2);
for (float f32 : floats) {
bb.putShort(floatToHalfFloatInt(f32));
}
InferInput inferInput = new InferInput("inputs", shape, DataType.BYTES);
inferInput.setData(bb.array(), isBinaryData);
2.3 获取 FP16 的推理结果
刚刚解决了模型输入为单精度浮点数的问题,接下来,模型的推理结果类型也是 fp16 。由于 Java 原始类型也没有半精度浮点数的实现,所以无法直接接收 fp16 的模型输出。针对这种情况 http 和 grpc 的 triton client 处理起来稍有不同。
【gRPC Inference Result】
通过 Triton Client 的 gRPC 接口得到的模型推理结果是一个 stream,如果是单精度的模型输出,可以直接将 stream 数据转换成一个 FloatBuffer。在半精度的模型输出时,需要将其指定 output 的输出内容取出来,剩下的工作就是要将 fp16 的 stream 数据转换成 fp32 的结果。
GrpcService.ModelInferResponse inferResponse = this.grpcClient.infer(modelName, version, inferInputs, inferOutputs);
// 单精度的模型输出,解析方法
FloatBuffer floatBuffer = inferResponse.getRawOutputContents(outputIndex).asReadOnlyByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
float[] opFP32 = new float[floatBuffer.remaining()];
floatBuffer.get(opFP32);
// 半精度的模型输出
ByteBuffer byteBuffer = inferResponse.getRawOutputContents(outputIndex).asReadOnlyByteBuffer();
// TODO: 将 fp16 的 stream 数据转换成 fp32 的结果 float[]
【HTTP Inference Result】
通过 Triton Http 接口获取的模型推理结果会被解析到 InferenceResponse 中。但由于 Java 中没有实现半精度的浮点数,导致数据类型转换失败。解决的办法是创建 InferRequestedOutput 的时候将指定的输出设置为 binary。
boolean isBinary = true;
for (TritonModelConfig.IO io : config.getOutput()) {
if (io.getName().equals(outputName)) {
inferOutputs.add(new InferRequestedOutput(io.getName(), isBinary));
}
}
这样便可以在模型推断结果 InferResult 的 binaryData 中获取推理结果。剩下的工作和前面一样,就是手动将 fp16 的 binaryData 转换成 FP32。
【fp16 的二进制数据转换成 float】
有了前面半精度浮点数的计算过程,就可以写出转换的函数。github gist / Fp16Utils.java
2.4 其他解决方法
当前的解决方案中使用 INormalizationLayer 强制将 layernorm 层使用 fp32 的精度运算。其中就无法避免有类型转换的操作,如果模型中这类操作比较频繁还会影响性能。还有一种方案也可以解决 fp16 精度问题,效果甚至会更好。那就是使用 tensorrt 的 plugin。根据文章:TensorRT Plugin的实现、调试与验证:以实现Layernorm为例 可以得知,这个 tensorrt 命令行工具的插件对 fp32/fp16 都是支持的。
如果在转换的过程中使用文中插件中实现的 layernorm 替换掉 onnx 的 LayerNorm,便可以在推理过程中省去 fp16/fp32 之间的转换操作(未经验证)。
3. 误差验证和性能测试
3.1 误差验证
最后,对转换后的模型 model.opt.plan 的推断结果和原模型 model.onnx 进行对比。计算两个推理结果的绝对误差总和SAD = np.sum(np.abs(opt_result — org_result))为:0.00030123752999996144,合理误差,符合预期。
3.2 性能测试
使用 perf_analyzer 工具进行测试,配置如下表所示,对原始的单精度模型 model.onnx 和优化后的半精度模型 model.opt.plan 进行对比。衡量性能的指标主要有:吞吐量 qps,及请求的平均时延 latency。
分别在并发为 1,11,21,31,41,优化前后的模型进行测试,结果如下:
通过 Triton Inference Server 的 Metrics 指标也可以观察出,和 model.opt.plan(黄色线)相比 model.onnx (绿色线)的请求队列的整体等待时间要长很多。
[全文完]
参考
- https://developer.nvidia.com/zh-cn/blog/tensorrt-trtexec-cn/
- https://www.hbblog.cn/%E6%A8%A1%E5%9E%8B%E9%83%A8%E7%BD%B2/2022%E5%B9%B408%E6%9C%8802%E6%97%A5%2023%E6%97%B628%E5%88%8628%E7%A7%92/#layernormplugin
- https://docs.nvidia.com/deeplearning/triton-inference-server/archives/triton_inference_server_1140/user-guide/docs/optimization.html
- https://en.wikipedia.org/wiki/Half-precision_floating-point_format
- https://www.ruanyifeng.com/blog/2022/06/endianness-analysis.html
- https://stackoverflow.com/a/6162687
- https://github.com/triton-inference-server/client/blob/main/src/c++/perf_analyzer/README.md