Files
opro_demo/prompt_utils.py
2025-12-05 07:11:25 +00:00

221 lines
7.8 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright 2023 The OPRO Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The utility functions for prompting GPT and Google Cloud models."""
import openai
import time
try:
import google.generativeai as palm
except Exception:
palm = None
try:
from vllm import LLM, SamplingParams
except Exception:
LLM = None
SamplingParams = None
from pathlib import Path
# 缓存 vLLM 实例,避免重复加载
_llm_instance = None
def get_llm(local_model_path):
if LLM is None:
raise RuntimeError("vLLM not available")
global _llm_instance
if _llm_instance is None:
assert local_model_path is not None, "model_path cannot be None"
local_model_path = str(Path(local_model_path).resolve())
_llm_instance = LLM(
model=local_model_path,
dtype="bfloat16",
tensor_parallel_size=8,
max_num_batched_tokens=8192,
max_num_seqs=64,
gpu_memory_utilization=0.7,
enforce_eager=True,
block_size=16,
enable_chunked_prefill=True,
trust_remote_code=True,
)
return _llm_instance
def call_local_server_single_prompt(prompt, local_model_path=None, temperature=0.8, max_decode_steps=512, **kwargs):
"""
使用本地 vLLM 模型生成单个 prompt 的响应,替代原本 OpenAI API。
"""
llm = get_llm(local_model_path)
sampling_params = SamplingParams(
temperature=temperature,
top_p=0.9,
max_tokens=max_decode_steps,
skip_special_tokens=True # 避免特殊字符触发协议错误
)
outputs = llm.generate([prompt], sampling_params)
return outputs[0].outputs[0].text
def call_local_server_func(inputs, local_model_path=None, temperature=0.8, max_decode_steps=512, **kwargs):
"""
批量处理多个输入 prompt。
"""
assert local_model_path is not None, "local_model_path must be provided"
# 强制类型检查
if isinstance(inputs, bytes):
inputs = inputs.decode('utf-8')
outputs = []
for input_str in inputs:
output = call_local_server_single_prompt(
input_str,
local_model_path=local_model_path,
temperature=temperature,
max_decode_steps=max_decode_steps
)
outputs.append(output)
return outputs
def call_openai_server_single_prompt(
prompt, model="gpt-3.5-turbo", max_decode_steps=20, temperature=0.8
):
"""The function to call OpenAI server with an input string."""
try:
completion = openai.ChatCompletion.create(
model=model,
temperature=temperature,
max_tokens=max_decode_steps,
messages=[
{"role": "user", "content": prompt},
],
)
return completion.choices[0].message.content
# 函数捕获了 6 类常见异常​​,并在遇到错误时自动重试:
except openai.error.Timeout as e: # API 请求超时
retry_time = e.retry_after if hasattr(e, "retry_after") else 30
print(f"Timeout error occurred. Retrying in {retry_time} seconds...")
time.sleep(retry_time)
return call_openai_server_single_prompt(
prompt, max_decode_steps=max_decode_steps, temperature=temperature
)
except openai.error.RateLimitError as e: #请求频率超限Rate Limit
retry_time = e.retry_after if hasattr(e, "retry_after") else 30
print(f"Rate limit exceeded. Retrying in {retry_time} seconds...")
time.sleep(retry_time)
return call_openai_server_single_prompt(
prompt, max_decode_steps=max_decode_steps, temperature=temperature
)
except openai.error.APIError as e: # API 错误(如服务器错误)
retry_time = e.retry_after if hasattr(e, "retry_after") else 30
print(f"API error occurred. Retrying in {retry_time} seconds...")
time.sleep(retry_time)
return call_openai_server_single_prompt(
prompt, max_decode_steps=max_decode_steps, temperature=temperature
)
except openai.error.APIConnectionError as e: # API 连接错误(如网络问题)
retry_time = e.retry_after if hasattr(e, "retry_after") else 30
print(f"API connection error occurred. Retrying in {retry_time} seconds...")
time.sleep(retry_time)
return call_openai_server_single_prompt(
prompt, max_decode_steps=max_decode_steps, temperature=temperature
)
except openai.error.ServiceUnavailableError as e: # 服务不可用(如服务器维护)
retry_time = e.retry_after if hasattr(e, "retry_after") else 30
print(f"Service unavailable. Retrying in {retry_time} seconds...")
time.sleep(retry_time)
return call_openai_server_single_prompt(
prompt, max_decode_steps=max_decode_steps, temperature=temperature
)
except OSError as e: # 操作系统级连接错误(如网络中断)
retry_time = 5 # Adjust the retry time as needed
print(
f"Connection error occurred: {e}. Retrying in {retry_time} seconds..."
)
time.sleep(retry_time)
return call_openai_server_single_prompt(
prompt, max_decode_steps=max_decode_steps, temperature=temperature
)
def call_openai_server_func( #批量处理多个输入提示prompts通过 OpenAI API 并行或顺序获取多个生成结果​​。
inputs, model="gpt-3.5-turbo", max_decode_steps=20, temperature=0.8
):
"""The function to call OpenAI server with a list of input strings."""
if isinstance(inputs, str): # 将单个字符串转为列表,统一处理
inputs = [inputs]
outputs = []
for input_str in inputs:
output = call_openai_server_single_prompt(
input_str,
model=model,
max_decode_steps=max_decode_steps,
temperature=temperature,
)
outputs.append(output)
return outputs
#通过 Google PaLM APICloud 版)调用 text-bison模型生成文本并包含基本错误处理和自动重试机制。
def call_palm_server_from_cloud(
input_text, model="text-bison-001", max_decode_steps=20, temperature=0.8
):
if palm is None:
raise RuntimeError("google.generativeai not available")
assert isinstance(input_text, str)
assert model == "text-bison-001"
all_model_names = [
m
for m in palm.list_models()
if "generateText" in m.supported_generation_methods
]
model_name = all_model_names[0].name
try:
completion = palm.generate_text(
model=model_name,
prompt=input_text,
temperature=temperature,
max_output_tokens=max_decode_steps,
)
output_text = completion.result
return [output_text]
except Exception:
retry_time = 10
time.sleep(retry_time)
return call_palm_server_from_cloud(
input_text, max_decode_steps=max_decode_steps, temperature=temperature
)
def refine_instruction(query: str) -> str:
return f"""
你是一个“问题澄清与重写助手”。
请根据用户的原始问题:
{query}
生成不少于20条多角度、可直接执行的问题改写每行一条。
"""
def refine_instruction_with_history(query: str, rejected_list: list) -> str:
rejected_text = "\n".join(f"- {r}" for r in rejected_list) if rejected_list else ""
return f"""
你是一个“问题澄清与重写助手”。
原始问题:
{query}
以下改写已被否定:
{rejected_text}
请从新的角度重新生成至少20条不同的改写问题每条单独一行。
"""