Llama Stack API (实验性)#

警告

NIM 中对 Llama Stack API 的支持是实验性的!

Llama Stack API 是 Meta 开发的一套全面的接口,供 ML 开发人员在 Llama 基础模型之上构建应用。该 API 旨在标准化与 Llama 模型的交互,简化开发人员体验,并促进整个 Llama 生态系统中的创新。Llama Stack 涵盖模型生命周期的各个组件,包括推理、微调、评估和合成数据生成。

借助 Llama Stack API,开发人员可以轻松地将 Llama 模型集成到他们的应用程序中,利用工具调用功能,并构建复杂的 AI 系统。本文档概述了如何使用 Llama Stack API 的 Python 绑定,重点介绍聊天完成和工具使用。

有关完整的 API 文档和源代码,请访问 Llama Stack GitHub 仓库

安装#

要开始使用 Llama Stack API,您需要安装必要的软件包。您可以使用 pip 执行此操作

pip install llama-toolchain llama-models llama-agentic-system

这些软件包提供了使用 Llama Stack API 的核心功能。

常用组件#

以下示例将常用组件存储在文件 inference.py 中。此文件包含 InferenceClient 类和实用程序函数,这些函数在不同的示例中都会用到。以下是 inference.py 的内容

import json
from typing import Union, Generator
import requests
from llama_toolchain.inference.api import (
    ChatCompletionRequest, 
    ChatCompletionResponse, 
    ChatCompletionResponseStreamChunk
)

class InferenceClient:
    def __init__(self, base_url: str):
        self.base_url = base_url

    def chat_completion(self, request: ChatCompletionRequest) -> Generator[Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk], None, None]:
        url = f"{self.base_url}/inference/chat_completion"
        payload = json.loads(request.json())
        
        response = requests.post(
            url,
            json=payload,
            headers={"Content-Type": "application/json"},
            stream=request.stream
        )

        if response.status_code != 200:
            raise Exception(f"Error: HTTP {response.status_code} {response.text}")

        if request.stream:
            for line in response.iter_lines():
                if line:
                    line = line.decode('utf-8')
                    if line.startswith('data: '):
                        data = json.loads(line[6:])
                        yield ChatCompletionResponseStreamChunk(**data)
        else:
            response_data = response.json()
            # Handle potential None values in tool_calls
            if 'completion_message' in response_data and 'tool_calls' in response_data['completion_message']:
                tool_calls = response_data['completion_message']['tool_calls']
                if tool_calls is not None:
                    for tool_call in tool_calls:
                        if 'arguments' in tool_call and tool_call['arguments'] is None:
                            tool_call['arguments'] = ''  # Replace None with empty string
            yield ChatCompletionResponse(**response_data)

def process_chat_completion(response: Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]):
    if isinstance(response, ChatCompletionResponse):
        print("Response content:", response.completion_message.content)
        if response.completion_message.tool_calls:
            print("Tool calls:")
            for tool_call in response.completion_message.tool_calls:
                print(f"  Tool: {tool_call.tool_name}")
                print(f"  Arguments: {tool_call.arguments}")
    elif isinstance(response, ChatCompletionResponseStreamChunk):
        print(response.event.delta, end='', flush=True)
        if response.event.stop_reason:
            print(f"\nStop reason: {response.event.stop_reason}")

基本用法#

在以下基本用法示例中使用这些常用组件

from inference import InferenceClient, process_chat_completion
from llama_toolchain.inference.api import ChatCompletionRequest, UserMessage
from llama_models.llama3.api.datatypes import SamplingParams

def chat():
    client = InferenceClient("http://0.0.0.0:8000/experimental/ls")
    
    message = UserMessage(content="Explain the concept of recursion in programming.")
    request = ChatCompletionRequest(
        model="meta/llama-3.1-70b-instruct",
        messages=[message],
        stream=False,
        sampling_params=SamplingParams(
            max_tokens=1024
        )
    )
    
    for response in client.chat_completion(request):
        process_chat_completion(response)

if __name__ == "__main__":
    chat()

流式响应#

对于流式响应,请使用相同的结构

from inference import InferenceClient, process_chat_completion
from llama_toolchain.inference.api import ChatCompletionRequest, UserMessage
from llama_models.llama3.api.datatypes import SamplingParams

def stream_chat():
    client = InferenceClient("http://0.0.0.0:8000/experimental/ls")
    
    message = UserMessage(content="Write a short story about a time-traveling scientist.")
    request = ChatCompletionRequest(
        model="meta/llama-3.1-70b-instruct",
        messages=[message],
        stream=True,
        sampling_params=SamplingParams(
            max_tokens=1024
        )
    )
    
    for response in client.chat_completion(request):
        process_chat_completion(response)

if __name__ == "__main__":
    stream_chat()

工具调用#

Llama Stack API 支持工具调用,允许模型与外部函数交互。

重要提示

与 OpenAI API 不同,Llama Stack API 仅支持工具选择 "auto"、“required"” 或 None

from inference import InferenceClient, process_chat_completion
from llama_toolchain.inference.api import ChatCompletionRequest, UserMessage, ToolDefinition, ToolParamDefinition
from llama_models.llama3.api.datatypes import SamplingParams, ToolChoice

weather_tool = ToolDefinition(
    tool_name="get_current_weather",
    description="Get the current weather for a location",
    parameters={
        "location": ToolParamDefinition(
            param_type="string",
            description="The city and state, e.g. San Francisco, CA",
            required=True
        ),
        "unit": ToolParamDefinition(
            param_type="string",
            description="The temperature unit (celsius or fahrenheit)",
            required=True
        )
    }
)

def tool_calling_example():
    client = InferenceClient("http://0.0.0.0:8000/experimental/ls")
    
    message = UserMessage(content="Get me the weather in New York City, NY.")
    request = ChatCompletionRequest(
        model="meta/llama-3.1-8b-instruct",
        messages=[message],
        tools=[weather_tool],
        tool_choice=ToolChoice.auto,
        sampling_params=SamplingParams(
            max_tokens=200
        )
    )
    
    for response in client.chat_completion(request):
        process_chat_completion(response)

if __name__ == "__main__":
    tool_calling_example()