部署 TensorFlow 模型#
本 README 展示了如何在 Triton 推理服务器上部署一个简单的 ResNet 模型。
步骤 1:导出模型#
将 TensorFlow 模型导出为保存的模型。
# <xx.xx> is the yy:mm for the publishing tag for NVIDIA's Tensorflow
# container; eg. 22.04
docker run -it --gpus all -v ${PWD}:/workspace nvcr.io/nvidia/tensorflow:<xx.xx>-tf2-py3
python export.py
步骤 2:设置 Triton 推理服务器#
要使用 Triton,我们需要构建一个模型仓库。仓库的结构如下
model_repository
|
+-- resnet50
|
+-- config.pbtxt
+-- 1
|
+-- model.savedmodel
|
+-- saved_model.pb
+-- variables
|
+-- variables.data-00000-of-00001
+-- variables.index
此演示附带了一个模型示例配置,文件名为 config.pbtxt
。如果您是 Triton 新手,强烈建议查看概念指南的第 1 部分。
docker run --gpus all --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 -v ${PWD}/model_repository:/models nvcr.io/nvidia/tritonserver:<xx.yy>-py3 tritonserver --model-repository=/models --backend-config=tensorflow,version=2
步骤 3:使用 Triton 客户端查询服务器#
安装依赖项并下载示例图像以测试推理。
docker run -it --net=host -v ${PWD}:/workspace/ nvcr.io/nvidia/tritonserver:<yy.mm>-py3-sdk bash
pip install --upgrade tensorflow
pip install image
wget -O img1.jpg "https://www.hakaimagazine.com/wp-content/uploads/header-gulf-birds.jpg"
构建客户端需要三个基本要点。首先,我们建立与 Triton 推理服务器的连接。
triton_client = httpclient.InferenceServerClient(url="localhost:8000")
其次,我们指定模型的输入和输出层名称。
inputs = httpclient.InferInput("input_1", transformed_img.shape, datatype="FP32")
inputs.set_data_from_numpy(transformed_img, binary_data=True)
output = httpclient.InferRequestedOutput("predictions", binary_data=True, class_count=1000)
最后,我们向 Triton 推理服务器发送推理请求。
# Querying the server
results = triton_client.infer(model_name="resnet50", inputs=[inputs], outputs=[output])
predictions = results.as_numpy('predictions')
print(predictions)
相同的输出应如下所示
[b'0.301167:90' b'0.169790:14' b'0.161309:92' b'0.093105:94'
b'0.058743:136' b'0.050185:11' b'0.033802:91' b'0.011760:88'
b'0.008309:989' b'0.004927:95' b'0.004905:13' b'0.004095:317'
b'0.004006:96' b'0.003694:12' b'0.003526:42' b'0.003390:313'
...
b'0.000001:751' b'0.000001:685' b'0.000001:408' b'0.000001:116'
b'0.000001:627' b'0.000001:933' b'0.000000:661' b'0.000000:148']
此处的输出格式为 <confidence_score>:<classification_index>
。要了解如何将这些映射到标签名称等,请参阅我们的文档。上面的客户端代码位于 client.py
中。