# Copyright 2023 The OPRO Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """The utility functions for prompting GPT and Google Cloud models.""" import openai import time try: import google.generativeai as palm except Exception: palm = None try: from vllm import LLM, SamplingParams except Exception: LLM = None SamplingParams = None from pathlib import Path # 缓存 vLLM 实例,避免重复加载 _llm_instance = None def get_llm(local_model_path): if LLM is None: raise RuntimeError("vLLM not available") global _llm_instance if _llm_instance is None: assert local_model_path is not None, "model_path cannot be None" local_model_path = str(Path(local_model_path).resolve()) _llm_instance = LLM( model=local_model_path, dtype="bfloat16", tensor_parallel_size=8, max_num_batched_tokens=8192, max_num_seqs=64, gpu_memory_utilization=0.7, enforce_eager=True, block_size=16, enable_chunked_prefill=True, trust_remote_code=True, ) return _llm_instance def call_local_server_single_prompt(prompt, local_model_path=None, temperature=0.8, max_decode_steps=512, **kwargs): """ 使用本地 vLLM 模型生成单个 prompt 的响应,替代原本 OpenAI API。 """ llm = get_llm(local_model_path) sampling_params = SamplingParams( temperature=temperature, top_p=0.9, max_tokens=max_decode_steps, skip_special_tokens=True # 避免特殊字符触发协议错误 ) outputs = llm.generate([prompt], sampling_params) return outputs[0].outputs[0].text def call_local_server_func(inputs, local_model_path=None, temperature=0.8, max_decode_steps=512, **kwargs): """ 批量处理多个输入 prompt。 """ assert local_model_path is not None, "local_model_path must be provided" # 强制类型检查 if isinstance(inputs, bytes): inputs = inputs.decode('utf-8') outputs = [] for input_str in inputs: output = call_local_server_single_prompt( input_str, local_model_path=local_model_path, temperature=temperature, max_decode_steps=max_decode_steps ) outputs.append(output) return outputs def call_openai_server_single_prompt( prompt, model="gpt-3.5-turbo", max_decode_steps=20, temperature=0.8 ): """The function to call OpenAI server with an input string.""" try: completion = openai.ChatCompletion.create( model=model, temperature=temperature, max_tokens=max_decode_steps, messages=[ {"role": "user", "content": prompt}, ], ) return completion.choices[0].message.content # 函数捕获了 ​​6 类常见异常​​,并在遇到错误时自动重试: except openai.error.Timeout as e: # API 请求超时 retry_time = e.retry_after if hasattr(e, "retry_after") else 30 print(f"Timeout error occurred. Retrying in {retry_time} seconds...") time.sleep(retry_time) return call_openai_server_single_prompt( prompt, max_decode_steps=max_decode_steps, temperature=temperature ) except openai.error.RateLimitError as e: #请求频率超限(Rate Limit) retry_time = e.retry_after if hasattr(e, "retry_after") else 30 print(f"Rate limit exceeded. Retrying in {retry_time} seconds...") time.sleep(retry_time) return call_openai_server_single_prompt( prompt, max_decode_steps=max_decode_steps, temperature=temperature ) except openai.error.APIError as e: # API 错误(如服务器错误) retry_time = e.retry_after if hasattr(e, "retry_after") else 30 print(f"API error occurred. Retrying in {retry_time} seconds...") time.sleep(retry_time) return call_openai_server_single_prompt( prompt, max_decode_steps=max_decode_steps, temperature=temperature ) except openai.error.APIConnectionError as e: # API 连接错误(如网络问题) retry_time = e.retry_after if hasattr(e, "retry_after") else 30 print(f"API connection error occurred. Retrying in {retry_time} seconds...") time.sleep(retry_time) return call_openai_server_single_prompt( prompt, max_decode_steps=max_decode_steps, temperature=temperature ) except openai.error.ServiceUnavailableError as e: # 服务不可用(如服务器维护) retry_time = e.retry_after if hasattr(e, "retry_after") else 30 print(f"Service unavailable. Retrying in {retry_time} seconds...") time.sleep(retry_time) return call_openai_server_single_prompt( prompt, max_decode_steps=max_decode_steps, temperature=temperature ) except OSError as e: # 操作系统级连接错误(如网络中断) retry_time = 5 # Adjust the retry time as needed print( f"Connection error occurred: {e}. Retrying in {retry_time} seconds..." ) time.sleep(retry_time) return call_openai_server_single_prompt( prompt, max_decode_steps=max_decode_steps, temperature=temperature ) def call_openai_server_func( #​​批量处理多个输入提示(prompts),通过 OpenAI API 并行或顺序获取多个生成结果​​。 inputs, model="gpt-3.5-turbo", max_decode_steps=20, temperature=0.8 ): """The function to call OpenAI server with a list of input strings.""" if isinstance(inputs, str): # 将单个字符串转为列表,统一处理 inputs = [inputs] outputs = [] for input_str in inputs: output = call_openai_server_single_prompt( input_str, model=model, max_decode_steps=max_decode_steps, temperature=temperature, ) outputs.append(output) return outputs #通过 Google PaLM API(Cloud 版)调用 text-bison模型生成文本​​,并包含基本错误处理和自动重试机制。 def call_palm_server_from_cloud( input_text, model="text-bison-001", max_decode_steps=20, temperature=0.8 ): if palm is None: raise RuntimeError("google.generativeai not available") assert isinstance(input_text, str) assert model == "text-bison-001" all_model_names = [ m for m in palm.list_models() if "generateText" in m.supported_generation_methods ] model_name = all_model_names[0].name try: completion = palm.generate_text( model=model_name, prompt=input_text, temperature=temperature, max_output_tokens=max_decode_steps, ) output_text = completion.result return [output_text] except Exception: retry_time = 10 time.sleep(retry_time) return call_palm_server_from_cloud( input_text, max_decode_steps=max_decode_steps, temperature=temperature ) def refine_instruction(query: str) -> str: return f""" 你是一个“问题澄清与重写助手”。 请根据用户的原始问题: 【{query}】 生成不少于20条多角度、可直接执行的问题改写,每行一条。 """ def refine_instruction_with_history(query: str, rejected_list: list) -> str: rejected_text = "\n".join(f"- {r}" for r in rejected_list) if rejected_list else "" return f""" 你是一个“问题澄清与重写助手”。 原始问题: {query} 以下改写已被否定: {rejected_text} 请从新的角度重新生成至少20条不同的改写问题,每条单独一行。 """