Most PyTorch users use TensorRT plug-ins to build a post-processing network for a detection model so that they can export the model to TensorRT. Machine Learning Platform for AI (PAI)-Blade features good scalability. If you have developed your own TensorRT plug-ins, you can use PAI-Blade and TensorRT plug-ins for collaborative model optimization. This article describes how to use PAI-Blade to optimize a detection model whose post-processing network is built by using TensorRT plug-ins.
TensorRT is a powerful tool for inference optimization on NVIDIA GPUs. PAI-Blade deeply integrates the optimization methods of TensorRT at the underlying layer. In addition, PAI-Blade integrates multiple optimization technologies, including graph optimization, optimization libraries such as TensorRT and oneDNN, AI compilation optimization, an optimization operator library, mixed precision, and EasyCompression.
RetinaNet is a detection network of the One-Stage Region-based Convolutional Neural Network (R-CNN) type. The basic structure of RetinaNet consists of a backbone, multiple subnetworks, and Non-Maximum Suppression (NMS). NMS is a post-processing algorithm. RetinaNet is implemented in many training frameworks. Detectron2 is a typical training framework that uses RetinaNet. You can call the scripting_with_instances
method of Detectron2 to export a RetinaNet model and use PAI-Blade to optimize the model. For more information, see Use PAI-Blade to optimize a RetinaNet model that is in the Detectron2 framework.
Most PyTorch users usually export models in the Open Neural Network Exchange (ONNX) format and then deploy the models by using TensorRT. However, both ONNX models and TensorRT provide limited support for ONNX opsets. As a result, the process of exporting an ONNX model and optimizing the model by using TensorRT lacks robustness in many cases. In particular, the post-processing network of a detection model cannot be directly exported to an ONNX model and optimized by using TensorRT. In addition, the code is implemented in an inefficient way for the post-processing network of a detection model in actual scenarios. Therefore, many users use TensorRT plug-ins to build a post-processing network for a detection model so that they can export the model to TensorRT.
You can also use PAI-Blade and TorchScript custom C++ operators to optimize a model. This method is easier to use than the method of building a post-processing network by using TensorRT plug-ins. PAI-Blade features good scalability. If you have developed your own TensorRT plug-ins, you can use PAI-Blade and TensorRT plug-ins for collaborative model optimization.
The environment used for the procedure in this topic must meet the following version requirements:
To use PAI-Blade and TensorRT plug-ins to optimize a RetinaNet model, perform the following steps:
Step 1: Create a PyTorch model by using TensorRT plug-ins
Use TensorRT plug-ins to build a post-processing network for the RetinaNet model.
Step 2: Use PAI-Blade to optimize the model
Call the blade.optimize
method to optimize the model, and save the optimized model.
Step 3: Load and run the optimized model
If the optimized model passes the performance testing and meets your expectations, load the optimized model for inference.
PAI-Blade can collaborate with TensorRT plug-ins for model optimization. This step describes how to use TensorRT plug-ins to build a post-processing network for the RetinaNet model. For more information about how to develop and compile TensorRT plug-ins, see NVIDIA Deep Learning TensorRT Documentation. In this article, the program logic for the post-processing network of the RetinaNet model comes from the open source community of NVIDIA. For more information, see retinanet-examples. The core code is used in this example to show you how to develop and implement custom operators.
1. Download the sample code and decompress the downloaded package.
wget -nv https://pai-blade.oss-cn-zhangjiakou.aliyuncs.com/tutorials/retinanet_example/retinanet-examples.tar.gz -O retinanet-examples.tar.gz
tar xvfz retinanet-examples.tar.gz 1>/dev/null
2. Compile TensorRT plug-ins.
The sample code contains that you can use to implement and register the TensorRT plug-ins decode
and nms
for the post-processing network of the RetinaNet model. PyTorch provides three methods to compile custom operators: building with CMake, building with JIT compilation, and building with setuptools. For more information, see EXTENDING TORCHSCRIPT WITH CUSTOM C++ OPERATORS. These three compilation methods are suitable for different scenarios. You can select a method based on your needs. In this example, the building with JIT compilation method is used to simplify operations. The following sample code provides an example:
Note: Before compilation, you must configure dependency libraries such as TensorRT, CUDA, and cuDNN.
import torch.utils.cpp_extension
import os
codebase="retinanet-examples"
sources=['csrc/plugins/plugin.cpp',
'csrc/cuda/decode.cu',
'csrc/cuda/nms.cu',]
sources = [os.path.join(codebase,src) for src in sources]
torch.utils.cpp_extension.load(
name="plugin",
sources=sources,
build_directory=codebase,
extra_include_paths=['/usr/local/TensorRT/include/', '/usr/local/cuda/include/', '/usr/local/cuda/include/thrust/system/cuda/detail'],
extra_cflags=['-std=c++14', '-O2', '-Wall'],
extra_ldflags=['-L/usr/local/TensorRT/lib/', '-lnvinfer'],
extra_cuda_cflags=[
'-std=c++14', '--expt-extended-lambda',
'--use_fast_math', '-Xcompiler', '-Wall,-fno-gnu-unique',
'-gencode=arch=compute_75,code=sm_75',],
is_python_module=False,
with_cuda=True,
verbose=False,
)
3. Encapsulate the convolution parts of the RetinaNet model.
Encapsulate the convolution parts of the RetinaNet model into a RetinaNetBackboneAndHeads
object.
import torch
from typing import List
from torch import Tensor
from torch.testing import assert_allclose
from detectron2 import model_zoo
# This class encapsulates the backbone and region proposal network (RPN) heads parts of the RetinaNet model.
class RetinaNetBackboneAndHeads(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def preprocess(self, img):
batched_inputs = [{"image": img}]
images = self.model.preprocess_image(batched_inputs)
return images.tensor
def forward(self, images):
features = self.model.backbone(images)
features = [features[f] for f in self.model.head_in_features]
cls_heads, box_heads = self.model.head(features)
cls_heads = [cls.sigmoid() for cls in cls_heads]
box_heads = [b.contiguous() for b in box_heads]
return cls_heads, box_heads
retinanet_model = model_zoo.get("COCO-Detection/retinanet_R_50_FPN_3x.yaml", trained=True).eval()
retinanet_bacbone_heads = RetinaNetBackboneAndHeads(retinanet_model)
4. Use the developed TensorRT plug-ins to build a post-processing network for the RetinaNet model. If you have created a TensorRT engine, skip this step.
a) Create a TensorRT engine.
To make the TensorRT plug-ins take effect, you must implement the following features:
ctypes.cdll.LoadLibrary
method to dynamically load the compiled plugin.so library.build_retinanet_decode
method of TensorRT Python API to build a post-processing network and build it into a TensorRT engine.The following sample code provides an example:
import os
import numpy as np
import tensorrt as trt
import ctypes
# Load the plugin.so library.
codebase="retinanet-examples"
ctypes.cdll.LoadLibrary(os.path.join(codebase, 'plugin.so'))
TRT_LOGGER = trt.Logger()
trt.init_libnvinfer_plugins(TRT_LOGGER, "")
PLUGIN_CREATORS = trt.get_plugin_registry().plugin_creator_list
# Obtain the developed TensorRT plug-ins.
def get_trt_plugin(plugin_name, field_collection):
plugin = None
for plugin_creator in PLUGIN_CREATORS:
if plugin_creator.name != plugin_name:
continue
if plugin_name == "RetinaNetDecode":
plugin = plugin_creator.create_plugin(
name=plugin_name, field_collection=field_collection
)
if plugin_name == "RetinaNetNMS":
plugin = plugin_creator.create_plugin(
name=plugin_name, field_collection=field_collection
)
assert plugin is not None, "plugin not found"
return plugin
# Build a post-processing network and build it into a TensorRT engine.
def build_retinanet_decode(example_outputs,
input_image_shape,
anchors_list,
test_score_thresh = 0.05,
test_nms_thresh = 0.5,
test_topk_candidates = 1000,
max_detections_per_image = 100,
):
builder = trt.Builder(TRT_LOGGER)
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(EXPLICIT_BATCH)
config = builder.create_builder_config()
config.max_workspace_size = 3 ** 20
cls_heads, box_heads = example_outputs
profile = builder.create_optimization_profile()
decode_scores = []
decode_boxes = []
decode_class = []
input_blob_names = []
input_blob_types = []
def _add_input(head_tensor, head_name):
input_blob_names.append(head_name)
input_blob_types.append("Float")
head_shape = list(head_tensor.shape)[-3:]
profile.set_shape(
head_name, [1] + head_shape, [20] + head_shape, [1000] + head_shape)
return network.add_input(
name=head_name, dtype=trt.float32, shape=[-1] + head_shape
)
# Build network inputs.
cls_head_inputs = []
cls_head_strides = [input_image_shape[-1] // cls_head.shape[-1] for cls_head in cls_heads]
for idx, cls_head in enumerate(cls_heads):
cls_head_name = "cls_head" + str(idx)
cls_head_inputs.append(_add_input(cls_head, cls_head_name))
box_head_inputs = []
for idx, box_head in enumerate(box_heads):
box_head_name = "box_head" + str(idx)
box_head_inputs.append(_add_input(box_head, box_head_name))
output_blob_names = []
output_blob_types = []
# Build decode network.
for idx, anchors in enumerate(anchors_list):
field_coll = trt.PluginFieldCollection([
trt.PluginField("topk_candidates", np.array([test_topk_candidates], dtype=np.int32), trt.PluginFieldType.INT32),
trt.PluginField("score_thresh", np.array([test_score_thresh], dtype=np.float32), trt.PluginFieldType.FLOAT32),
trt.PluginField("stride", np.array([cls_head_strides[idx]], dtype=np.int32), trt.PluginFieldType.INT32),
trt.PluginField("num_anchors", np.array([anchors.numel()], dtype=np.int32), trt.PluginFieldType.INT32),
trt.PluginField("anchors", anchors.contiguous().cpu().numpy().astype(np.float32), trt.PluginFieldType.FLOAT32),]
)
decode_layer = network.add_plugin_v2(
inputs=[cls_head_inputs[idx], box_head_inputs[idx]],
plugin=get_trt_plugin("RetinaNetDecode", field_coll),
)
decode_scores.append(decode_layer.get_output(0))
decode_boxes.append(decode_layer.get_output(1))
decode_class.append(decode_layer.get_output(2))
# Build NMS network.
scores_layer = network.add_concatenation(decode_scores)
boxes_layer = network.add_concatenation(decode_boxes)
class_layer = network.add_concatenation(decode_class)
field_coll = trt.PluginFieldCollection([
trt.PluginField("nms_thresh", np.array([test_nms_thresh], dtype=np.float32), trt.PluginFieldType.FLOAT32),
trt.PluginField("max_detections_per_image", np.array([max_detections_per_image], dtype=np.int32), trt.PluginFieldType.INT32),]
)
nms_layer = network.add_plugin_v2(
inputs=[scores_layer.get_output(0), boxes_layer.get_output(0), class_layer.get_output(0)],
plugin=get_trt_plugin("RetinaNetNMS", field_coll),
)
nms_layer.get_output(0).name = "scores"
nms_layer.get_output(1).name = "boxes"
nms_layer.get_output(2).name = "classes"
nms_outputs = [network.mark_output(nms_layer.get_output(k)) for k in range(3)]
config.add_optimization_profile(profile)
cuda_engine = builder.build_engine(network, config)
assert cuda_engine is not None
return cuda_engine
b) Build the TensorRT engine based on the number of outputs, output types, and output shapes of the RetinaNetBackboneAndHeads
object.
import numpy as np
from detectron2.data.detection_utils import read_image
# wget http://images.cocodataset.org/val2017/000000439715.jpg -q -O input.jpg
img = read_image('./input.jpg')
img = torch.from_numpy(np.ascontiguousarray(img.transpose(2, 0, 1)))
example_inputs = retinanet_bacbone_heads.preprocess(img)
example_outputs = retinanet_bacbone_heads(example_inputs)
cell_anchors = [c.contiguous() for c in retinanet_model.anchor_generator.cell_anchors]
cuda_engine = build_retinanet_decode(
example_outputs, example_inputs.shape, cell_anchors)
5. Reassemble the RetinaNet model so that the model can use both PyTorch and the TensorRT engine.
The following sample code provides an example on how to reassemble the backbone and RPN heads parts, and the post-processing network of the RetinaNet model by using the RetinaNetWrapper
, RetinaNetBackboneAndHeads
, and RetinaNetPostProcess
classes.
import blade.torch
# Reassemble the post-processing network that is built by using the TensorRT plug-ins.
class RetinaNetPostProcess(torch.nn.Module):
def __init__(self, cuda_engine):
super().__init__()
blob_names = [cuda_engine.get_binding_name(idx) for idx in range(cuda_engine.num_bindings)]
input_blob_names = blob_names[:-3]
input_blob_types = ["Float"] * len(input_blob_names)
output_blob_names = blob_names[-3:]
output_blob_types = ["Float"] * len(output_blob_names)
self.trt_ext_plugin = torch.classes.torch_addons.TRTEngineExtension(
bytes(cuda_engine.serialize()),
(input_blob_names, output_blob_names, input_blob_types, output_blob_types),
)
def forward(self, inputs: List[Tensor]):
return self.trt_ext_plugin.forward(inputs)
# Reassemble the RetinaNet model to use both PyTorch and the TensorRT engine.
class RetinaNetWrapper(torch.nn.Module):
def __init__(self, model, trt_postproc):
super().__init__()
self.backbone_and_heads = model
self.trt_postproc = torch.jit.script(trt_postproc)
def forward(self, images):
cls_heads, box_heads = self.backbone_and_heads(images)
return self.trt_postproc(cls_heads + box_heads)
trt_postproc = RetinaNetPostProcess(cuda_engine)
retinanet_mix_trt = RetinaNetWrapper(retinanet_bacbone_heads, trt_postproc)
# You can export and save the reassembled model as a TorchScript model.
retinanet_script = torch.jit.trace(retinanet_mix_trt, (example_inputs, ), check_trace=False)
torch.jit.save(retinanet_script, 'retinanet_script.pt')
torch.save(example_inputs, 'example_inputs.pth')
outputs = retinanet_script(example_inputs)
The reassembled torch.nn.Module
object has the following characteristics:
torch.classes.torch_addons.TRTEngineExtension
class based on the TensorRT plug-ins.torch.jit.trace
method is used to export the model.1. Call the blade.optimize method of PAI-Blade.
Call the blade.optimize
method to optimize the model. The following sample code provides an example. For more information about the blade.optimize
method, see Optimize a PyTorch model.
import blade
import blade.torch
import ctypes
import torch
import os
codebase="retinanet-examples"
ctypes.cdll.LoadLibrary(os.path.join(codebase, 'plugin.so'))
blade_config = blade.Config()
blade_config.gpu_config.disable_fp16_accuracy_check = True
script_model = torch.jit.load('retinanet_script.pt')
example_inputs = torch.load('example_inputs.pth')
test_data = [(example_inputs,)] # The test data used for a PyTorch model is a list of tuples of tensors.
with blade_config:
optimized_model, opt_spec, report = blade.optimize(
script_model, # The TorchScript model exported in the previous step.
'o1', # The optimization level of PAI-Blade. In this example, the optimization level is o1.
device_type='gpu', # The type of the device on which the model is run. In this example, the device is GPU.
test_data=test_data, # The given set of test data, which facilitates optimization and testing.
)
2. Display the optimization report and save the optimized model.
The model optimized by using PAI-Blade is still a TorchScript model. After the optimization is complete, you can run the following code to display the optimization report and save the optimized model:
# Display the optimization report.
print("Report: {}".format(report))
# Save the optimized model.
torch.jit.save(optimized_model, 'optimized.pt')
The following sample code provides an example of the optimization report. For more information about the parameters in the report, see Optimization report.
Report: {
"software_context": [
{
"software": "pytorch",
"version": "1.8.1+cu102"
},
{
"software": "cuda",
"version": "10.2.0"
}
],
"hardware_context": {
"device_type": "gpu",
"microarchitecture": "T4"
},
"user_config": "",
"diagnosis": {
"model": "unnamed.pt",
"test_data_source": "user provided",
"shape_variation": "undefined",
"message": "Unable to deduce model inputs information (data type, shape, value range, etc.)",
"test_data_info": "0 shape: (1, 3, 480, 640) data type: float32"
},
"optimizations": [
{
"name": "PtTrtPassFp16",
"status": "effective",
"speedup": "4.37",
"pre_run": "40.59 ms",
"post_run": "9.28 ms"
}
],
"overall": {
"baseline": "40.02 ms",
"optimized": "9.27 ms",
"speedup": "4.32"
},
"model_info": {
"input_format": "torch_script"
},
"compatibility_list": [
{
"device_type": "gpu",
"microarchitecture": "T4"
}
],
"model_sdk": {}
}
3. Test the performance of the original model and the optimized model.
The following sample code provides an example on how to test the performance of the models:
import time
@torch.no_grad()
def benchmark(model, inp):
for i in range(100):
model(inp)
torch.cuda.synchronize()
start = time.time()
for i in range(200):
model(inp)
torch.cuda.synchronize()
elapsed_ms = (time.time() - start) * 1000
print("Latency: {:.2f}".format(elapsed_ms / 200))
# Measure the speed of the original model.
benchmark(script_model, example_inputs)
# Measure the speed of the optimized model.
benchmark(optimized_model, example_inputs)
The following results of this performance testing are for your reference:
Latency: 40.71
Latency: 9.35
The preceding results show that after both the models are run for 200 times, the average latency of the original model is 40.71 ms and the average latency of the optimized model is 9.35 ms.
1. Optional: During the trial period, add the following environment variable setting to prevent the program from unexpected quits due to an authentication failure:
export BLADE_AUTH_USE_COUNTING=1
2. Get authenticated to use PAI-Blade.
export BLADE_REGION=<region>
export BLADE_TOKEN=<token>
Configure the following parameters based on your business requirements:
<region>
: the region where you use PAI-Blade. You can join the DingTalk group of PAI-Blade users to obtain the regions where PAI-Blade can be used.<token>
: the authentication token that is required to use PAI-Blade. You can join the DingTalk group of PAI-Blade users to obtain the authentication token.3. Load and run the optimized model.
The model optimized by using PAI-Blade is still a TorchScript model. Therefore, you can load the optimized model without changing the environment.
import blade.runtime.torch
import torch
from torch.testing import assert_allclose
import ctypes
import os
codebase="retinanet-examples"
ctypes.cdll.LoadLibrary(os.path.join(codebase, 'plugin.so'))
optimized_model = torch.jit.load('optimized.pt')
example_inputs = torch.load('example_inputs.pth')
with torch.no_grad():
pred = optimized_model(example_inputs)
Use EAS and ApsaraDB RDS for PostgreSQL to Deploy a RAG-Based LLM Chatbot
Use PAI-Blade to Optimize a RetinaNet Model in the Detectron2 Framework
40 posts | 1 followers
FollowAlibaba Cloud Data Intelligence - November 28, 2024
Alibaba Cloud Data Intelligence - November 28, 2024
Alibaba Cloud Data Intelligence - September 6, 2023
Alibaba Cloud Native Community - March 18, 2024
Alibaba Cloud Native Community - April 2, 2024
PM - C2C_Yuan - March 18, 2024
40 posts | 1 followers
FollowA platform that provides enterprise-level data modeling services based on machine learning algorithms to quickly meet your needs for data-driven operations.
Learn MoreOffline SDKs for visual production, such as image segmentation, video segmentation, and character recognition, based on deep learning technologies developed by Alibaba Cloud.
Learn MoreAccelerate AI-driven business and AI model training and inference with Alibaba Cloud GPU technology
Learn MoreTop-performance foundation models from Alibaba Cloud
Learn MoreMore Posts by Alibaba Cloud Data Intelligence