Files
opro_demo/prompt_utils.py

221 lines
7.8 KiB
Python
Raw Normal View History

2025-12-05 07:11:25 +00:00
# Copyright 2023 The OPRO Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The utility functions for prompting GPT and Google Cloud models."""
import openai
import time
try:
import google.generativeai as palm
except Exception:
palm = None
try:
from vllm import LLM, SamplingParams
except Exception:
LLM = None
SamplingParams = None
from pathlib import Path
# 缓存 vLLM 实例,避免重复加载
_llm_instance = None
def get_llm(local_model_path):
if LLM is None:
raise RuntimeError("vLLM not available")
global _llm_instance
if _llm_instance is None:
assert local_model_path is not None, "model_path cannot be None"
local_model_path = str(Path(local_model_path).resolve())
_llm_instance = LLM(
model=local_model_path,
dtype="bfloat16",
tensor_parallel_size=8,
max_num_batched_tokens=8192,
max_num_seqs=64,
gpu_memory_utilization=0.7,
enforce_eager=True,
block_size=16,
enable_chunked_prefill=True,
trust_remote_code=True,
)
return _llm_instance
def call_local_server_single_prompt(prompt, local_model_path=None, temperature=0.8, max_decode_steps=512, **kwargs):
"""
使用本地 vLLM 模型生成单个 prompt 的响应替代原本 OpenAI API
"""
llm = get_llm(local_model_path)
sampling_params = SamplingParams(
temperature=temperature,
top_p=0.9,
max_tokens=max_decode_steps,
skip_special_tokens=True # 避免特殊字符触发协议错误
)
outputs = llm.generate([prompt], sampling_params)
return outputs[0].outputs[0].text
def call_local_server_func(inputs, local_model_path=None, temperature=0.8, max_decode_steps=512, **kwargs):
"""
批量处理多个输入 prompt
"""
assert local_model_path is not None, "local_model_path must be provided"
# 强制类型检查
if isinstance(inputs, bytes):
inputs = inputs.decode('utf-8')
outputs = []
for input_str in inputs:
output = call_local_server_single_prompt(
input_str,
local_model_path=local_model_path,
temperature=temperature,
max_decode_steps=max_decode_steps
)
outputs.append(output)
return outputs
def call_openai_server_single_prompt(
prompt, model="gpt-3.5-turbo", max_decode_steps=20, temperature=0.8
):
"""The function to call OpenAI server with an input string."""
try:
completion = openai.ChatCompletion.create(
model=model,
temperature=temperature,
max_tokens=max_decode_steps,
messages=[
{"role": "user", "content": prompt},
],
)
return completion.choices[0].message.content
# 函数捕获了 6 类常见异常​​,并在遇到错误时自动重试:
except openai.error.Timeout as e: # API 请求超时
retry_time = e.retry_after if hasattr(e, "retry_after") else 30
print(f"Timeout error occurred. Retrying in {retry_time} seconds...")
time.sleep(retry_time)
return call_openai_server_single_prompt(
prompt, max_decode_steps=max_decode_steps, temperature=temperature
)
except openai.error.RateLimitError as e: #请求频率超限Rate Limit
retry_time = e.retry_after if hasattr(e, "retry_after") else 30
print(f"Rate limit exceeded. Retrying in {retry_time} seconds...")
time.sleep(retry_time)
return call_openai_server_single_prompt(
prompt, max_decode_steps=max_decode_steps, temperature=temperature
)
except openai.error.APIError as e: # API 错误(如服务器错误)
retry_time = e.retry_after if hasattr(e, "retry_after") else 30
print(f"API error occurred. Retrying in {retry_time} seconds...")
time.sleep(retry_time)
return call_openai_server_single_prompt(
prompt, max_decode_steps=max_decode_steps, temperature=temperature
)
except openai.error.APIConnectionError as e: # API 连接错误(如网络问题)
retry_time = e.retry_after if hasattr(e, "retry_after") else 30
print(f"API connection error occurred. Retrying in {retry_time} seconds...")
time.sleep(retry_time)
return call_openai_server_single_prompt(
prompt, max_decode_steps=max_decode_steps, temperature=temperature
)
except openai.error.ServiceUnavailableError as e: # 服务不可用(如服务器维护)
retry_time = e.retry_after if hasattr(e, "retry_after") else 30
print(f"Service unavailable. Retrying in {retry_time} seconds...")
time.sleep(retry_time)
return call_openai_server_single_prompt(
prompt, max_decode_steps=max_decode_steps, temperature=temperature
)
except OSError as e: # 操作系统级连接错误(如网络中断)
retry_time = 5 # Adjust the retry time as needed
print(
f"Connection error occurred: {e}. Retrying in {retry_time} seconds..."
)
time.sleep(retry_time)
return call_openai_server_single_prompt(
prompt, max_decode_steps=max_decode_steps, temperature=temperature
)
def call_openai_server_func( #批量处理多个输入提示prompts通过 OpenAI API 并行或顺序获取多个生成结果​​。
inputs, model="gpt-3.5-turbo", max_decode_steps=20, temperature=0.8
):
"""The function to call OpenAI server with a list of input strings."""
if isinstance(inputs, str): # 将单个字符串转为列表,统一处理
inputs = [inputs]
outputs = []
for input_str in inputs:
output = call_openai_server_single_prompt(
input_str,
model=model,
max_decode_steps=max_decode_steps,
temperature=temperature,
)
outputs.append(output)
return outputs
#通过 Google PaLM APICloud 版)调用 text-bison模型生成文本并包含基本错误处理和自动重试机制。
def call_palm_server_from_cloud(
input_text, model="text-bison-001", max_decode_steps=20, temperature=0.8
):
if palm is None:
raise RuntimeError("google.generativeai not available")
assert isinstance(input_text, str)
assert model == "text-bison-001"
all_model_names = [
m
for m in palm.list_models()
if "generateText" in m.supported_generation_methods
]
model_name = all_model_names[0].name
try:
completion = palm.generate_text(
model=model_name,
prompt=input_text,
temperature=temperature,
max_output_tokens=max_decode_steps,
)
output_text = completion.result
return [output_text]
except Exception:
retry_time = 10
time.sleep(retry_time)
return call_palm_server_from_cloud(
input_text, max_decode_steps=max_decode_steps, temperature=temperature
)
def refine_instruction(query: str) -> str:
return f"""
你是一个问题澄清与重写助手
请根据用户的原始问题
{query}
生成不少于20条多角度可直接执行的问题改写每行一条
"""
def refine_instruction_with_history(query: str, rejected_list: list) -> str:
rejected_text = "\n".join(f"- {r}" for r in rejected_list) if rejected_list else ""
return f"""
你是一个问题澄清与重写助手
原始问题
{query}
以下改写已被否定
{rejected_text}
请从新的角度重新生成至少20条不同的改写问题每条单独一行
"""