采样控制#

NIM for VLMs 公开了一套采样参数,使用户可以精细地控制 VLM 的生成行为。以下是如何在推理请求中配置采样参数的完整参考。

采样参数:OpenAI API#

参数

类型

默认值

注释

presence_penalty

float

0.0

根据新 token 是否出现在已生成的文本中来惩罚新 token。值 > 0 鼓励模型使用新 token,而值 < 0 鼓励模型重复 token。必须在 [-2, 2] 范围内。

frequency_penalty

float

0.0

根据新 token 在已生成的文本中出现的频率来惩罚新 token。值 > 0 鼓励模型使用新 token,而值 < 0 鼓励模型重复 token。必须在 [-2, 2] 范围内。

repetition_penalty

float

1.0

根据新 token 是否出现在提示和已生成的文本中来惩罚新 token。值 > 1 鼓励模型使用新 token,而值 < 1 鼓励模型重复 token。必须在 (0, 2] 范围内。

temperature

float

1.0

控制采样的随机性。较低的值使模型更具确定性,而较高的值使模型更随机。必须 >= 0。设置为 0 以进行贪婪采样。

top_p

float

1.0

控制要考虑的顶部 token 的累积概率。必须在 (0, 1] 范围内。设置为 1 以考虑所有 token。

top_k

int

-1

控制要考虑的顶部 token 的数量。设置为 -1 以考虑所有 token。否则必须 >= 1。

min_p

float

0.0

表示 token 被考虑的最小概率,相对于最有可能的 token 的概率。必须在 [0, 1] 范围内。设置为 0 以禁用 min_p

seed

int

None

用于生成的随机种子。

stop

str 或 List[str]

None

当生成字符串或字符串列表时,停止生成。返回的输出将不包含停止字符串。

ignore_eos

bool

False

是否忽略 EOS token 并在生成 EOS token 后继续生成 token。对于性能基准测试很有用。

max_tokens

int

16

每个输出序列要生成的最大 token 数。必须 >= 1。

min_tokens

int

0

在生成 EOS 或 stop_token_ids 之前,每个输出序列要生成的最小 token 数。必须 >= 0。

logprobs

int

None

每个输出 token 要返回的对数概率数。当设置为 None 时,不返回概率。如果设置为非 None 值,则结果包括指定数量的最有可能 token 的对数概率,以及所选 token 的对数概率。请注意,此实现遵循 OpenAI API:API 将始终返回采样 token 的对数概率,因此响应中最多可能包含 logprob + 1 个元素。必须 >= 0。

prompt_logprobs

int

None

每个提示 token 要返回的对数概率数。必须 >= 0。

response_format

Dict[str, str]

None

指定模型必须输出的格式。设置为 {'type': 'json_object'} 以启用 JSON 模式,这将保证模型生成的输出是有效的 JSON。请参阅 结构化生成

示例#

从命令行

curl -X 'POST' \
'http://0.0.0.0:8000/v1/chat/completions' \
    -H 'Accept: application/json' \
    -H 'Content-Type: application/json' \
    -d '{
        "model": "meta/llama-3.2-11b-vision-instruct",
        "messages": [
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": "What is in this image?"
                    },
                    {
                        "type": "image_url",
                        "image_url":
                            {
                                "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
                            }
                    }
                ]
            }
        ],
        "temperature": 0.2,
        "top_p": 0.7,
        "max_tokens": 256
    }'

使用 OpenAI Python API 库

from openai import OpenAI
client = OpenAI(base_url="http://0.0.0.0:8000/v1", api_key="not-used")
messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "text",
                "text": "What is in this image?"
            },
            {
                "type": "image_url",
                "image_url": {
                    "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
                }
            }
        ]
    }
]
chat_response = client.chat.completions.create(
    model="meta/llama-3.2-11b-vision-instruct",
    messages=messages,
    temperature=0.2,
    top_p=0.7,
    max_tokens=256,
    stream=False
)
assistant_message = chat_response.choices[0].message
print(assistant_message)

高级:Guided Decoding#

NIM for VLMs 额外支持通过 nvext 进行 Guided Decoding 以实现结构化生成。有关示例用例,请参阅 结构化生成

参数

类型

默认值

注释

guided_json

str、dict 或 Pydantic BaseModel

None

如果指定,输出将遵循 JSON 模式。

guided_regex

str

None

如果指定,输出将遵循正则表达式模式。

guided_choice

List[str]

None

如果指定,输出将恰好是选项之一。

guided_grammar

str

None

如果指定,输出将遵循上下文无关文法。

采样参数:Llama Stack API#

参数

类型

默认值

注释

strategy

str

“greedy”

生成的采样策略。必须是 greedytop_ptop_k 之一。

repetition_penalty

float

1.0

根据新 token 是否出现在提示和已生成的文本中来惩罚新 token。值 > 1 鼓励模型使用新 token,而值 < 1 鼓励模型重复 token。必须在 (0, 2] 范围内。

temperature

float

1.0

控制采样的随机性。较低的值使模型更具确定性,而较高的值使模型更随机。必须 >= 0。设置为 0 以进行贪婪采样。

top_p

float

1.0

控制要考虑的顶部 token 的累积概率。必须在 (0, 1] 范围内。设置为 1 以考虑所有 token。

top_k

int

-1

控制要考虑的顶部 token 的数量。设置为 -1 以考虑所有 token。否则必须 >= 1。

重要提示

Llama Stack API 目前不支持 Guided Decoding。

示例#

从命令行

curl -X 'POST' \
'http://0.0.0.0:8000/inference/chat_completion' \
    -H 'Accept: application/json' \
    -H 'Content-Type: application/json' \
    -d '{
        "model": "meta/llama-3.2-11b-vision-instruct",
        "messages": [
            {
                "role": "user",
                "content": [
                    {
                        "image":
                            {
                                "uri": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
                            }
                    },
                    "What is in this image?"
                ]
            }
        ],
        "sampling_params": {
            "temperature": 0.2,
            "top_p": 0.7,
            "max_tokens": 256
        }
    }'

使用 Llama Stack Client Python 库

from llama_stack_client import LlamaStackClient

client = LlamaStackClient(base_url=f"http://0.0.0.0:8000")

messages = [
    {
        "role": "user",
        "content": [
            {
                "image": {
                    "uri": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
                }
            },
            "What is in this image?"
        ]
    }
]

iterator = client.inference.chat_completion(
    model="meta/llama-3.2-11b-vision-instruct",
    messages=messages,
    sampling_params={
        "temperature": 0.2,
        "top_p": 0.7,
        "max_tokens": 256,
    },
    stream=True
)

for chunk in iterator:
    print(chunk.event.delta, end="", flush=True)