221 lines
7.8 KiB
Python
221 lines
7.8 KiB
Python
# Copyright 2023 The OPRO Authors
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
"""The utility functions for prompting GPT and Google Cloud models."""
|
||
import openai
|
||
import time
|
||
try:
|
||
import google.generativeai as palm
|
||
except Exception:
|
||
palm = None
|
||
try:
|
||
from vllm import LLM, SamplingParams
|
||
except Exception:
|
||
LLM = None
|
||
SamplingParams = None
|
||
from pathlib import Path
|
||
|
||
|
||
# 缓存 vLLM 实例,避免重复加载
|
||
_llm_instance = None
|
||
|
||
def get_llm(local_model_path):
|
||
if LLM is None:
|
||
raise RuntimeError("vLLM not available")
|
||
global _llm_instance
|
||
if _llm_instance is None:
|
||
assert local_model_path is not None, "model_path cannot be None"
|
||
local_model_path = str(Path(local_model_path).resolve())
|
||
_llm_instance = LLM(
|
||
model=local_model_path,
|
||
dtype="bfloat16",
|
||
tensor_parallel_size=8,
|
||
max_num_batched_tokens=8192,
|
||
max_num_seqs=64,
|
||
gpu_memory_utilization=0.7,
|
||
enforce_eager=True,
|
||
block_size=16,
|
||
enable_chunked_prefill=True,
|
||
trust_remote_code=True,
|
||
)
|
||
return _llm_instance
|
||
|
||
def call_local_server_single_prompt(prompt, local_model_path=None, temperature=0.8, max_decode_steps=512, **kwargs):
|
||
"""
|
||
使用本地 vLLM 模型生成单个 prompt 的响应,替代原本 OpenAI API。
|
||
"""
|
||
llm = get_llm(local_model_path)
|
||
sampling_params = SamplingParams(
|
||
temperature=temperature,
|
||
top_p=0.9,
|
||
max_tokens=max_decode_steps,
|
||
skip_special_tokens=True # 避免特殊字符触发协议错误
|
||
)
|
||
outputs = llm.generate([prompt], sampling_params)
|
||
return outputs[0].outputs[0].text
|
||
|
||
def call_local_server_func(inputs, local_model_path=None, temperature=0.8, max_decode_steps=512, **kwargs):
|
||
"""
|
||
批量处理多个输入 prompt。
|
||
"""
|
||
|
||
assert local_model_path is not None, "local_model_path must be provided"
|
||
# 强制类型检查
|
||
if isinstance(inputs, bytes):
|
||
inputs = inputs.decode('utf-8')
|
||
outputs = []
|
||
for input_str in inputs:
|
||
output = call_local_server_single_prompt(
|
||
input_str,
|
||
local_model_path=local_model_path,
|
||
temperature=temperature,
|
||
max_decode_steps=max_decode_steps
|
||
)
|
||
outputs.append(output)
|
||
return outputs
|
||
|
||
|
||
def call_openai_server_single_prompt(
|
||
prompt, model="gpt-3.5-turbo", max_decode_steps=20, temperature=0.8
|
||
):
|
||
"""The function to call OpenAI server with an input string."""
|
||
try:
|
||
completion = openai.ChatCompletion.create(
|
||
model=model,
|
||
temperature=temperature,
|
||
max_tokens=max_decode_steps,
|
||
messages=[
|
||
{"role": "user", "content": prompt},
|
||
],
|
||
)
|
||
return completion.choices[0].message.content
|
||
# 函数捕获了 6 类常见异常,并在遇到错误时自动重试:
|
||
except openai.error.Timeout as e: # API 请求超时
|
||
retry_time = e.retry_after if hasattr(e, "retry_after") else 30
|
||
print(f"Timeout error occurred. Retrying in {retry_time} seconds...")
|
||
time.sleep(retry_time)
|
||
return call_openai_server_single_prompt(
|
||
prompt, max_decode_steps=max_decode_steps, temperature=temperature
|
||
)
|
||
|
||
except openai.error.RateLimitError as e: #请求频率超限(Rate Limit)
|
||
retry_time = e.retry_after if hasattr(e, "retry_after") else 30
|
||
print(f"Rate limit exceeded. Retrying in {retry_time} seconds...")
|
||
time.sleep(retry_time)
|
||
return call_openai_server_single_prompt(
|
||
prompt, max_decode_steps=max_decode_steps, temperature=temperature
|
||
)
|
||
|
||
except openai.error.APIError as e: # API 错误(如服务器错误)
|
||
retry_time = e.retry_after if hasattr(e, "retry_after") else 30
|
||
print(f"API error occurred. Retrying in {retry_time} seconds...")
|
||
time.sleep(retry_time)
|
||
return call_openai_server_single_prompt(
|
||
prompt, max_decode_steps=max_decode_steps, temperature=temperature
|
||
)
|
||
|
||
except openai.error.APIConnectionError as e: # API 连接错误(如网络问题)
|
||
retry_time = e.retry_after if hasattr(e, "retry_after") else 30
|
||
print(f"API connection error occurred. Retrying in {retry_time} seconds...")
|
||
time.sleep(retry_time)
|
||
return call_openai_server_single_prompt(
|
||
prompt, max_decode_steps=max_decode_steps, temperature=temperature
|
||
)
|
||
|
||
except openai.error.ServiceUnavailableError as e: # 服务不可用(如服务器维护)
|
||
retry_time = e.retry_after if hasattr(e, "retry_after") else 30
|
||
print(f"Service unavailable. Retrying in {retry_time} seconds...")
|
||
time.sleep(retry_time)
|
||
return call_openai_server_single_prompt(
|
||
prompt, max_decode_steps=max_decode_steps, temperature=temperature
|
||
)
|
||
|
||
except OSError as e: # 操作系统级连接错误(如网络中断)
|
||
retry_time = 5 # Adjust the retry time as needed
|
||
print(
|
||
f"Connection error occurred: {e}. Retrying in {retry_time} seconds..."
|
||
)
|
||
time.sleep(retry_time)
|
||
return call_openai_server_single_prompt(
|
||
prompt, max_decode_steps=max_decode_steps, temperature=temperature
|
||
)
|
||
|
||
|
||
def call_openai_server_func( #批量处理多个输入提示(prompts),通过 OpenAI API 并行或顺序获取多个生成结果。
|
||
inputs, model="gpt-3.5-turbo", max_decode_steps=20, temperature=0.8
|
||
):
|
||
"""The function to call OpenAI server with a list of input strings."""
|
||
if isinstance(inputs, str): # 将单个字符串转为列表,统一处理
|
||
inputs = [inputs]
|
||
outputs = []
|
||
for input_str in inputs:
|
||
output = call_openai_server_single_prompt(
|
||
input_str,
|
||
model=model,
|
||
max_decode_steps=max_decode_steps,
|
||
temperature=temperature,
|
||
)
|
||
outputs.append(output)
|
||
return outputs
|
||
|
||
#通过 Google PaLM API(Cloud 版)调用 text-bison模型生成文本,并包含基本错误处理和自动重试机制。
|
||
def call_palm_server_from_cloud(
|
||
input_text, model="text-bison-001", max_decode_steps=20, temperature=0.8
|
||
):
|
||
if palm is None:
|
||
raise RuntimeError("google.generativeai not available")
|
||
assert isinstance(input_text, str)
|
||
assert model == "text-bison-001"
|
||
all_model_names = [
|
||
m
|
||
for m in palm.list_models()
|
||
if "generateText" in m.supported_generation_methods
|
||
]
|
||
model_name = all_model_names[0].name
|
||
try:
|
||
completion = palm.generate_text(
|
||
model=model_name,
|
||
prompt=input_text,
|
||
temperature=temperature,
|
||
max_output_tokens=max_decode_steps,
|
||
)
|
||
output_text = completion.result
|
||
return [output_text]
|
||
except Exception:
|
||
retry_time = 10
|
||
time.sleep(retry_time)
|
||
return call_palm_server_from_cloud(
|
||
input_text, max_decode_steps=max_decode_steps, temperature=temperature
|
||
)
|
||
|
||
def refine_instruction(query: str) -> str:
|
||
return f"""
|
||
你是一个“问题澄清与重写助手”。
|
||
请根据用户的原始问题:
|
||
【{query}】
|
||
生成不少于20条多角度、可直接执行的问题改写,每行一条。
|
||
"""
|
||
|
||
def refine_instruction_with_history(query: str, rejected_list: list) -> str:
|
||
rejected_text = "\n".join(f"- {r}" for r in rejected_list) if rejected_list else ""
|
||
return f"""
|
||
你是一个“问题澄清与重写助手”。
|
||
原始问题:
|
||
{query}
|
||
|
||
以下改写已被否定:
|
||
{rejected_text}
|
||
|
||
请从新的角度重新生成至少20条不同的改写问题,每条单独一行。
|
||
"""
|