JAX 示例#
在本节中,我们将演示一个端到端示例,说明如何在 Python 后端中使用 JAX。
创建 JAX AddSub 模型仓库#
我们将使用此示例附带的文件来创建模型仓库。
首先,将 client.py、config.pbtxt 和 model.py 下载到您的本地计算机。
接下来,在三个文件所在的目录中,使用以下命令创建模型仓库
mkdir -p models/jax/1
mv model.py models/jax/1
mv config.pbtxt models/jax
拉取 Triton Docker 镜像#
在继续之前,我们需要安装 Docker 和 NVIDIA Container Toolkit,请参阅安装步骤。
要拉取最新的容器,请运行以下命令
docker pull nvcr.io/nvidia/tritonserver:<yy.mm>-py3
docker pull nvcr.io/nvidia/tritonserver:<yy.mm>-py3-sdk
有关 <yy.mm>
版本,请参阅上面的安装步骤。
在撰写本文时,最新版本是 23.04
,这转化为以下命令
docker pull nvcr.io/nvidia/tritonserver:23.04-py3
docker pull nvcr.io/nvidia/tritonserver:23.04-py3-sdk
请务必将 <yy.mm>
替换为您为本示例所有剩余部分拉取的版本。
启动 Triton 服务器#
在创建 JAX 模型的目录(“models”文件夹所在的目录)中,运行以下命令
docker run --gpus all -it --rm -p 8000:8000 -v `pwd`:/jax nvcr.io/nvidia/tritonserver:<yy.mm>-py3 /bin/bash
在容器内部,我们需要安装 JAX 才能运行此示例。
我们建议使用 JAX 文档中提到的 pip
方法。 确保 JAX 在与其他依赖项相同的 Python 环境中可用。
要为此示例安装,请运行以下命令
pip3 install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
最后,我们需要启动 Triton 服务器,运行以下命令
tritonserver --model-repository=/jax/models
要离开容器进行下一步,请按:CTRL + P + Q
。
测试推理#
在 client.py 所在的目录中,运行以下命令
docker run --rm --net=host -v `pwd`:/jax nvcr.io/nvidia/tritonserver:<yy.mm>-py3-sdk python3 /jax/client.py
成功的推理将在末尾打印以下内容
INPUT0 ([0.89262384 0.645457 0.18913145 0.17099917]) + INPUT1 ([0.5703733 0.21917151 0.22854741 0.97336507]) = OUTPUT0 ([1.4629972 0.86462855 0.41767886 1.1443642 ])
INPUT0 ([0.89262384 0.645457 0.18913145 0.17099917]) - INPUT1 ([0.5703733 0.21917151 0.22854741 0.97336507]) = OUTPUT0 ([ 0.32225055 0.4262855 -0.03941596 -0.8023659 ])
PASS: jax
注意:您的输入可能与上述内容不同,但输出始终与其输入相对应。