TopicControl 入门#

先决条件#

  • 具有 Docker Engine 的主机。请参阅 Docker 的说明

  • 已安装并配置 NVIDIA Container Toolkit。请参阅工具包文档中的 安装

  • NVIDIA AI Enterprise 产品的有效订阅,或成为 NVIDIA 开发者计划成员。对容器和模型的访问受到限制。

  • NGC API 密钥。容器使用此密钥向 NVIDIA API Catalog 模型发送推理请求。有关更多信息,请参阅 NVIDIA NGC 用户指南 中的 生成您的 NGC API 密钥

    当您创建 NGC API 个人密钥时,请从包含的服务菜单中至少选择 NGC Catalog。您可以指定更多服务以将密钥用于其他目的。

启动 NIM 容器#

  1. 登录到 NVIDIA NGC,以便您可以拉取容器。

    1. 将您的 NGC API 密钥导出为环境变量

      $ export NGC_API_KEY="<nvapi-...>"
      
    2. 登录到注册表

      $ docker login nvcr.io --username '$oauthtoken' --password-stdin <<< $NGC_API_KEY
      
  2. 下载容器

    $ docker pull nvcr.io/nim/nvidia/llama-3.1-nemoguard-8b-topic-control:1.0.0
    
  3. 在主机上创建模型缓存目录

    $ export LOCAL_NIM_CACHE=~/.cache/llama-nemoguard-topiccontrol
    $ mkdir -p "${LOCAL_NIM_CACHE}"
    $ chmod 666 "${LOCAL_NIM_CACHE}"
    
  4. 使用缓存目录作为卷挂载运行容器

    $ docker run -d \
      --name llama-nemoguard-topiccontrol \
      --gpus=all --runtime=nvidia \
      -e NGC_API_KEY \
      -e NIM_SERVED_MODEL_NAME="llama-3.1-nemoguard-8b-topic-control" \
      -e NIM_CUSTOM_MODEL_NAME="llama-3.1-nemoguard-8b-topic-control" \
      -u $(id -u) \
      -v "$LOCAL_NIM_CACHE:/opt/nim/.cache/" \
      -p 8000:8000 \
      nvcr.io/nim/nvidia/llama-3.1-nemoguard-8b-topic-control:1.0.0
    

    容器需要几分钟才能启动并从 NGC 下载模型。您可以通过运行 docker logs llama-nemoguard-topiccontrol 命令来监控进度。

  5. 可选:确认服务已准备好响应推理请求

    $ curl -X GET http://127.0.0.1:8000/v1/health/ready
    

    示例输出

    {"object":"health-response","message":"ready"}
    

运行推理#

您可以向 v1/chat/completions 端点发送请求以执行推理。

以下步骤演示了如何创建一个 Python 脚本来执行以下操作

  • 连接到具有微服务和主题控制模型的容器。

  • 提供一个提示,该提示向主题控制模型提供主题控制指令。

  1. 创建开发环境并安装依赖项

    $ conda create -n evals python=3.10
    $ conda activate evals
    $ pip install requests
    
  2. 创建一个文件,例如 topic_control_inference_example.py,内容如下例所示

    import argparse
    from typing import List, Optional
    import requests
    
    TOPIC_SAFETY_OUTPUT_RESTRICTION = (
        'If any of the above conditions are violated, please respond with "off-topic". '
        'Otherwise, respond with "on-topic". '
        'You must respond with "on-topic" or "off-topic".'
    )
    
    class TopicGuard:
        def __init__(
            self,
            host: str = "0.0.0.0",
            port: str = "8000",
            model_name: str = "llama-3.1-nemoguard-8b-topic-control",
        ):
            self.uri = f"http://{host}:{port}/v1/chat/completions"
            self.model_name = model_name
    
        def __call__(self, prompt: List[dict]) -> str:
            return self._call(prompt)
    
        def _call(self, prompt: List[dict], stop: Optional[List[str]] = None) -> str:
            try:
                response = requests.post(
                    self.uri,
                    headers={
                        "Content-Type": "application/json",
                        "Accept": "application/json",
                    },
                    json={
                        "model": self.model_name,
                        "messages": prompt,
                        "max_tokens": 20,
                        "top_p": 1,
                        "n": 1,
                        "temperature": 0.0,
                        "stream": False,
                        "frequency_penalty": 0.0,
                    },
                )
                if response.status_code != 200:
                    raise Exception(
                        f"Error response from the LLM. Status code: {response.status_code} {response.text}"
                    )
                return response.json()["choices"][0]["message"]["content"].strip()
            except Exception as e:
                print(e)
                return "error"
    
    
    def format_prompt(system_prompt: str, user_message: str) -> str:
    
        system_prompt = system_prompt.strip()
    
        if not system_prompt.endswith(TOPIC_SAFETY_OUTPUT_RESTRICTION):
            system_prompt = f"{system_prompt}\n\n{TOPIC_SAFETY_OUTPUT_RESTRICTION}"
    
        prompt = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_message},
        ]
    
        return prompt
    
    
    if __name__ == "__main__":
        parser = argparse.ArgumentParser()
        parser.add_argument("--nim_host", type=str, default="0.0.0.0")
        parser.add_argument("--nim_port", type=str, default="8000")
        parser.add_argument(
            "--nim_model_name", type=str, default="llama-3.1-nemoguard-8b-topic-control"
        )
        args = parser.parse_args()
    
        system_prompt = """You are to act as an investor relations bot for ABC, providing users with factual, publicly available information related to the company's financial performance and corporate updates. Your role is to ensure that you respond only to relevant queries and adhere to the following guidelines:
    
    1. Do not answer questions about future predictions, such as profit forecasts or future revenue outlook.
    2. Do not provide any form of investment advice, including recommendations to buy, sell, or hold ABC stock or any other securities. Never recommend any stock or investment.
    3. Do not engage in discussions that require personal opinions or subjective judgments. Never make any subjective statements about ABC, its stock or its products.
    4. If a user asks about topics irrelevant to ABC's investor relations or financial performance, politely redirect the conversation or end the interaction.
    5. Your responses should be professional, accurate, and compliant with investor relations guidelines, focusing solely on providing transparent, up-to-date information about ABC that is already publicly available."""
    
        user_message = (
            "Can you speculate on the potential impact of a recession on ABCs business?"
        )
    
        print(
            f"Using Nim inference mode with host: {args.nim_host} and port: {args.nim_port}"
        )
        topic_guard = TopicGuard(
            host=args.nim_host, port=args.nim_port, model_name=args.nim_model_name
        )
    
        prompt = format_prompt(system_prompt, user_message)
        response = topic_guard(prompt)
    
        print(f"For user message: {user_message}")
        print(f"\nResponse from TopicGuard model: {response}")
    
  3. 运行脚本以执行推理

    $ python topic_control_inference_example.py
    

停止容器#

以下命令通过停止并移除正在运行的容器来停止容器。

$ docker stop llama-nemoguard-topiccontrol
$ docker rm llama-nemoguard-topiccontrol