为了使TensorRT Plugin生效,需要实现以下功能:
- 通过
ctypes.cdll.LoadLibrary
动态加载编译好的plugin.so。
build_retinanet_decode
通过tensorrt
Python API构建后处理网络并将其Build成为Engine。
示例代码如下。
import os
import numpy as np
import tensorrt as trt
import ctypes
# 加载TensorRT Plugin动态链接库。
codebase="retinanet-examples"
ctypes.cdll.LoadLibrary(os.path.join(codebase, 'plugin.so'))
TRT_LOGGER = trt.Logger()
trt.init_libnvinfer_plugins(TRT_LOGGER, "")
PLUGIN_CREATORS = trt.get_plugin_registry().plugin_creator_list
# 获取TensorRT Plugin的函数。
def get_trt_plugin(plugin_name, field_collection):
plugin = None
for plugin_creator in PLUGIN_CREATORS:
if plugin_creator.name != plugin_name:
continue
if plugin_name == "RetinaNetDecode":
plugin = plugin_creator.create_plugin(
name=plugin_name, field_collection=field_collection
)
if plugin_name == "RetinaNetNMS":
plugin = plugin_creator.create_plugin(
name=plugin_name, field_collection=field_collection
)
assert plugin is not None, "plugin not found"
return plugin
# 构建TensorRT网络的函数。
def build_retinanet_decode(example_outputs,
input_image_shape,
anchors_list,
test_score_thresh = 0.05,
test_nms_thresh = 0.5,
test_topk_candidates = 1000,
max_detections_per_image = 100,
):
builder = trt.Builder(TRT_LOGGER)
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(EXPLICIT_BATCH)
config = builder.create_builder_config()
config.max_workspace_size = 3 ** 20
cls_heads, box_heads = example_outputs
profile = builder.create_optimization_profile()
decode_scores = []
decode_boxes = []
decode_class = []
input_blob_names = []
input_blob_types = []
def _add_input(head_tensor, head_name):
input_blob_names.append(head_name)
input_blob_types.append("Float")
head_shape = list(head_tensor.shape)[-3:]
profile.set_shape(
head_name, [1] + head_shape, [20] + head_shape, [1000] + head_shape)
return network.add_input(
name=head_name, dtype=trt.float32, shape=[-1] + head_shape
)
# Build network inputs.
cls_head_inputs = []
cls_head_strides = [input_image_shape[-1] // cls_head.shape[-1] for cls_head in cls_heads]
for idx, cls_head in enumerate(cls_heads):
cls_head_name = "cls_head" + str(idx)
cls_head_inputs.append(_add_input(cls_head, cls_head_name))
box_head_inputs = []
for idx, box_head in enumerate(box_heads):
box_head_name = "box_head" + str(idx)
box_head_inputs.append(_add_input(box_head, box_head_name))
output_blob_names = []
output_blob_types = []
# Build decode network.
for idx, anchors in enumerate(anchors_list):
field_coll = trt.PluginFieldCollection([
trt.PluginField("topk_candidates", np.array([test_topk_candidates], dtype=np.int32), trt.PluginFieldType.INT32),
trt.PluginField("score_thresh", np.array([test_score_thresh], dtype=np.float32), trt.PluginFieldType.FLOAT32),
trt.PluginField("stride", np.array([cls_head_strides[idx]], dtype=np.int32), trt.PluginFieldType.INT32),
trt.PluginField("num_anchors", np.array([anchors.numel()], dtype=np.int32), trt.PluginFieldType.INT32),
trt.PluginField("anchors", anchors.contiguous().cpu().numpy().astype(np.float32), trt.PluginFieldType.FLOAT32),]
)
decode_layer = network.add_plugin_v2(
inputs=[cls_head_inputs[idx], box_head_inputs[idx]],
plugin=get_trt_plugin("RetinaNetDecode", field_coll),
)
decode_scores.append(decode_layer.get_output(0))
decode_boxes.append(decode_layer.get_output(1))
decode_class.append(decode_layer.get_output(2))
# Build NMS network.
scores_layer = network.add_concatenation(decode_scores)
boxes_layer = network.add_concatenation(decode_boxes)
class_layer = network.add_concatenation(decode_class)
field_coll = trt.PluginFieldCollection([
trt.PluginField("nms_thresh", np.array([test_nms_thresh], dtype=np.float32), trt.PluginFieldType.FLOAT32),
trt.PluginField("max_detections_per_image", np.array([max_detections_per_image], dtype=np.int32), trt.PluginFieldType.INT32),]
)
nms_layer = network.add_plugin_v2(
inputs=[scores_layer.get_output(0), boxes_layer.get_output(0), class_layer.get_output(0)],
plugin=get_trt_plugin("RetinaNetNMS", field_coll),
)
nms_layer.get_output(0).name = "scores"
nms_layer.get_output(1).name = "boxes"
nms_layer.get_output(2).name = "classes"
nms_outputs = [network.mark_output(nms_layer.get_output(k)) for k in range(3)]
config.add_optimization_profile(profile)
cuda_engine = builder.build_engine(network, config)
assert cuda_engine is not None
return cuda_engine