ほとんどのPyTorchユーザーは、TensorRTプラグインを使用して検出モデルの後処理ネットワークを構築し、モデルをTensorRTにエクスポートできるようにします。 AI (PAI) 向け機械学習プラットフォーム-ブレードは優れたスケーラビリティを備えています。 独自のTensorRTプラグインを開発している場合は、共同モデルの最適化にPAI-BladeプラグインとTensorRTプラグインを使用できます。 このトピックでは、PAI-Bladeを使用して、TensorRTプラグインを使用して後処理ネットワークを構築する検出モデルを最適化する方法について説明します。
背景情報
TensorRTは、NVIDIA GPUの推論最適化のための強力なツールです。 PAI − Bladeは、下層でTensorRTの最適化方法を深く統合する。 さらに、PAI-Bladeは、グラフ最適化、TensorRTやoneDNNなどの最適化ライブラリ、AIコンパイル最適化、最適化演算子ライブラリ、混合精度、EasyCompressionなど、複数の最適化テクノロジーを統合しています。
RetinaNetは、1ステージ領域ベースの畳み込みニューラルネットワーク (R-CNN) タイプの検出ネットワークです。 RetinaNetの基本構造は、バックボーン、複数のサブネットワーク、および非最大抑制 (NMS) で構成されています。 NMSは後処理アルゴリズムである。 RetinaNetは多くのトレーニングフレームワークで実装されています。 Detectron2は、RetinaNetを使用する典型的なトレーニングフレームワークです。 Detectron2のscripting_with_instances
メソッドを呼び出してRetinaNetモデルをエクスポートし、PAI-Bladeを使用してモデルを最適化できます。 詳細については、「PAI-Bladeを使用したDetectron2フレームワークのRetinaNetモデルの最適化」をご参照ください。
ほとんどのPyTorchユーザーは通常、Open Neural Network Exchange (ONNX) 形式でモデルをエクスポートし、TensorRTを使用してモデルをデプロイします。 ただし、ONNXモデルとTensorRTの両方で、ONNXオプセットのサポートは制限されています。 その結果、ONNXモデルをエクスポートし、TensorRTを使用してモデルを最適化するプロセスは、多くの場合、堅牢性に欠けます。 特に、検出モデルの後処理ネットワークをONNXモデルに直接エクスポートしてTensorRTを使用して最適化することはできません。 さらに、コードは、実際のシナリオにおける検出モデルの後処理ネットワークに対して非効率的な方法で実装される。 そのため、多くのユーザーはTensorRTプラグインを使用して検出モデルの後処理ネットワークを構築し、モデルをTensorRTにエクスポートできます。
PAI-BladeおよびTorchScriptのカスタムC ++ 演算子を使用して、モデルを最適化することもできます。 この方法は、TensorRTプラグインを使用して後処理ネットワークを構築する方法よりも使いやすいです。 PAI刃はよいスケーラビリティを特色にします。 独自のTensorRTプラグインを開発している場合は、共同モデルの最適化にPAI-BladeプラグインとTensorRTプラグインを使用できます。
制限事項
このトピックの手順で使用する環境は、次のバージョン要件を満たす必要があります。
システム環境: LinuxのPython 3.6以降、GCC 5.4以降、NVIDIA Tesla T4、CUDA 10.2、cuDNN 8.0.5.39、およびTensorRT 7.2.2.3
フレームワーク: PyTorch 1.8.1以降、およびDetectron2 0.4.1以降
推論最適化ツール: TensorRTをサポートするPAI-Blade V3.16.0以降
手順
PAI-BladeおよびTensorRTプラグインを使用してRetinaNetモデルを最適化するには、次の手順を実行します。
手順1: TensorRTプラグインを使用したPyTorchモデルの作成
TensorRTプラグインを使用して、RetinaNetモデルの後処理ネットワークを構築します。
blade.optimize
メソッドを呼び出してモデルを最適化し、最適化されたモデルを保存します。最適化されたモデルがパフォーマンステストに合格し、期待を満たす場合は、推論のために最適化されたモデルを読み込みます。
手順1: TensorRTプラグインを使用してPyTorchモデルを作成する
PAI-Bladeは、モデル最適化のためにTensorRTプラグインと連携できます。 この手順では、TensorRTプラグインを使用して、RetinaNetモデルの後処理ネットワークを構築する方法について説明します。 TensorRTプラグインを開発およびコンパイルする方法の詳細については、「NVIDIA Deep Learning TensorRTドキュメント」をご参照ください。 このトピックでは、RetinaNetモデルの後処理ネットワークのプログラムロジックは、NVIDIAのオープンソースコミュニティから提供されています。 詳細については、「retinanet-examples」をご参照ください。 この例では、コアコードを使用して、カスタム演算子を開発および実装する方法を示します。
サンプルコードをダウンロードし、ダウンロードしたパッケージを解凍します。
wget -nv https://pai-blade.oss-cn-zhangjiakou.aliyuncs.com/tutorials/retinanet_example/retinanet-examples.tar.gz -O retinanet-examples.tar.gz tar xvfz retinanet-examples.tar.gz 1>/dev/null
コンパイルTensorRTプラグイン。
サンプルコードには、RetinaNetモデルの後処理ネットワークのTensorRTプラグイン
decode
およびnms
を実装および登録するために使用できるコードが含まれています。 PyTorchは、カスタム演算子をコンパイルするための3つのメソッドを提供します。CMakeによるビルド、JITコンパイルによるビルド、setuptoolsによるビルドです。 詳細については、「カスタムC ++ オペレーターによるTORCHSCRIPTの拡張」をご参照ください。 これらの3つのコンパイル方法は、さまざまなシナリオに適しています。 必要に応じてメソッドを選択できます。 この例では、操作を簡素化するために、JITコンパイル方法を使用しています。 次のサンプルコードに例を示します。説明コンパイルする前に、TensorRT、CUDA、cuDNNなどの依存関係ライブラリを構成する必要があります。
import torch.utils.cpp_extension import os codebase="retinanet-examples" sources=['csrc/plugins/plugin.cpp', 'csrc/cuda/decode.cu', 'csrc/cuda/nms.cu',] sources = [os.path.join(codebase,src) for src in sources] torch.utils.cpp_extension.load( name="plugin", sources=sources, build_directory=codebase, extra_include_paths=['/usr/local/TensorRT/include/', '/usr/local/cuda/include/', '/usr/local/cuda/include/thrust/system/cuda/detail'], extra_cflags=['-std=c++14', '-O2', '-Wall'], extra_ldflags=['-L/usr/local/TensorRT/lib/', '-lnvinfer'], extra_cuda_cflags=[ '-std=c++14', '--expt-extended-lambda', '--use_fast_math', '-Xcompiler', '-Wall,-fno-gnu-unique', '-gencode=arch=compute_75,code=sm_75',], is_python_module=False, with_cuda=True, verbose=False, )
RetinaNetモデルの畳み込み部分をカプセル化します。
RetinaNetモデルの畳み込み部分を
RetinaNetBackboneAndHeads
オブジェクトにカプセル化します。import torch from typing import List from torch import Tensor from torch.testing import assert_allclose from detectron2 import model_zoo # This class encapsulates the backbone and region proposal network (RPN) heads parts of the RetinaNet model. class RetinaNetBackboneAndHeads(torch.nn.Module): def __init__(self, model): super().__init__() self.model = model def preprocess(self, img): batched_inputs = [{"image": img}] images = self.model.preprocess_image(batched_inputs) return images.tensor def forward(self, images): features = self.model.backbone(images) features = [features[f] for f in self.model.head_in_features] cls_heads, box_heads = self.model.head(features) cls_heads = [cls.sigmoid() for cls in cls_heads] box_heads = [b.contiguous() for b in box_heads] return cls_heads, box_heads retinanet_model = model_zoo.get("COCO-Detection/retinanet_R_50_FPN_3x.yaml", trained=True).eval() retinanet_bacbone_heads = RetinaNetBackboneAndHeads(retinanet_model)
開発したTensorRTプラグインを使用して、RetinaNetモデルの後処理ネットワークを構築します。 TensorRTエンジンを作成した場合は、この手順をスキップしてください。
TensorRTエンジンを作成します。
TensorRTプラグインを有効にするには、次の機能を実装する必要があります。
ctypes.cdll.LoadLibrary
メソッドを呼び出して、コンパイルされたplugin.soライブラリを動的に読み込みます。TensorRT Python APIの
build_retinanet_decode
メソッドを呼び出して、後処理ネットワークを構築し、TensorRTエンジンに組み込みます。
次のサンプルコードでは、例を示します。
import os import numpy as np import tensorrt as trt import ctypes # Load the plugin.so library. codebase="retinanet-examples" ctypes.cdll.LoadLibrary(os.path.join(codebase, 'plugin.so')) TRT_LOGGER = trt.Logger() trt.init_libnvinfer_plugins(TRT_LOGGER, "") PLUGIN_CREATORS = trt.get_plugin_registry().plugin_creator_list # Obtain the developed TensorRT plug-ins. def get_trt_plugin(plugin_name, field_collection): plugin = None for plugin_creator in PLUGIN_CREATORS: if plugin_creator.name != plugin_name: continue if plugin_name == "RetinaNetDecode": plugin = plugin_creator.create_plugin( name=plugin_name, field_collection=field_collection ) if plugin_name == "RetinaNetNMS": plugin = plugin_creator.create_plugin( name=plugin_name, field_collection=field_collection ) assert plugin is not None, "plugin not found" return plugin # Build a post-processing network and build it into a TensorRT engine. def build_retinanet_decode(example_outputs, input_image_shape, anchors_list, test_score_thresh = 0.05, test_nms_thresh = 0.5, test_topk_candidates = 1000, max_detections_per_image = 100, ): builder = trt.Builder(TRT_LOGGER) EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) network = builder.create_network(EXPLICIT_BATCH) config = builder.create_builder_config() config.max_workspace_size = 3 ** 20 cls_heads, box_heads = example_outputs profile = builder.create_optimization_profile() decode_scores = [] decode_boxes = [] decode_class = [] input_blob_names = [] input_blob_types = [] def _add_input(head_tensor, head_name): input_blob_names.append(head_name) input_blob_types.append("Float") head_shape = list(head_tensor.shape)[-3:] profile.set_shape( head_name, [1] + head_shape, [20] + head_shape, [1000] + head_shape) return network.add_input( name=head_name, dtype=trt.float32, shape=[-1] + head_shape ) # Build network inputs. cls_head_inputs = [] cls_head_strides = [input_image_shape[-1] // cls_head.shape[-1] for cls_head in cls_heads] for idx, cls_head in enumerate(cls_heads): cls_head_name = "cls_head" + str(idx) cls_head_inputs.append(_add_input(cls_head, cls_head_name)) box_head_inputs = [] for idx, box_head in enumerate(box_heads): box_head_name = "box_head" + str(idx) box_head_inputs.append(_add_input(box_head, box_head_name)) output_blob_names = [] output_blob_types = [] # Build decode network. for idx, anchors in enumerate(anchors_list): field_coll = trt.PluginFieldCollection([ trt.PluginField("topk_candidates", np.array([test_topk_candidates], dtype=np.int32), trt.PluginFieldType.INT32), trt.PluginField("score_thresh", np.array([test_score_thresh], dtype=np.float32), trt.PluginFieldType.FLOAT32), trt.PluginField("stride", np.array([cls_head_strides[idx]], dtype=np.int32), trt.PluginFieldType.INT32), trt.PluginField("num_anchors", np.array([anchors.numel()], dtype=np.int32), trt.PluginFieldType.INT32), trt.PluginField("anchors", anchors.contiguous().cpu().numpy().astype(np.float32), trt.PluginFieldType.FLOAT32),] ) decode_layer = network.add_plugin_v2( inputs=[cls_head_inputs[idx], box_head_inputs[idx]], plugin=get_trt_plugin("RetinaNetDecode", field_coll), ) decode_scores.append(decode_layer.get_output(0)) decode_boxes.append(decode_layer.get_output(1)) decode_class.append(decode_layer.get_output(2)) # Build NMS network. scores_layer = network.add_concatenation(decode_scores) boxes_layer = network.add_concatenation(decode_boxes) class_layer = network.add_concatenation(decode_class) field_coll = trt.PluginFieldCollection([ trt.PluginField("nms_thresh", np.array([test_nms_thresh], dtype=np.float32), trt.PluginFieldType.FLOAT32), trt.PluginField("max_detections_per_image", np.array([max_detections_per_image], dtype=np.int32), trt.PluginFieldType.INT32),] ) nms_layer = network.add_plugin_v2( inputs=[scores_layer.get_output(0), boxes_layer.get_output(0), class_layer.get_output(0)], plugin=get_trt_plugin("RetinaNetNMS", field_coll), ) nms_layer.get_output(0).name = "scores" nms_layer.get_output(1).name = "boxes" nms_layer.get_output(2).name = "classes" nms_outputs = [network.mark_output(nms_layer.get_output(k)) for k in range(3)] config.add_optimization_profile(profile) cuda_engine = builder.build_engine(network, config) assert cuda_engine is not None return cuda_engine
を返す
RetinaNetBackboneAndHeads
オブジェクトの出力数、出力タイプ、および出力形状に基づいてTensorRTエンジンを構築します。import numpy as np from detectron2.data.detection_utils import read_image # wget http://images.cocodataset.org/val2017/000000439715.jpg -q -O input.jpg img = read_image('./input.jpg') img = torch.from_numpy(np.ascontiguousarray(img.transpose(2, 0, 1))) example_inputs = retinanet_bacbone_heads.preprocess(img) example_outputs = retinanet_bacbone_heads(example_inputs) cell_anchors = [c.contiguous() for c in retinanet_model.anchor_generator.cell_anchors] cuda_engine = build_retinanet_decode( example_outputs, example_inputs.shape, cell_anchors)
RetinaNetモデルを再構築して、PyTorchとTensorRTエンジンの両方を使用できるようにします。
次のサンプルコードでは、
RetinaNetWrapper
、RetinaNetBackboneAndHeads
、およびRetinaNetPostProcess
クラスを使用して、バックボーンおよびRPNヘッドパーツ、およびRetinaNetモデルの後処理ネットワークを再構築する方法の例を示します。import blade.torch # Reassemble the post-processing network that is built by using the TensorRT plug-ins. class RetinaNetPostProcess(torch.nn.Module): def __init__(self, cuda_engine): super().__init__() blob_names = [cuda_engine.get_binding_name(idx) for idx in range(cuda_engine.num_bindings)] input_blob_names = blob_names[:-3] input_blob_types = ["Float"] * len(input_blob_names) output_blob_names = blob_names[-3:] output_blob_types = ["Float"] * len(output_blob_names) self.trt_ext_plugin = torch.classes.torch_addons.TRTEngineExtension( bytes(cuda_engine.serialize()), (input_blob_names, output_blob_names, input_blob_types, output_blob_types), ) def forward(self, inputs: List[Tensor]): return self.trt_ext_plugin.forward(inputs) # Reassemble the RetinaNet model to use both PyTorch and the TensorRT engine. class RetinaNetWrapper(torch.nn.Module): def __init__(self, model, trt_postproc): super().__init__() self.backbone_and_heads = model self.trt_postproc = torch.jit.script(trt_postproc) def forward(self, images): cls_heads, box_heads = self.backbone_and_heads(images) return self.trt_postproc(cls_heads + box_heads) trt_postproc = RetinaNetPostProcess(cuda_engine) retinanet_mix_trt = RetinaNetWrapper(retinanet_bacbone_heads, trt_postproc) # You can export and save the reassembled model as a TorchScript model. retinanet_script = torch.jit.trace(retinanet_mix_trt, (example_inputs, ), check_trace=False) torch.jit.save(retinanet_script, 'retinanet_script.pt') torch.save(example_inputs, 'example_inputs.pth') outputs = retinanet_script(example_inputs)
再構築された
torch.nn.Module
オブジェクトには、次の特性があります。TensorRTプラグインに基づく
torch.classes.torch_addons.TRTEngineExtension
クラスをサポートします。モデルをTorchScript形式でエクスポートできます。 この例では、
torch.jit.trace
メソッドを使用してモデルをエクスポートします。モデルをTorchScript形式で保存できます。
ステップ2: PAI-Bladeを使用してモデルを最適化する
PAI-bladeのBlade. optimizeメソッドを呼び出します。
モデルを最適化するには、
blade.optimize
メソッドを呼び出します。 次のサンプルコードに例を示します。blade.optimize
メソッドの詳細については、「PyTorchモデルの最適化」をご参照ください。import blade import blade.torch import ctypes import torch import os codebase="retinanet-examples" ctypes.cdll.LoadLibrary(os.path.join(codebase, 'plugin.so')) blade_config = blade.Config() blade_config.gpu_config.disable_fp16_accuracy_check = True script_model = torch.jit.load('retinanet_script.pt') example_inputs = torch.load('example_inputs.pth') test_data = [(example_inputs,)] # The test data used for a PyTorch model is a list of tuples of tensors. with blade_config: optimized_model, opt_spec, report = blade.optimize( script_model, # The TorchScript model exported in the previous step. 'o1', # The optimization level of PAI-Blade. In this example, the optimization level is o1. device_type='gpu', # The type of the device on which the model is run. In this example, the device is GPU. test_data=test_data, # The given set of test data, which facilitates optimization and testing. )
最適化レポートを表示し、最適化モデルを保存します。
PAI-Bladeを使用して最適化されたモデルは、依然としてTorchScriptモデルです。 最適化が完了したら、次のコードを実行して最適化レポートを表示し、最適化モデルを保存します
# Display the optimization report. print("Report: {}".format(report)) # Save the optimized model. torch.jit.save(optimized_model, 'optimized.pt')
最適化レポートの例を次に示します。 レポートのパラメーターの詳細については、「最適化レポート」をご参照ください。
Report: { "software_context": [ { "software": "pytorch", "version": "1.8.1+cu102" }, { "software": "cuda", "version": "10.2.0" } ], "hardware_context": { "device_type": "gpu", "microarchitecture": "T4" }, "user_config": "", "diagnosis": { "model": "unnamed.pt", "test_data_source": "user provided", "shape_variation": "undefined", "message": "Unable to deduce model inputs information (data type, shape, value range, etc.)", "test_data_info": "0 shape: (1, 3, 480, 640) data type: float32" }, "optimizations": [ { "name": "PtTrtPassFp16", "status": "effective", "speedup": "4.37", "pre_run": "40.59 ms", "post_run": "9.28 ms" } ], "overall": { "baseline": "40.02 ms", "optimized": "9.27 ms", "speedup": "4.32" }, "model_info": { "input_format": "torch_script" }, "compatibility_list": [ { "device_type": "gpu", "microarchitecture": "T4" } ], "model_sdk": {} }
元のモデルと最適化されたモデルのパフォーマンスをテストします。
次のサンプルコードは、モデルのパフォーマンスをテストする方法の例を示しています。
import time @torch.no_grad() def benchmark(model, inp): for i in range(100): model(inp) torch.cuda.synchronize() start = time.time() for i in range(200): model(inp) torch.cuda.synchronize() elapsed_ms = (time.time() - start) * 1000 print("Latency: {:.2f}".format(elapsed_ms / 200)) # Measure the speed of the original model. benchmark(script_model, example_inputs) # Measure the speed of the optimized model. benchmark(optimized_model, example_inputs)
このパフォーマンステストの次の結果は参考のためのものです:
Latency: 40.71 Latency: 9.35
上記の結果は、両方のモデルが200回実行された後、元のモデルの平均待ち時間は40.71ミリ秒であり、最適化されたモデルの平均待ち時間は9.35ミリ秒であることを示している。
ステップ3: 最適化されたモデルをロードして実行する
オプション: 試用期間中に、次の環境変数設定を追加して、認証の失敗によるプログラムの予期しない停止を防止します。
export BLADE_AUTH_USE_COUNTING=1
PAI-Bladeを使用するように認証されます。
export BLADE_REGION=<region> export BLADE_TOKEN=<token>
ビジネス要件に基づいて次のパラメーターを設定します。
<region>: PAI-Bladeを使用するリージョンです。 PAI-BladeユーザーのDingTalkグループに参加して、PAI-Bladeを使用できるリージョンを取得できます。
<token>: PAI-Bladeを使用するために必要な認証トークン。 PAI-BladeユーザーのDingTalkグループに参加して、認証トークンを取得できます。
最適化されたモデルをロードして実行します。
PAI-Bladeを使用して最適化されたモデルは、依然としてTorchScriptモデルです。 したがって、環境を変更せずに最適化モデルをロードできます。
import blade.runtime.torch import torch from torch.testing import assert_allclose import ctypes import os codebase="retinanet-examples" ctypes.cdll.LoadLibrary(os.path.join(codebase, 'plugin.so')) optimized_model = torch.jit.load('optimized.pt') example_inputs = torch.load('example_inputs.pth') with torch.no_grad(): pred = optimized_model(example_inputs)