diff --git a/.trae/documents/修复预览访问与静态页面挂载.md b/.trae/documents/修复预览访问与静态页面挂载.md new file mode 100644 index 0000000..256d598 --- /dev/null +++ b/.trae/documents/修复预览访问与静态页面挂载.md @@ -0,0 +1,37 @@ +## 问题分析 +- 浏览器报错 `net::ERR_ABORTED http://0.0.0.0:8010/`,常见于预览环境对 `0.0.0.0` 的访问被中止或跳转流程未完成。 +- 现有后端已启动并挂载静态目录到 `/`,但预览器可能对根路径加载敏感,或端口/主机解析不一致。 +- 目前 API 路由与静态挂载并存,使用相对 `fetch('/query')` 正常;问题主要是根页面加载。 + +## 修复方案 +1. 主机与端口访问 +- 推荐使用 `http://127.0.0.1:8010/` 或 `http://localhost:8010/` 访问,而不是 `0.0.0.0`。 +- 新增 `/health` 路由用于快速诊断服务是否运行。 + +2. 静态页面挂载位置 +- 将静态目录从根 `/` 改为 `/ui` 挂载,降低与根路径的潜在冲突。 +- 新增 `GET /` 路由,返回 `frontend/index.html` 或 302 跳转到 `/ui/index.html`。 + +3. 前端请求与错误处理 +- 保持相对路径 `fetch('/query')`、`/select`、`/reject`,保持同源;增强错误提示(显示响应状态与内容)。 +- 若需要跨源(前端单独部署),补充 CORS:允许前端源访问后端 API。 + +4. 诊断与验证 +- 使用 `curl http://127.0.0.1:8010/health` 验证健康。 +- 使用 `curl` 端到端:`/query`(新会话与带 `session_id` 再生)、`/select`(生成答案)。 +- 浏览器打开 `/ui/` 页面,执行完整流程:开始生成 → 拒绝并再生 → 选择并出答案。 + +## 具体改动清单 +- `_qwen_xinference_demo/api.py` + - 添加 `GET /health` 路由返回 `{status:"ok"}`。 + - 将 `StaticFiles(directory="frontend", html=True)` 从 `/` 挂载到 `/ui`。 + - 添加 `GET /` 路由,返回 `index.html` 或重定向到 `/ui/index.html`。 +- `frontend/index.html` + - 增强错误显示:同时显示响应状态码与文本(提升诊断能力)。 + +## 后续增强(可选) +- 为 `/query` 与 `/select` 增加耗时、来源日志,便于问题排查。 +- 在页面上展示历史候选与拒绝原因列表,提升可观测性。 +- 提供配置项切换嵌入优先级(Xinference/Ollama)。 + +请确认是否按以上方案进行修改与验证,我将立即实施并完成端到端测试。 \ No newline at end of file diff --git a/API.md b/API.md new file mode 100644 index 0000000..bdc1ad8 --- /dev/null +++ b/API.md @@ -0,0 +1,360 @@ +# 项目 API 文档 + +本项目提供用于 OPRO 风格提示优化与会话交互的 REST API。所有接口均使用 `application/json`,无鉴权。示例以默认本地启动地址为例:`http://127.0.0.1:8010`。 + +- 基础路径:`/` +- 前端页面:`/ui/`(三栏界面),`/ui/react`(React 示例页面),`/ui/offline`(离线备份页面) +- 内容类型:`Content-Type: application/json` + +### 统一响应格式 + +所有 JSON 接口统一返回以下包装结构: + +```json +{ + "code": 0, + "msg": "ok", + "data": {} +} +``` + +- 成功:`code` 固定为 `0`,`msg` 为简要说明(默认 `ok`),业务数据在 `data` 字段中。 +- 失败:HTTP 状态码保持原值(如 400/404/500),`code` 同步为该状态码,`msg` 为错误信息,`data` 为 `null`。 + +错误处理位于 `_qwen_xinference_demo/api.py:23-31`(异常处理器),成功响应包装器为 `_qwen_xinference_demo/api.py:21-22` 的 `ok()`。 + +--- + +## 健康检查 + +- 方法与路径:`GET /health` +- 作用:服务可用性检查 +- 响应示例: +```json +{ + "code": 0, + "msg": "ok", + "data": { "status": "ok" } +} +``` + +--- + +## 模型管理 + +### 获取可用模型 +- 方法与路径:`GET /models` +- 作用:列出可用于推理的 Ollama 模型(过滤掉 embedding/reranker) +- 响应示例: +```json +{ + "code": 0, + "msg": "ok", + "data": { + "models": ["qwen3:8b", "qwen3:14b", "qwen3:32b"] + } +} +``` + +### 设置当前会话模型 +- 方法与路径:`POST /set_model` +- 请求体: +```json +{ + "session_id": "", + "model_name": "qwen3:8b" +} +``` +- 成功响应: +```json +{ + "code": 0, + "msg": "ok", + "data": { + "session_id": "", + "model_name": "qwen3:8b" + } +} +``` +- 说明:`model_name` 必须在 `/models` 返回列表中;否则返回 400 错误。 + +--- + +## 会话与候选生成(提示优化) + +提示优化由以下流程实现:根据用户问题或最近消息构造“改写/变异”指令 → 调用 Qwen 批量生成候选 → 通过 Xinference(失败回退到 Ollama embedding)做语义向量 → 聚类去重并选取 Top‑K(默认 5)。 + +### 首次生成候选(创建会话) +- 方法与路径:`POST /query` +- 请求体(新会话): +```json +{ "query": "我想买苹果" } +``` +- 成功响应: +```json +{ + "code": 0, + "msg": "ok", + "data": { + "session_id": "", + "round": 0, + "candidates": ["...", "..."] + } +} +``` +- 说明: + - 新建会话并记录用户原始问题与首轮候选;`round` 会在候选入库后加 1。 + +### 继续优化(基于最近消息再生候选) +- 方法与路径:`POST /query_from_message` +- 请求体: +```json +{ "session_id": "" } +``` +- 成功响应: +```json +{ + "code": 0, + "msg": "ok", + "data": { + "session_id": "", + "round": 1, + "candidates": ["...", "..."] + } +} +``` +- 说明: + - 从会话的最近一条“用户消息”或原始问题作为基线生成新候选。 + +### 选择候选并回答 +- 方法与路径:`POST /select` +- 请求体: +```json +{ + "session_id": "", + "choice": "选中的提示词" +} +``` +- 成功响应: +```json +{ + "code": 0, + "msg": "ok", + "data": { + "prompt": "选中的提示词", + "answer": "模型回答内容" + } +} +``` +- 说明: + - 将 `choice` 记录为当前会话的 `selected_prompt`,并用该提示词生成回答。 + - 会把用户选择与回答追加到 `outputs/user_feedback.jsonl`。 + +### 拒绝候选并再生成 +- 方法与路径:`POST /reject` +- 请求体: +```json +{ + "session_id": "", + "candidate": "不合适的候选", + "reason": "可选的拒绝理由" +} +``` +- 成功响应: +```json +{ + "code": 0, + "msg": "ok", + "data": { + "session_id": "", + "round": 2, + "candidates": ["...", "..."] + } +} +``` +- 说明: + - 将被拒绝的候选加入会话历史,生成新一轮候选以“避撞并多样化”。 + +### 直接回答 + 候选(可选流程) +- 方法与路径:`POST /answer` +- 请求体: +```json +{ "query": "我想买苹果" } +``` +- 成功响应: +```json +{ + "code": 0, + "msg": "ok", + "data": { + "session_id": "", + "answer": "直接回答内容", + "candidates": ["...", "..."] + } +} +``` +- 说明: + - 先对用户问题直接回答,再生成提示优化候选。该路由默认使用后端配置的模型。 + +### 再次生成(旧接口,含 MAX_ROUNDS) +- 方法与路径:`POST /next` +- 请求体: +```json +{ "session_id": "" } +``` +- 成功响应(达到最大轮次时): +```json +{ + "code": 0, + "msg": "ok", + "data": { "final": true, "answer": "最终回答" } +} +``` +- 成功响应(未达到最大轮次时): +```json +{ + "code": 0, + "msg": "ok", + "data": { "session_id": "", "round": 1, "candidates": ["...", "..."] } +} +``` +- 说明: + - `MAX_ROUNDS` 当前为 3,仅对该路由有效;前端默认不使用此路由。 + +--- + +## 会话聊天 + +### 发送消息并获取回答 +- 方法与路径:`POST /message` +- 请求体: +```json +{ + "session_id": "", + "message": "继续提问或补充说明" +} +``` +- 成功响应: +```json +{ + "code": 0, + "msg": "ok", + "data": { + "session_id": "", + "answer": "模型回答", + "history": [ + {"role": "user", "content": "..."}, + {"role": "assistant", "content": "..."} + ] + } +} +``` +- 说明: + - 回答会在已选提示词(如无则原始问题)基础上拼接本次消息生成。 + +--- + +## 会话管理 + +### 列出会话 +- 方法与路径:`GET /sessions` +- 成功响应: +```json +{ + "code": 0, + "msg": "ok", + "data": { + "sessions": [ + { + "session_id": "", + "round": 2, + "selected_prompt": "...", + "original_query": "我想买苹果" + } + ] + } +} +``` + +### 会话详情 +- 方法与路径:`GET /session/{sid}` +- 成功响应: +```json +{ + "code": 0, + "msg": "ok", + "data": { + "session_id": "", + "round": 2, + "original_query": "我想买苹果", + "selected_prompt": "...", + "candidates": ["...", "..."], + "user_feedback": [{"round": 1, "choice": "..."}], + "rejected": ["...", "..."], + "history": [ + {"role": "user", "content": "..."}, + {"role": "assistant", "content": "..."} + ] + } +} +``` + +--- + +## 静态页面与重定向 + +- `GET /` → 重定向到 `/ui/` +- `GET /ui/` → 前端三栏页面(由后端挂载静态目录 `frontend`) +- `GET /ui/react` → React 版本示例页面 +- `GET /ui/offline` → 离线页面(无 CDN 依赖) +- `GET /react` → 与 `/ui/react` 等价的页面入口 + +--- + +## 错误码与通用返回 + +- 错误包装: + - HTTP 404:`{"code": 404, "msg": "session not found", "data": null}` + - HTTP 400:`{"code": 400, "msg": "model not available: "|"ollama error: ", "data": null}` + - HTTP 500:`{"code": 500, "msg": "internal error", "data": null}` + +--- + +## 调用示例(curl) + +```bash +# 创建会话并生成首轮候选 +curl -X POST http://127.0.0.1:8010/query \ + -H 'Content-Type: application/json' \ + -d '{"query": "我想买苹果"}' + +# 选择某个候选并回答 +curl -X POST http://127.0.0.1:8010/select \ + -H 'Content-Type: application/json' \ + -d '{"session_id": "", "choice": "选中的提示词"}' + +# 拒绝某个候选并再生成 +curl -X POST http://127.0.0.1:8010/reject \ + -H 'Content-Type: application/json' \ + -d '{"session_id": "", "candidate": "不合适的候选", "reason": "太笼统"}' + +# 基于最近消息继续优化 +curl -X POST http://127.0.0.1:8010/query_from_message \ + -H 'Content-Type: application/json' \ + -d '{"session_id": ""}' + +# 普通聊天 +curl -X POST http://127.0.0.1:8010/message \ + -H 'Content-Type: application/json' \ + -d '{"session_id": "", "message": "有无更甜的品种?"}' + +# 获取会话详情 +curl http://127.0.0.1:8010/session/ +``` + +--- + +## 备注 +- 候选 Top‑K 默认 5,聚类阈值默认 `0.15`。 +- 向量优先使用 Xinference(`http://127.0.0.1:9997/...`),失败自动回退到 Ollama embedding(`qwen3-embedding:4b`)。 +- 回答默认使用 Ollama 中的 `qwen3:8b`,或通过 `/set_model` 设置当前会话模型。 diff --git a/__pycache__/config.cpython-310.pyc b/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000..7ed0201 Binary files /dev/null and b/__pycache__/config.cpython-310.pyc differ diff --git a/__pycache__/config.cpython-313.pyc b/__pycache__/config.cpython-313.pyc new file mode 100644 index 0000000..82e4658 Binary files /dev/null and b/__pycache__/config.cpython-313.pyc differ diff --git a/__pycache__/ollama_client.cpython-310.pyc b/__pycache__/ollama_client.cpython-310.pyc new file mode 100644 index 0000000..a78a861 Binary files /dev/null and b/__pycache__/ollama_client.cpython-310.pyc differ diff --git a/__pycache__/prompt_utils.cpython-310.pyc b/__pycache__/prompt_utils.cpython-310.pyc new file mode 100644 index 0000000..fa55b45 Binary files /dev/null and b/__pycache__/prompt_utils.cpython-310.pyc differ diff --git a/__pycache__/session_state.cpython-310.pyc b/__pycache__/session_state.cpython-310.pyc new file mode 100644 index 0000000..5f74013 Binary files /dev/null and b/__pycache__/session_state.cpython-310.pyc differ diff --git a/__pycache__/user_prompt_optimizer.cpython-310.pyc b/__pycache__/user_prompt_optimizer.cpython-310.pyc new file mode 100644 index 0000000..eac9c0b Binary files /dev/null and b/__pycache__/user_prompt_optimizer.cpython-310.pyc differ diff --git a/__pycache__/xinference_client.cpython-310.pyc b/__pycache__/xinference_client.cpython-310.pyc new file mode 100644 index 0000000..75b9114 Binary files /dev/null and b/__pycache__/xinference_client.cpython-310.pyc differ diff --git a/_qwen_xinference_demo/__pycache__/api.cpython-310.pyc b/_qwen_xinference_demo/__pycache__/api.cpython-310.pyc new file mode 100644 index 0000000..93a16cf Binary files /dev/null and b/_qwen_xinference_demo/__pycache__/api.cpython-310.pyc differ diff --git a/_qwen_xinference_demo/__pycache__/api.cpython-313.pyc b/_qwen_xinference_demo/__pycache__/api.cpython-313.pyc new file mode 100644 index 0000000..5e08690 Binary files /dev/null and b/_qwen_xinference_demo/__pycache__/api.cpython-313.pyc differ diff --git a/_qwen_xinference_demo/api.py b/_qwen_xinference_demo/api.py new file mode 100644 index 0000000..5f982d7 --- /dev/null +++ b/_qwen_xinference_demo/api.py @@ -0,0 +1,284 @@ +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import RedirectResponse, FileResponse, JSONResponse +from fastapi.staticfiles import StaticFiles +from pydantic import BaseModel +import config + +from .opro.session_state import create_session, get_session, update_session_add_candidates, log_user_choice +from .opro.session_state import log_user_reject +from .opro.session_state import set_selected_prompt, log_chat_message +from .opro.session_state import set_session_model +from .opro.session_state import USER_FEEDBACK_LOG +from .opro.user_prompt_optimizer import generate_candidates +from .opro.ollama_client import call_qwen +from .opro.ollama_client import list_models + +from fastapi.middleware.cors import CORSMiddleware + +app = FastAPI( + title=config.APP_TITLE, + description=config.APP_DESCRIPTION, + version=config.APP_VERSION, + contact=config.APP_CONTACT, + openapi_tags=[ + {"name": "health", "description": "健康检查"}, + {"name": "models", "description": "模型列表与设置"}, + {"name": "sessions", "description": "会话管理"}, + {"name": "opro", "description": "提示优化候选生成与选择/拒绝"}, + {"name": "chat", "description": "会话聊天"}, + {"name": "ui", "description": "静态页面"} + ] +) +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) +MAX_ROUNDS = 3 + +def ok(data=None): + return JSONResponse({"success": True, "data": data}) + +class AppException(HTTPException): + def __init__(self, status_code: int, detail: str, error_code: str): + super().__init__(status_code=status_code, detail=detail) + self.error_code = error_code + +@app.exception_handler(AppException) +def _app_exc_handler(request: Request, exc: AppException): + return JSONResponse(status_code=exc.status_code, content={ + "success": False, + "error": str(exc.detail), + "error_code": exc.error_code + }) + +@app.exception_handler(HTTPException) +def _http_exc_handler(request: Request, exc: HTTPException): + return JSONResponse(status_code=exc.status_code, content={ + "success": False, + "error": str(exc.detail), + "error_code": "HTTP_ERROR" + }) + +@app.exception_handler(Exception) +def _generic_exc_handler(request: Request, exc: Exception): + return JSONResponse(status_code=500, content={ + "success": False, + "error": "internal error", + "error_code": "INTERNAL_ERROR" + }) + +class StartReq(BaseModel): + query: str + +class NextReq(BaseModel): + session_id: str + +class SelectReq(BaseModel): + session_id: str + choice: str + +class RejectReq(BaseModel): + session_id: str + candidate: str + reason: str | None = None + +class SetModelReq(BaseModel): + session_id: str + model_name: str + +@app.post("/start", tags=["opro"]) +def start(req: StartReq): + sid = create_session(req.query) + cands = generate_candidates(req.query, [], model_name=get_session(sid).get("model_name")) + update_session_add_candidates(sid, cands) + return ok({"session_id": sid, "round": 0, "candidates": cands}) + +@app.post("/next", tags=["opro"]) +def next_round(req: NextReq): + s = get_session(req.session_id) + if not s: + raise AppException(404, "session not found", "SESSION_NOT_FOUND") + + if s["round"] >= MAX_ROUNDS: + ans = call_qwen(s["original_query"], temperature=0.3, max_tokens=512) + return ok({"final": True, "answer": ans}) + + cands = generate_candidates(s["original_query"], s["history_candidates"], model_name=s.get("model_name")) + update_session_add_candidates(req.session_id, cands) + return ok({"session_id": req.session_id, "round": s["round"], "candidates": cands}) + +@app.post("/select", tags=["opro"]) +def select(req: SelectReq): + s = get_session(req.session_id) + if not s: + raise AppException(404, "session not found", "SESSION_NOT_FOUND") + + log_user_choice(req.session_id, req.choice) + set_selected_prompt(req.session_id, req.choice) + log_chat_message(req.session_id, "system", req.choice) + try: + ans = call_qwen(req.choice, temperature=0.2, max_tokens=1024, model_name=s.get("model_name")) + except Exception as e: + raise AppException(400, f"ollama error: {e}", "OLLAMA_ERROR") + log_chat_message(req.session_id, "assistant", ans) + try: + import os, json + os.makedirs("outputs", exist_ok=True) + with open("outputs/user_feedback.jsonl", "a", encoding="utf-8") as f: + f.write(json.dumps({ + "session_id": req.session_id, + "round": s["round"], + "choice": req.choice, + "answer": ans + }, ensure_ascii=False) + "\n") + except Exception: + pass + return ok({"prompt": req.choice, "answer": ans}) + +@app.post("/reject", tags=["opro"]) +def reject(req: RejectReq): + s = get_session(req.session_id) + if not s: + raise AppException(404, "session not found", "SESSION_NOT_FOUND") + log_user_reject(req.session_id, req.candidate, req.reason) + cands = generate_candidates(s["original_query"], s["history_candidates"] + [req.candidate], model_name=s.get("model_name")) + update_session_add_candidates(req.session_id, cands) + return ok({"session_id": req.session_id, "round": s["round"], "candidates": cands}) +class QueryReq(BaseModel): + query: str + session_id: str | None = None + +@app.post("/query", tags=["opro"]) +def query(req: QueryReq): + if req.session_id: + s = get_session(req.session_id) + if not s: + raise AppException(404, "session not found", "SESSION_NOT_FOUND") + cands = generate_candidates(s["original_query"], s["history_candidates"], model_name=s.get("model_name")) + update_session_add_candidates(req.session_id, cands) + return ok({"session_id": req.session_id, "round": s["round"], "candidates": cands}) + else: + sid = create_session(req.query) + log_chat_message(sid, "user", req.query) + cands = generate_candidates(req.query, [], model_name=get_session(sid).get("model_name")) + update_session_add_candidates(sid, cands) + return ok({"session_id": sid, "round": 0, "candidates": cands}) +app.mount("/ui", StaticFiles(directory="frontend", html=True), name="static") + +@app.get("/", tags=["ui"]) +def root(): + return RedirectResponse(url="/ui/") + +@app.get("/health", tags=["health"]) +def health(): + return ok({"status": "ok", "version": config.APP_VERSION}) + +@app.get("/version", tags=["health"]) +def version(): + return ok({"version": config.APP_VERSION}) + +# @app.get("/ui/react", tags=["ui"]) +# def ui_react(): +# return FileResponse("frontend/react/index.html") + +# @app.get("/ui/offline", tags=["ui"]) +# def ui_offline(): +# return FileResponse("frontend/ui_offline.html") + +@app.get("/react", tags=["ui"]) +def react_root(): + return FileResponse("frontend/react/index.html") + +@app.get("/sessions", tags=["sessions"]) +def sessions(): + from .opro.session_state import SESSIONS + return ok({"sessions": [{ + "session_id": sid, + "round": s.get("round", 0), + "selected_prompt": s.get("selected_prompt"), + "original_query": s.get("original_query") + } for sid, s in SESSIONS.items()]}) + +@app.get("/session/{sid}", tags=["sessions"]) +def session_detail(sid: str): + s = get_session(sid) + if not s: + raise AppException(404, "session not found", "SESSION_NOT_FOUND") + return ok({ + "session_id": sid, + "round": s["round"], + "original_query": s["original_query"], + "selected_prompt": s["selected_prompt"], + "candidates": s["history_candidates"], + "user_feedback": s["user_feedback"], + "rejected": s["rejected"], + "history": s["chat_history"], + }) + +class MessageReq(BaseModel): + session_id: str + message: str + +@app.post("/message", tags=["chat"]) +def message(req: MessageReq): + s = get_session(req.session_id) + if not s: + raise AppException(404, "session not found", "SESSION_NOT_FOUND") + log_chat_message(req.session_id, "user", req.message) + base_prompt = s.get("selected_prompt") or s["original_query"] + full_prompt = base_prompt + "\n\n" + req.message + try: + ans = call_qwen(full_prompt, temperature=0.3, max_tokens=1024, model_name=s.get("model_name")) + except Exception as e: + raise AppException(400, f"ollama error: {e}", "OLLAMA_ERROR") + log_chat_message(req.session_id, "assistant", ans) + return ok({"session_id": req.session_id, "answer": ans, "history": s["chat_history"]}) + +class QueryFromMsgReq(BaseModel): + session_id: str + +@app.post("/query_from_message", tags=["opro"]) +def query_from_message(req: QueryFromMsgReq): + s = get_session(req.session_id) + if not s: + raise AppException(404, "session not found", "SESSION_NOT_FOUND") + last_user = None + for m in reversed(s.get("chat_history", [])): + if m.get("role") == "user" and m.get("content"): + last_user = m["content"] + break + base = last_user or s["original_query"] + cands = generate_candidates(base, s["history_candidates"], model_name=s.get("model_name")) + update_session_add_candidates(req.session_id, cands) + return ok({"session_id": req.session_id, "round": s["round"], "candidates": cands}) + +class AnswerReq(BaseModel): + query: str + +@app.post("/answer", tags=["opro"]) +def answer(req: AnswerReq): + sid = create_session(req.query) + log_chat_message(sid, "user", req.query) + ans = call_qwen(req.query, temperature=0.2, max_tokens=1024) + log_chat_message(sid, "assistant", ans) + cands = generate_candidates(req.query, []) + update_session_add_candidates(sid, cands) + return ok({"session_id": sid, "answer": ans, "candidates": cands}) + +@app.get("/models", tags=["models"]) +def models(): + return ok({"models": list_models()}) + +@app.post("/set_model", tags=["models"]) +def set_model(req: SetModelReq): + s = get_session(req.session_id) + if not s: + raise AppException(404, "session not found", "SESSION_NOT_FOUND") + avail = set(list_models() or []) + if req.model_name not in avail: + raise AppException(400, f"model not available: {req.model_name}", "MODEL_NOT_AVAILABLE") + set_session_model(req.session_id, req.model_name) + return ok({"session_id": req.session_id, "model_name": req.model_name}) diff --git a/_qwen_xinference_demo/opro/__pycache__/ollama_client.cpython-310.pyc b/_qwen_xinference_demo/opro/__pycache__/ollama_client.cpython-310.pyc new file mode 100644 index 0000000..bf092de Binary files /dev/null and b/_qwen_xinference_demo/opro/__pycache__/ollama_client.cpython-310.pyc differ diff --git a/_qwen_xinference_demo/opro/__pycache__/ollama_client.cpython-313.pyc b/_qwen_xinference_demo/opro/__pycache__/ollama_client.cpython-313.pyc new file mode 100644 index 0000000..29fc91b Binary files /dev/null and b/_qwen_xinference_demo/opro/__pycache__/ollama_client.cpython-313.pyc differ diff --git a/_qwen_xinference_demo/opro/__pycache__/prompt_utils.cpython-310.pyc b/_qwen_xinference_demo/opro/__pycache__/prompt_utils.cpython-310.pyc new file mode 100644 index 0000000..5aaa537 Binary files /dev/null and b/_qwen_xinference_demo/opro/__pycache__/prompt_utils.cpython-310.pyc differ diff --git a/_qwen_xinference_demo/opro/__pycache__/prompt_utils.cpython-313.pyc b/_qwen_xinference_demo/opro/__pycache__/prompt_utils.cpython-313.pyc new file mode 100644 index 0000000..dddc7d6 Binary files /dev/null and b/_qwen_xinference_demo/opro/__pycache__/prompt_utils.cpython-313.pyc differ diff --git a/_qwen_xinference_demo/opro/__pycache__/session_state.cpython-310.pyc b/_qwen_xinference_demo/opro/__pycache__/session_state.cpython-310.pyc new file mode 100644 index 0000000..fd1643b Binary files /dev/null and b/_qwen_xinference_demo/opro/__pycache__/session_state.cpython-310.pyc differ diff --git a/_qwen_xinference_demo/opro/__pycache__/session_state.cpython-313.pyc b/_qwen_xinference_demo/opro/__pycache__/session_state.cpython-313.pyc new file mode 100644 index 0000000..2c6872b Binary files /dev/null and b/_qwen_xinference_demo/opro/__pycache__/session_state.cpython-313.pyc differ diff --git a/_qwen_xinference_demo/opro/__pycache__/user_prompt_optimizer.cpython-310.pyc b/_qwen_xinference_demo/opro/__pycache__/user_prompt_optimizer.cpython-310.pyc new file mode 100644 index 0000000..bee2f1d Binary files /dev/null and b/_qwen_xinference_demo/opro/__pycache__/user_prompt_optimizer.cpython-310.pyc differ diff --git a/_qwen_xinference_demo/opro/__pycache__/user_prompt_optimizer.cpython-313.pyc b/_qwen_xinference_demo/opro/__pycache__/user_prompt_optimizer.cpython-313.pyc new file mode 100644 index 0000000..86a5be2 Binary files /dev/null and b/_qwen_xinference_demo/opro/__pycache__/user_prompt_optimizer.cpython-313.pyc differ diff --git a/_qwen_xinference_demo/opro/__pycache__/xinference_client.cpython-310.pyc b/_qwen_xinference_demo/opro/__pycache__/xinference_client.cpython-310.pyc new file mode 100644 index 0000000..468d6c1 Binary files /dev/null and b/_qwen_xinference_demo/opro/__pycache__/xinference_client.cpython-310.pyc differ diff --git a/_qwen_xinference_demo/opro/__pycache__/xinference_client.cpython-313.pyc b/_qwen_xinference_demo/opro/__pycache__/xinference_client.cpython-313.pyc new file mode 100644 index 0000000..3b8eb3f Binary files /dev/null and b/_qwen_xinference_demo/opro/__pycache__/xinference_client.cpython-313.pyc differ diff --git a/_qwen_xinference_demo/opro/ollama_client.py b/_qwen_xinference_demo/opro/ollama_client.py new file mode 100644 index 0000000..e5b6ba1 --- /dev/null +++ b/_qwen_xinference_demo/opro/ollama_client.py @@ -0,0 +1,52 @@ +import requests +import re +import config + +OLLAMA_URL = config.OLLAMA_GENERATE_URL +TAGS_URL = config.OLLAMA_TAGS_URL +MODEL_NAME = config.DEFAULT_CHAT_MODEL + +def call_qwen(prompt: str, temperature: float = 0.8, max_tokens: int = 512, model_name: str | None = None) -> str: + def _payload(m: str): + return { + "model": m, + "prompt": prompt, + "stream": False, + "options": { + "temperature": temperature, + "num_predict": max_tokens + } + } + primary = model_name or MODEL_NAME + try: + resp = requests.post(OLLAMA_URL, json=_payload(primary), timeout=60) + resp.raise_for_status() + data = resp.json() + return data.get("response", "") or data.get("text", "") + except requests.HTTPError as e: + # Try fallback to default when user-selected model fails + if model_name and model_name != MODEL_NAME: + try: + resp = requests.post(OLLAMA_URL, json=_payload(MODEL_NAME), timeout=60) + resp.raise_for_status() + data = resp.json() + return data.get("response", "") or data.get("text", "") + except Exception: + pass + raise + +def list_models() -> list[str]: + try: + r = requests.get(TAGS_URL, timeout=10) + r.raise_for_status() + data = r.json() or {} + items = data.get("models") or [] + names = [] + for m in items: + name = m.get("name") or m.get("model") + if name: + names.append(name) + names = [n for n in names if not re.search(r"embedding|rerank|reranker|bge", n, re.I)] + return names + except Exception: + return [MODEL_NAME] diff --git a/_qwen_xinference_demo/opro/prompt_utils.py b/_qwen_xinference_demo/opro/prompt_utils.py new file mode 100644 index 0000000..4d44486 --- /dev/null +++ b/_qwen_xinference_demo/opro/prompt_utils.py @@ -0,0 +1,20 @@ +def refine_instruction(query: str) -> str: + return f""" +你是一个“问题澄清与重写助手”。 +请根据用户的原始问题: +【{query}】 +生成不少于20条多角度、可直接执行的问题改写,每行一条。 +""" + +def refine_instruction_with_history(query: str, rejected_list: list) -> str: + rejected_text = "\n".join(f"- {r}" for r in rejected_list) if rejected_list else "" + return f""" +你是一个“问题澄清与重写助手”。 +原始问题: +{query} + +以下改写已被否定: +{rejected_text} + +请从新的角度重新生成至少20条不同的改写问题,每条单独一行。 +""" diff --git a/_qwen_xinference_demo/opro/session_state.py b/_qwen_xinference_demo/opro/session_state.py new file mode 100644 index 0000000..96c7c77 --- /dev/null +++ b/_qwen_xinference_demo/opro/session_state.py @@ -0,0 +1,56 @@ +import uuid + +SESSIONS = {} +USER_FEEDBACK_LOG = [] + +def create_session(query: str) -> str: + sid = uuid.uuid4().hex + SESSIONS[sid] = { + "original_query": query, + "round": 0, + "history_candidates": [], + "user_feedback": [], + "rejected": [], + "selected_prompt": None, + "chat_history": [], + "model_name": None + } + return sid + +def get_session(sid: str): + return SESSIONS.get(sid) + +def update_session_add_candidates(sid: str, candidates: list): + s = SESSIONS[sid] + s["round"] += 1 + s["history_candidates"].extend(candidates) + +def log_user_choice(sid: str, choice: str): + SESSIONS[sid]["user_feedback"].append( + {"round": SESSIONS[sid]["round"], "choice": choice} + ) + USER_FEEDBACK_LOG.append({ + "session_id": sid, + "round": SESSIONS[sid]["round"], + "choice": choice + }) + +def log_user_reject(sid: str, candidate: str, reason: str | None = None): + SESSIONS[sid]["rejected"].append(candidate) + USER_FEEDBACK_LOG.append({ + "session_id": sid, + "round": SESSIONS[sid]["round"], + "reject": candidate, + "reason": reason or "" + }) + +def set_selected_prompt(sid: str, prompt: str): + SESSIONS[sid]["selected_prompt"] = prompt + +def log_chat_message(sid: str, role: str, content: str): + SESSIONS[sid]["chat_history"].append({"role": role, "content": content}) + +def set_session_model(sid: str, model_name: str | None): + s = SESSIONS.get(sid) + if s is not None: + s["model_name"] = model_name diff --git a/_qwen_xinference_demo/opro/user_prompt_optimizer.py b/_qwen_xinference_demo/opro/user_prompt_optimizer.py new file mode 100644 index 0000000..d742e8f --- /dev/null +++ b/_qwen_xinference_demo/opro/user_prompt_optimizer.py @@ -0,0 +1,55 @@ +import re +import numpy as np +from sklearn.cluster import AgglomerativeClustering +from sklearn.metrics.pairwise import cosine_similarity +import config + +from .ollama_client import call_qwen +from .xinference_client import embed_texts +from .prompt_utils import refine_instruction, refine_instruction_with_history + +def parse_candidates(raw: str) -> list: + lines = [l.strip() for l in re.split(r'\r?\n', raw) if l.strip()] + cleaned = [] + for l in lines: + l = re.sub(r'^[\-\*\d\.\)\s]+', '', l).strip() + if len(l) >= 6: + cleaned.append(l) + return list(dict.fromkeys(cleaned)) + +def cluster_and_select(candidates: list, top_k=config.TOP_K, distance_threshold=config.CLUSTER_DISTANCE_THRESHOLD): + if not candidates: + return [] + if len(candidates) <= top_k: + return candidates + vecs = embed_texts(candidates) + if not vecs or len(vecs) != len(candidates): + return candidates[:top_k] + X = np.array(vecs) + + clustering = AgglomerativeClustering(n_clusters=None, + distance_threshold=distance_threshold, + metric="cosine", + linkage="average") + labels = clustering.fit_predict(X) + + selected_idx = [] + for label in sorted(set(labels)): + idxs = [i for i,l in enumerate(labels) if l == label] + sims = cosine_similarity(X[idxs]).mean(axis=1) + rep = idxs[int(np.argmax(sims))] + selected_idx.append(rep) + + selected = [candidates[i] for i in sorted(selected_idx)] + return selected[:top_k] + +def generate_candidates(query: str, rejected=None, top_k=config.TOP_K, model_name=None): + rejected = rejected or [] + if rejected: + prompt = refine_instruction_with_history(query, rejected) + else: + prompt = refine_instruction(query) + + raw = call_qwen(prompt, temperature=0.9, max_tokens=1024, model_name=model_name) + all_candidates = parse_candidates(raw) + return cluster_and_select(all_candidates, top_k=top_k) diff --git a/_qwen_xinference_demo/opro/xinference_client.py b/_qwen_xinference_demo/opro/xinference_client.py new file mode 100644 index 0000000..c0d3a8b --- /dev/null +++ b/_qwen_xinference_demo/opro/xinference_client.py @@ -0,0 +1,29 @@ +import requests +from typing import List +import config + +XINFERENCE_EMBED_URL = config.XINFERENCE_EMBED_URL +OLLAMA_EMBED_URL = config.OLLAMA_HOST + "/api/embeddings" + +def embed_texts(texts: List[str]) -> List[List[float]]: + payload = {"inputs": texts} + try: + resp = requests.post(XINFERENCE_EMBED_URL, json=payload, timeout=10) + resp.raise_for_status() + data = resp.json() + embs = data.get("embeddings", []) + if embs: + return embs + except Exception: + pass + + try: + payload2 = {"model": config.DEFAULT_EMBED_MODEL, "input": texts} + resp2 = requests.post(OLLAMA_EMBED_URL, json=payload2, timeout=15) + resp2.raise_for_status() + data2 = resp2.json() + if isinstance(data2, dict) and "data" in data2: + return [item.get("embedding", []) for item in data2["data"]] + return data2.get("embeddings", []) + except Exception: + return [] diff --git a/config.py b/config.py new file mode 100644 index 0000000..2662c83 --- /dev/null +++ b/config.py @@ -0,0 +1,19 @@ +APP_TITLE = "OPRO Prompt Optimizer API" +APP_DESCRIPTION = "提供提示优化、候选生成、会话聊天与模型管理的接口" +APP_VERSION = "0.1.0" +APP_CONTACT = {"name": "OPRO Team", "url": "http://127.0.0.1:8010/ui/"} + +# Ollama endpoints +OLLAMA_HOST = "http://127.0.0.1:11434" +OLLAMA_GENERATE_URL = f"{OLLAMA_HOST}/api/generate" +OLLAMA_TAGS_URL = f"{OLLAMA_HOST}/api/tags" +DEFAULT_CHAT_MODEL = "qwen3:8b" +DEFAULT_EMBED_MODEL = "qwen3-embedding:4b" + +# Xinference +XINFERENCE_EMBED_URL = "http://127.0.0.1:9997/models/bge-base-zh/embed" + +# Clustering/selection +TOP_K = 5 +CLUSTER_DISTANCE_THRESHOLD = 0.15 + diff --git a/examples/client_demo.py b/examples/client_demo.py new file mode 100644 index 0000000..5d3303e --- /dev/null +++ b/examples/client_demo.py @@ -0,0 +1,55 @@ +import requests + +BASE = "http://127.0.0.1:8010" + +def _post(path, payload): + r = requests.post(BASE + path, json=payload, timeout=30) + r.raise_for_status() + j = r.json() + if "success" in j: + if not j.get("success"): + raise RuntimeError(f"api error: {j}") + return j.get("data") + return j + +def _get(path): + r = requests.get(BASE + path, timeout=15) + r.raise_for_status() + j = r.json() + if "success" in j: + if not j.get("success"): + raise RuntimeError(f"api error: {j}") + return j.get("data") + return j + +def main(): + print("health:", _get("/health")) + try: + print("version:", _get("/version")) + except Exception: + pass + + data = _post("/query", {"query": "我想买苹果"}) + sid = data["session_id"] + print("created session:", sid) + print("candidates:", data["candidates"]) + + # choose first candidate + if data["candidates"]: + choice = data["candidates"][0] + ans = _post("/select", {"session_id": sid, "choice": choice}) + print("answer:", ans["answer"][:200]) + + # continue optimization + more = _post("/query_from_message", {"session_id": sid}) + print("next candidates:", more["candidates"]) + + # chat + chat = _post("/message", {"session_id": sid, "message": "还有更甜的苹果吗?"}) + print("chat answer:", chat["answer"][:200]) + + # list sessions + print("sessions:", _get("/sessions")) + +if __name__ == "__main__": + main() diff --git a/frontend/index.html b/frontend/index.html new file mode 100644 index 0000000..d6de7ea --- /dev/null +++ b/frontend/index.html @@ -0,0 +1,446 @@ + + + + + + OPRO Prompt Optimizer + + + + +
+
+
+ +
+
+
+
选择模型
+
+ + +
+
+
+ +
+
+

OPRO Prompt Optimizer

+
+
+
+
+ + +
+
+
+ +
+
+

会话信息

+
+
+
+
+
当前轮次
+
Round 0
+
+
+
已选提示词
+
暂未选择
+
+
+
操作提示
+
+
• 输入问题后会生成候选提示词
+
• 点击"选择"使用该提示词
+
• 点击"拒绝"生成新候选
+
• 点击"继续优化"获取更多选项
+
+
+
+
+ +
+
+ + + + + + diff --git a/frontend/react-app.html b/frontend/react-app.html new file mode 100644 index 0000000..0fb1649 --- /dev/null +++ b/frontend/react-app.html @@ -0,0 +1,164 @@ + + + + + + OPRO React 界面 + + + + + + + +
+ + + diff --git a/frontend/react/index.html b/frontend/react/index.html new file mode 100644 index 0000000..06bbb01 --- /dev/null +++ b/frontend/react/index.html @@ -0,0 +1,192 @@ + + + + + + OPRO React 界面 + + + + + + + +
+ + + diff --git a/frontend/ui_offline.html b/frontend/ui_offline.html new file mode 100644 index 0000000..48cb789 --- /dev/null +++ b/frontend/ui_offline.html @@ -0,0 +1,157 @@ + + + + + + OPRO 三栏界面(离线版) + + + +
+ +
+
+
+ + +
+
+
+
+ + +
+
+
+
+ + + diff --git a/ollama_client.py b/ollama_client.py new file mode 100644 index 0000000..a654934 --- /dev/null +++ b/ollama_client.py @@ -0,0 +1,19 @@ +import requests + +OLLAMA_URL = "http://127.0.0.1:11434/api/generate" +MODEL_NAME = "qwen3:8b" + +def call_qwen(prompt: str, temperature: float = 0.8, max_tokens: int = 512) -> str: + payload = { + "model": MODEL_NAME, + "prompt": prompt, + "stream": False, + "options": { + "temperature": temperature, + "num_predict": max_tokens + } + } + resp = requests.post(OLLAMA_URL, json=payload, timeout=60) + resp.raise_for_status() + data = resp.json() + return data.get("response", "") or data.get("text", "") diff --git a/opro_qwen_xinference_demo_package.zip b/opro_qwen_xinference_demo_package.zip new file mode 100644 index 0000000..187905f Binary files /dev/null and b/opro_qwen_xinference_demo_package.zip differ diff --git a/optimization/__pycache__/opt_utils.cpython-310.pyc b/optimization/__pycache__/opt_utils.cpython-310.pyc new file mode 100644 index 0000000..2c7a4a9 Binary files /dev/null and b/optimization/__pycache__/opt_utils.cpython-310.pyc differ diff --git a/optimization/opt_utils.py b/optimization/opt_utils.py new file mode 100644 index 0000000..36ad3bf --- /dev/null +++ b/optimization/opt_utils.py @@ -0,0 +1,1035 @@ +# Copyright 2023 The OPRO Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The utility functions for prompt optimization.""" + +import collections +import json +import os +import pickle +import re +import sys + +OPRO_ROOT_PATH = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +) +sys.path.insert(0, OPRO_ROOT_PATH) + +import numpy as np +from opro.evaluation import eval_utils +import pandas as pd + + +def extract_string_in_square_brackets(input_string): + raw_result = re.findall(r"\[.*?\]", input_string) + if raw_result: + return raw_result[0][1:-1] + else: + return "" #从输入的字符串 input_string中提取​​第一个被方括号 []包裹的内容​​,并返回去掉方括号后的结果。如果输入字符串中没有方括号,则返回空字符串 ""。 + + +def parse_tag_content(text, prefix="", suffix=""): + pattern = f"{prefix}(.*?){suffix}" + results = re.findall(pattern, text, re.DOTALL) + return results + + +def _bucketize_float(num, n_buckets=20): + assert num >= 0 and num <= 1, "The given number must be between 0 and 1." + return round(num * n_buckets) #此函数可能用于将​​指令的评分​​(如 0.85)转换为整数分桶 + + +def gen_ins_and_score_pairs_substr( + old_instructions_and_scores, + old_instruction_score_threshold=0.1, + max_num_instructions=1000, + return_str_only=False, + num_score_buckets=np.inf, +): + """Generate the string that includes instruction-score pairs.""" + assert num_score_buckets == np.inf or isinstance(num_score_buckets, int) + old_instructions_and_scores_str = "" + old_instructions_and_scores = sorted( + old_instructions_and_scores, key=lambda x: x[1] + )[-max_num_instructions:] + old_instructions_and_scores_in_meta_prompt = [] + for instruction, score, i_step in old_instructions_and_scores: + if ( + not old_instruction_score_threshold + or score >= old_instruction_score_threshold + ): + old_instructions_and_scores_in_meta_prompt.append( + (instruction, score, i_step) + ) + if num_score_buckets == np.inf: + score_to_show = round(score, 3) + else: + score_to_show = _bucketize_float(score, num_score_buckets) + old_instructions_and_scores_str += ( + f"\ntext:\n{instruction}\nscore:\n{score_to_show}\n" + ) + if return_str_only: + return old_instructions_and_scores_str + else: + return ( + old_instructions_and_scores_str, + old_instructions_and_scores_in_meta_prompt, + ) + + +def gen_meta_prompt( + old_instructions_and_scores, + instruction_pos, + optimizer_llm_name, + old_instruction_score_threshold=0.1, + max_num_instructions=1000, + meta_prompt_type="both_instructions_and_exemplars", + few_shot_qa_pairs=False, + include_qa=True, + data=None, + few_shot_index_list=None, + instructions_before_exemplars=True, + num_score_buckets=np.inf, + dataset_name="", + task_name="", +): + """Generate meta prompt for instruction rewriting. + + Args: + old_instructions_and_scores (list): a list of (instruction, score, i_step) + pairs. + instruction_pos (str): where to put the instruction, one of {'before_QA', + 'Q_begin', 'Q_end', 'A_begin'}. + optimizer_llm_name (str): the name of the LLM used for instruction editing. + old_instruction_score_threshold (float): only add old instructions with score + no less than this threshold. + max_num_instructions (int): the maximum number of instructions in the meta + prompt. + meta_prompt_type (str): the type of meta-prompt: whether to have both + previous instructions and dataset exemplars (often for fine-tuned + optimizers), or to have only previous instructions (often for pre-trained + optimizers). + few_shot_qa_pairs (bool): whether to have few-shot QA pairs in the meta + prompt. + include_qa (bool): whether to include "Q:" and "A:" formats in the prompt. + data (list or pd.DataFrame): the raw data. + few_shot_index_list (list): the list of indices of few-shot examples. + instructions_before_exemplars (bool): whether the instruction-score pairs are + before the exemplars from the dataset. + num_score_buckets (np.inf or int): the number of score buckets when we + convert float accuracies to integers. Default to np.inf for not + bucketizing. + dataset_name (str): the name of the current dataset. Only used when + generating task description when meta_prompt_type == "instructions_only". + task_name (str): the name of the current task. Only used when generating task + description when meta_prompt_type == "instructions_only". + + Returns: + meta_prompt (str): the generated meta prompt. + """ + assert instruction_pos in { + "before_Q", + "Q_begin", + "Q_end", + "A_begin", + }, ( + "The instruction position should be either before the question, or at the" + " beginning of the question, at the end of the question, or at the" + " beginning of the answer." + ) + assert meta_prompt_type in { + "both_instructions_and_exemplars", + "instructions_only", + } + assert dataset_name in { + "mmlu", + "bbh", + "gsm8k", + }, "The lower-case dataset name must be one of mmlu, bbh, gsm8k." + assert num_score_buckets == np.inf or isinstance(num_score_buckets, int) + + meta_prompt = "" + if meta_prompt_type == "both_instructions_and_exemplars": + if optimizer_llm_name.lower() in {"gpt-3.5-turbo", "gpt-4","local"}: # GPT模型的提示模板 + if instruction_pos == "A_begin":# 针对答案开头部分的描述 + meta_prompt_old_instruction_part = ( + "Your task is to generate the answer starting sentence ." + " Below are some previous starting sentences with their scores." + " The score ranges from 0 to 100.\n" + )#要求模型生成答案的起始句(如“The answer is...”),并参考历史评分数据。 + else:# 针对普通指令的描述 + meta_prompt_old_instruction_part = ( + "Your task is to generate the instruction ." + " Below are some previous instructions with their scores." + " The score ranges from 0 to 100.\n" + ) + else: # text-bison 模型的专用提示模板 + assert optimizer_llm_name.lower() == "text-bison" + meta_prompt_old_instruction_part = ( + "I have some texts along with their corresponding scores." + " The texts are arranged in ascending order based on their scores," + " where higher scores indicate better quality.\n\n" + ) + # add old instructions + old_instructions_and_scores_str = gen_ins_and_score_pairs_substr( + old_instructions_and_scores=old_instructions_and_scores, + old_instruction_score_threshold=old_instruction_score_threshold, + max_num_instructions=max_num_instructions, + return_str_only=True, + num_score_buckets=num_score_buckets, + ) + meta_prompt_old_instruction_part += old_instructions_and_scores_str + # add QA pairs if few_shot_qa_pairs == True + meta_prompt_exemplar_part = "" + if few_shot_qa_pairs: + if optimizer_llm_name.lower() in {"gpt-3.5-turbo", "gpt-4", "local"}: + meta_prompt_exemplar_part += "Below are some problems.\n" + else: + assert optimizer_llm_name.lower() == "text-bison" + meta_prompt_exemplar_part += ( + "The following exemplars show how to apply your text: you replace" + " in each input with your text, then read the input and give" + " an output. We say your output is wrong if your output is" + " different from the given output, and we say your output is" + " correct if they are the same. When replacing with an old" + " piece of text above, we get wrong outputs on the following" + " inputs.\n\n" + ) + for idx in few_shot_index_list: + if dataset_name == "mmlu": + question = eval_utils._format_mmlu_example(data, idx) # pylint: disable=protected-access + true_answer = data.iloc[idx, -1] + elif dataset_name == "bbh": + question = data[idx]["input"] + true_answer = data[idx]["target"] + else: + assert dataset_name == "gsm8k" + question = data.iloc[idx, 0] + true_answer = data.iloc[idx, 1] + + if include_qa: # when "Q:" and "A:" are present in the prompt + if instruction_pos == "before_Q": + meta_prompt_exemplar_part += f"\ninput:\n\nQ: {question}\nA:" + elif instruction_pos == "Q_begin": + meta_prompt_exemplar_part += f"\ninput:\nQ: \n{question}\nA:" + elif instruction_pos == "Q_end": + meta_prompt_exemplar_part += f"\ninput:\nQ: {question}\n\nA:" + else: # instruction_pos == "A_begin" + if optimizer_llm_name.lower() in {"gpt-3.5-turbo", "gpt-4", "local"}: + meta_prompt_exemplar_part += f"\nQ: {question}\nA: " + else: + assert optimizer_llm_name.lower() == "text-bison" + meta_prompt_exemplar_part += f"\ninput:\nQ: {question}\nA: " + else: # when there're no "Q:" and "A:" in the prompt + assert instruction_pos in {"Q_begin", "Q_end"} + if optimizer_llm_name.lower() in {"gpt-3.5-turbo", "gpt-4", "local"}: + if instruction_pos == "Q_begin": + meta_prompt_exemplar_part += f"\nProblem:\n\n{question}\n" + elif instruction_pos == "Q_end": + meta_prompt_exemplar_part += f"\nProblem:\n{question}\n\n" + else: + assert optimizer_llm_name.lower() == "text-bison" + if instruction_pos == "Q_begin": + meta_prompt_exemplar_part += f"\ninput:\n\n{question}\n" + elif instruction_pos == "Q_end": + meta_prompt_exemplar_part += f"\ninput:\n{question}\n\n" + + if optimizer_llm_name.lower() in {"gpt-3.5-turbo", "gpt-4", "local"}: + meta_prompt_exemplar_part += ( + f"\nGround truth answer:\n{true_answer}\n" + ) + else: + assert optimizer_llm_name.lower() == "text-bison" + meta_prompt_exemplar_part += f"\noutput:\n{true_answer}\n" + + if few_shot_qa_pairs: + if instructions_before_exemplars: + meta_prompt += ( + meta_prompt_old_instruction_part + + "\n\n" + + meta_prompt_exemplar_part + ) + else: + meta_prompt += ( + meta_prompt_exemplar_part + + "\n\n" + + meta_prompt_old_instruction_part + ) + else: + meta_prompt += meta_prompt_old_instruction_part + + if optimizer_llm_name.lower() in {"gpt-3.5-turbo", "gpt-4", "local"}: + if instruction_pos == "A_begin": + meta_prompt += ( + "\n\nGenerate a starting sentence that is different from all the" + " sentences above, and has a higher score than all the" + " sentences above. The starting sentence should begin with" + " and end with . The starting sentence should be" + " concise, effective, and generally applicable to all QA pairs" + " above." + ) + else: + meta_prompt += ( + "\n\nGenerate an instruction that" + " is different from all the instructions above," + " and has a higher score than all the instructions above." + " The instruction should begin with and end with ." + " The instruction should be concise, effective," + " and generally applicable to all problems above." + ) + else: + assert optimizer_llm_name.lower() == "text-bison" + meta_prompt += ( + "\n\nWrite your new text that is different from the old ones and" + " has a score as high as possible. Write the text in square brackets." + ) + else: + # when using a pre-trained model as optimizer + assert meta_prompt_type == "instructions_only" + + assert instruction_pos in {"Q_begin", "Q_end", "A_begin"} + if instruction_pos == "Q_begin": + instruction_pos_description = "at the beginning of the question" + elif instruction_pos == "Q_end": + instruction_pos_description = "at the end of the question" + else: + assert instruction_pos == "A_begin" + instruction_pos_description = "at the beginning of the answer" + + if dataset_name == "gsm8k": + instruction_task_description = "grade school math" + elif dataset_name == "mmlu": + instruction_task_description = task_name + else: + assert dataset_name == "bbh" + instruction_task_description = " ".join(task_name.split("_")) + + meta_instruction = ( + f"Create a piece of text {instruction_pos_description.strip()} to" + " enhance the precision in solving diverse" + f" {instruction_task_description.strip()} problems." + ) + old_instructions_and_scores = sorted( + old_instructions_and_scores, key=lambda x: x[1] + ) + old_instructions_and_scores_str = "" + for instruction, score, _ in old_instructions_and_scores: + if num_score_buckets == np.inf: + score_to_show = round(score, 2) + else: + score_to_show = _bucketize_float(score, num_score_buckets) + old_instructions_and_scores_str += ( + f"\n\nPrecision: {score_to_show} {instruction}" + ) + meta_prompt += meta_instruction + old_instructions_and_scores_str + return meta_prompt + + +def run_evolution(**kwargs): + """The function for evolution.""" + # ================= experiment configurations ============================= + num_search_steps = kwargs["num_search_steps"] + old_instruction_score_threshold = kwargs["old_instruction_score_threshold"] + scorer_llm_dict = kwargs["scorer_llm_dict"] + optimizer_llm_dict = kwargs["optimizer_llm_dict"] + extract_final_answer_by_prompting_again = kwargs[ + "extract_final_answer_by_prompting_again" + ] + include_qa = kwargs["include_qa"] + evaluate_in_parallel = kwargs["evaluate_in_parallel"] + tasks_all = kwargs["tasks_all"] + train_ratio = kwargs["train_ratio"] + eval_ratio = kwargs["eval_ratio"] + test_ratio = kwargs["test_ratio"] + train_index = kwargs["train_index"] + eval_index = kwargs["eval_index"] + dataset_name = kwargs["dataset_name"] + task_name = kwargs["task_name"] + num_examples = kwargs["num_examples"] + root_data_folder_path = kwargs["root_data_folder_path"] + optimizer_llm_temperature = kwargs["optimizer_llm_temperature"] + optimizer_llm_temperature_schedule = ( + kwargs["optimizer_llm_temperature_schedule"] + if "optimizer_llm_temperature_schedule" in kwargs + else "constant" + ) + optimizer_llm_temperature_end = ( + kwargs["optimizer_llm_temperature_end"] + if "optimizer_llm_temperature_end" in kwargs + else None + ) + initial_instructions = kwargs["initial_instructions"] + multiple_choice_tasks = kwargs["multiple_choice_tasks"] + raw_data = kwargs["raw_data"] + call_scorer_server_func = kwargs["call_scorer_server_func"] + call_optimizer_server_func = kwargs["call_optimizer_server_func"] + instruction_pos = kwargs["instruction_pos"] + prediction_treat_as_number = kwargs["prediction_treat_as_number"] + prediction_treat_as_bool = kwargs["prediction_treat_as_bool"] + result_by_instruction_folder = kwargs["result_by_instruction_folder"] + few_shot_qa_pairs = kwargs["few_shot_qa_pairs"] + num_score_buckets = kwargs["num_score_buckets"] + max_num_instructions = kwargs["max_num_instructions"] + meta_prompt_type = kwargs["meta_prompt_type"] + meta_prompt_instructions_before_exemplars = kwargs[ + "meta_prompt_instructions_before_exemplars" + ] + few_shot_selection_criteria = kwargs["few_shot_selection_criteria"] + optimizer_llm_name = kwargs["optimizer_llm_name"] + num_generated_instructions_in_each_step = kwargs[ + "num_generated_instructions_in_each_step" + ] + evaluate_generated_ins_on_few_shot = kwargs[ + "evaluate_generated_ins_on_few_shot" + ] + num_few_shot_questions_for_instruction_refinement = kwargs[ + "num_few_shot_questions_for_instruction_refinement" + ] + evaluate_old_ins_on_few_shot = kwargs["evaluate_old_ins_on_few_shot"] + eval_interval = kwargs["eval_interval"] + save_folder = kwargs["save_folder"] + verbose = kwargs["verbose"] if "verbose" in kwargs else False + + # =================== assertions ===================== + assert dataset_name in { + "mmlu", + "bbh", + "gsm8k", + }, "The lower-case dataset name must be one of mmlu, bbh, gsm8k." + assert optimizer_llm_temperature_schedule in { + "constant", + "linear_increase", + }, "The temperature schedule should be constant or linear_increase." + + # =================== save configurations to json file ==================== + configs_dict = dict() + configs_dict["scorer_llm_dict"] = scorer_llm_dict + configs_dict["optimizer_llm_dict"] = optimizer_llm_dict + configs_dict["instruction_pos"] = instruction_pos + configs_dict["optimizer_llm_temperature"] = optimizer_llm_temperature + configs_dict["optimizer_llm_temperature_schedule"] = ( + optimizer_llm_temperature_schedule + ) + configs_dict["optimizer_llm_temperature_end"] = optimizer_llm_temperature_end + with open(os.path.join(save_folder, "configs_dict.json"), "w") as f: + json.dump(configs_dict, f, indent=4) + + num_servers = scorer_llm_dict["num_servers"] + batch_size = scorer_llm_dict["batch_size"] + generated_ins_on_few_shot_results_dict = dict() + old_ins_on_few_shot_results_dict = dict() + # evaluation results every a few steps + # format: [(i_step, instruction, detailed_results_df)] + eval_results = [] + # all generated instructions, format: [(instruction, score, step_index)] + # the instructions that were skipped have score NaN + old_instructions_and_scores_raw = [] + # the new instructions, format: [(instruction, score, step_index)] + old_instructions_and_scores = [] + meta_prompts = [] # format: [(meta_prompt, step_index)] + instruction_score_dict = dict() # the dictionary of {instruction: score} + # the dictionary of the few-shot QA indices in meta-prompt + # key: step index; value: the list of few-shot indices in that step + few_shot_index_list_by_step_dict = dict() + detailed_results_df_by_instruction_dict = dict() + wrong_questions_from_start_counter = collections.Counter() + # EVAL results + eval_detailed_results_df_dict = dict() # {instruction: detailed_results_df} + instruction_eval_score_dict = dict() # {instruction: eval_score} + old_instruction_md5_hashstrings_set = set() + + print(f"tasks_all: {tasks_all}") + print( + f"train_ratio: {train_ratio}, number of training points:" + f" {int(num_examples * train_ratio)}" + ) + print( + f"eval_ratio: {eval_ratio}, number of eval points: " + f"{int(num_examples * eval_ratio)}" + ) + print( + f"test_ratio: {test_ratio}, number of test points: " + f"{int(num_examples * test_ratio)}" + ) + print( + f"optimizer llm temperature: {optimizer_llm_temperature}, schedule:" + f" {optimizer_llm_temperature_schedule}" + ) + print( + f"generating {num_generated_instructions_in_each_step} instructions in" + f" each step, run for {num_search_steps} steps" + ) + print( + "discarding generated instructions with score less than:" + f" {old_instruction_score_threshold} (old_instruction_score_threshold)" + ) + print(f"num_score_buckets: {num_score_buckets}") + + if dataset_name == "mmlu": + is_multiple_choice = True + is_multiple_choice_eval = True + elif dataset_name in {"gsm8k"}: + is_multiple_choice = False + is_multiple_choice_eval = False + else: + assert dataset_name == "bbh" + is_multiple_choice = [] + is_multiple_choice_eval = [] + train_index_by_task_dict = dict() + eval_index_by_task_dict = dict() + start_index = 0 + for task_name in tasks_all: + single_task_list = eval_utils.load_bbh_task_data( + task_name, base_dir=root_data_folder_path + ) + end_index = start_index + len(single_task_list) + train_index_by_task_dict[task_name] = ( + train_index[(train_index >= start_index) & (train_index < end_index)] + # if " - start_index" is added here, then the dict would contain + # indices in the original task + ) + eval_index_by_task_dict[task_name] = ( + eval_index[(eval_index >= start_index) & (eval_index < end_index)] + # if " - start_index" is added here, then the dict would contain + # indices in the original task + ) + start_index = end_index + is_multiple_choice_single_task_train = [ + task_name in multiple_choice_tasks + ] * len(train_index_by_task_dict[task_name]) + is_multiple_choice_single_task_eval = [ + task_name in multiple_choice_tasks + ] * len(eval_index_by_task_dict[task_name]) + is_multiple_choice += is_multiple_choice_single_task_train + is_multiple_choice_eval += is_multiple_choice_single_task_eval + + prev_saved_instructions = set() + + # evaluate initial instructions + print("\n============== evaluating initial instructions ===============") + for instruction in initial_instructions: + print(f"""computing the score of "{instruction}" by prompting""") + + detailed_results_df = eval_utils.evaluate_single_instruction( + data=raw_data, + instruction=instruction, + eval_index_all=train_index, + batch_size=batch_size, + call_server_func=call_scorer_server_func, + dataset_name=dataset_name, + num_servers=num_servers, + extract_final_answer_by_prompting_again=extract_final_answer_by_prompting_again, + include_qa=include_qa, + evaluate_in_parallel=evaluate_in_parallel, + instruction_pos=instruction_pos, + is_multiple_choice=is_multiple_choice, + prediction_treat_as_number=prediction_treat_as_number, + prediction_treat_as_bool=prediction_treat_as_bool, + prediction_num_decimals=0, + max_retry=120, + sleep_time=60, + verbose=verbose, + ) + + detailed_results_df_by_instruction_dict[instruction] = detailed_results_df + scores = detailed_results_df["accuracy"] + average_score = np.average(scores) + print(f"instruction: {instruction}, score: {average_score}") + filename = eval_utils.instruction_to_filename(instruction) + file_path = os.path.join(result_by_instruction_folder, f"{filename}.csv") + detailed_results_df.to_csv(file_path, index=True, header=True) + print(f"""saving results of "{instruction}" to {file_path}""") + old_instructions_and_scores.append((instruction, average_score, -1)) + old_instructions_and_scores_raw.append((instruction, average_score, -1)) + instruction_score_dict[instruction] = average_score + + # increment the counter on wrong questions + wrong_question_indices_set = set( + list( + detailed_results_df.iloc[ + np.where(detailed_results_df.accuracy == 0.0)[0], : + ].index + ) + ) + for idx in wrong_question_indices_set: + wrong_questions_from_start_counter[idx] += 1 + + # evolution + for i_step in range(num_search_steps): + print(f"\n================== Step {i_step} =====================") + if not i_step % 10: + print(f"old_instructions_and_scores: {old_instructions_and_scores}") + + if optimizer_llm_temperature_schedule == "linear_increase": + optimizer_llm_temperature_curr = ( + optimizer_llm_temperature + + i_step + / num_search_steps + * (optimizer_llm_temperature_end - optimizer_llm_temperature) + ) + else: + optimizer_llm_temperature_curr = optimizer_llm_temperature + print( + f"current optimizer_llm_temperature: {optimizer_llm_temperature_curr}" + ) + + # generate new instructions + if few_shot_qa_pairs: + if few_shot_selection_criteria == "accumulative_most_frequent": + # select QA pairs that were done wrong the most number of times + most_frequent_wrong_question_indices = [ + k + for k, _ in sorted( + wrong_questions_from_start_counter.items(), key=lambda x: -x[1] + ) + ] + print( + "len(most_frequent_wrong_question_indices):" + f" {len(most_frequent_wrong_question_indices)}" + ) + if ( + len(most_frequent_wrong_question_indices) + <= num_few_shot_questions_for_instruction_refinement + ): + few_shot_index_list = most_frequent_wrong_question_indices.copy() + else: + np.random.seed(i_step) + few_shot_index_list = np.sort( + np.random.choice( + most_frequent_wrong_question_indices, + num_few_shot_questions_for_instruction_refinement, + replace=False, + ) + ) + + elif few_shot_selection_criteria == "current_most_frequent": + # show exemplars done wrong most often by currently shown instructions + old_instruction_score_threshold_single_step = ( + old_instruction_score_threshold if i_step > 0 else 0 + ) + _, old_instructions_and_scores_in_meta_prompt = ( + gen_ins_and_score_pairs_substr( + old_instructions_and_scores=old_instructions_and_scores, + old_instruction_score_threshold=old_instruction_score_threshold_single_step, + max_num_instructions=max_num_instructions, + return_str_only=False, + num_score_buckets=num_score_buckets, + ) + ) + wrong_questions_counter_single_step = collections.Counter() + for ins, _, _ in old_instructions_and_scores_in_meta_prompt: + filename = eval_utils.instruction_to_filename(ins) + file_path = os.path.join( + result_by_instruction_folder, f"{filename}.csv" + ) + single_ins_df = pd.read_csv(file_path, index_col=0, header=0) + wrong_question_indices_set_single_old_ins = set( + list( + single_ins_df.iloc[ + np.where(single_ins_df.accuracy == 0.0)[0], : + ].index + ) + ) + for idx in wrong_question_indices_set_single_old_ins: + wrong_questions_counter_single_step[idx] += 1 + most_occurred_wrong_questions = [ + k + for k, v in wrong_questions_counter_single_step.items() + if v == max(wrong_questions_counter_single_step.values()) + ] + if ( + len(most_occurred_wrong_questions) + < num_few_shot_questions_for_instruction_refinement + ): + # pylint: disable=cell-var-from-loop + idx_most_to_least = sorted( + wrong_questions_counter_single_step, + key=lambda x: -wrong_questions_counter_single_step[x], + ) + few_shot_index_list = idx_most_to_least[ + :num_few_shot_questions_for_instruction_refinement + ] + else: + few_shot_index_list = np.sort( + np.random.choice( + most_occurred_wrong_questions, + num_few_shot_questions_for_instruction_refinement, + replace=False, + ) + ) + elif few_shot_selection_criteria == "constant": + np.random.seed(0) + few_shot_index_list = np.sort( + np.random.choice( + train_index, + num_few_shot_questions_for_instruction_refinement, + replace=False, + ) + ) + else: + assert few_shot_selection_criteria == "random" + np.random.seed(i_step) + few_shot_index_list = np.sort( + np.random.choice( + train_index, + num_few_shot_questions_for_instruction_refinement, + replace=False, + ) + ).tolist() + + few_shot_index_list_by_step_dict[i_step] = few_shot_index_list + + meta_prompt = gen_meta_prompt( + old_instructions_and_scores=old_instructions_and_scores, + instruction_pos=instruction_pos, + optimizer_llm_name=optimizer_llm_name, + old_instruction_score_threshold=old_instruction_score_threshold, + max_num_instructions=max_num_instructions, + meta_prompt_type=meta_prompt_type, + few_shot_qa_pairs=few_shot_qa_pairs, + include_qa=include_qa, + data=raw_data, + few_shot_index_list=few_shot_index_list, + instructions_before_exemplars=meta_prompt_instructions_before_exemplars, + num_score_buckets=num_score_buckets, + dataset_name=dataset_name, + task_name=task_name, + ) + + else: # no few-shot exemplars in meta-prompt + few_shot_index_list = [] + meta_prompt = gen_meta_prompt( + old_instructions_and_scores=old_instructions_and_scores, + instruction_pos=instruction_pos, + optimizer_llm_name=optimizer_llm_name, + old_instruction_score_threshold=old_instruction_score_threshold, + max_num_instructions=max_num_instructions, + meta_prompt_type=meta_prompt_type, + few_shot_qa_pairs=False, + include_qa=include_qa, + instructions_before_exemplars=meta_prompt_instructions_before_exemplars, + num_score_buckets=num_score_buckets, + dataset_name=dataset_name, + task_name=task_name, + ) + print(f"\nmeta_prompt: \n\n{meta_prompt}\n") + meta_prompts.append((meta_prompt, i_step)) + remaining_num_instructions_to_generate = ( + num_generated_instructions_in_each_step + ) + generated_instructions_raw = [] + while remaining_num_instructions_to_generate > 0: + optimizer_llm_input_text = meta_prompt + # generate instructions + print(f"current temperature: {optimizer_llm_temperature_curr}") + raw_outputs = call_optimizer_server_func( + optimizer_llm_input_text, + temperature=optimizer_llm_temperature_curr, + ) + + # Extract the generated instructions from the optimizer LLM output. Only + # keep some samples if the desired number of remaining instructions + # is smaller than the total number of decodes in this step. + if meta_prompt_type == "both_instructions_and_exemplars": + raw_outputs = raw_outputs[:remaining_num_instructions_to_generate] + if optimizer_llm_name.lower() in {"gpt-3.5-turbo", "gpt-4", "local"}: + if instruction_pos == "A_begin": + start_string = "" + end_string = "" + else: + start_string = "+" + end_string = "" + for raw_output in raw_outputs: + if start_string not in raw_output: + start_index = 0 + else: + start_index = raw_output.index(start_string) + len(start_string) + if end_string not in raw_output: + end_index = len(raw_output) + else: + end_index = raw_output.index(end_string) + new_inst = raw_output[start_index:end_index].strip() + generated_instructions_raw.append(new_inst) + else: + assert optimizer_llm_name.lower() == "text-bison" + generated_instructions_raw += [ + extract_string_in_square_brackets(string) + for string in raw_outputs + ] + + remaining_num_instructions_to_generate -= optimizer_llm_dict[ + "batch_size" + ] + else: + assert meta_prompt_type == "instructions_only" + max_num_instructions_to_keep_in_each_output = 1 + for string in raw_outputs: + generated_instructions_raw += parse_tag_content(string)[ + :max_num_instructions_to_keep_in_each_output + ] + remaining_num_instructions_to_generate -= ( + optimizer_llm_dict["batch_size"] + * max_num_instructions_to_keep_in_each_output + ) + + generated_instructions_raw = list( + map(eval_utils.polish_sentence, generated_instructions_raw) + ) + print(f"\ninitially generated instructions: {generated_instructions_raw}\n") + + # do not evaluate old instructions again + generated_instructions = [] # the new instructions generated in this step + for ins in generated_instructions_raw: + ins_md5_hashstring = eval_utils.instruction_to_filename( + ins, md5_hashing=True + ) + if ins_md5_hashstring not in old_instruction_md5_hashstrings_set: + generated_instructions.append(ins) + old_instruction_md5_hashstrings_set.add(ins_md5_hashstring) + else: + print(f"already evaluated '{ins}' previously") + generated_instructions = list(set(generated_instructions)) + + to_evaluate_instructions = [] + for instruction in generated_instructions: + if len(instruction) > 500: + print(f"Step {i_step}, instruction: {instruction}, too long, skipped") + continue + if dataset_name == "gsm8k" and any( + char.isdigit() for char in instruction + ): + print( + f"Step {i_step}, instruction: {instruction}, contains numbers," + " skipped" + ) + continue + if "INS" in instruction: + print( + f"Step {i_step}, instruction: {instruction}, contains 'INS'," + " skipped" + ) + continue + to_evaluate_instructions.append(instruction) + print(f"\nto-evaluate generated instructions: {to_evaluate_instructions}\n") + + # evaluate new instructions on the few-shot exemplars in meta-prompt + if few_shot_qa_pairs and evaluate_generated_ins_on_few_shot: + print("evaluating GENERATED instructions on few-shot exemplars") + single_step_eval_on_few_shot = dict() + for instruction in to_evaluate_instructions: + if instruction not in prev_saved_instructions: + print( + f"evaluating Step {i_step}, instruction: {instruction} on" + " few-shot exemplars" + ) + detailed_results_df = eval_utils.evaluate_single_instruction( + data=raw_data, + instruction=instruction, + eval_index_all=few_shot_index_list, + batch_size=batch_size, + call_server_func=call_scorer_server_func, + dataset_name=dataset_name, + num_servers=num_servers, + extract_final_answer_by_prompting_again=extract_final_answer_by_prompting_again, + include_qa=include_qa, + evaluate_in_parallel=evaluate_in_parallel, + instruction_pos=instruction_pos, + is_multiple_choice=is_multiple_choice, + prediction_treat_as_number=prediction_treat_as_number, + prediction_treat_as_bool=prediction_treat_as_bool, + prediction_num_decimals=0, + max_retry=5, + sleep_time=180, + verbose=verbose, + ) + single_step_eval_on_few_shot[instruction] = detailed_results_df + + print( + f"Step {i_step}, single_step_eval_on_few_shot:" + f" {single_step_eval_on_few_shot}\n" + ) + generated_ins_on_few_shot_results_dict[i_step] = ( + single_step_eval_on_few_shot + ) + + # evaluate OLD instructions on the few-shot exemplars in meta-prompt + if few_shot_qa_pairs and evaluate_old_ins_on_few_shot: + print("evaluating OLD instructions on few-shot exemplars") + single_step_eval_on_few_shot = dict() + for instruction, _, _ in old_instructions_and_scores: + print( + f"evaluating Step {i_step}, instruction: {instruction} on few-shot" + " exemplars" + ) + detailed_results_df = eval_utils.evaluate_single_instruction( + data=raw_data, + instruction=instruction, + eval_index_all=few_shot_index_list, + batch_size=scorer_llm_dict["batch_size"], + call_server_func=call_scorer_server_func, + dataset_name=dataset_name, + num_servers=scorer_llm_dict["num_servers"], + extract_final_answer_by_prompting_again=extract_final_answer_by_prompting_again, + include_qa=include_qa, + evaluate_in_parallel=evaluate_in_parallel, + instruction_pos=instruction_pos, + is_multiple_choice=is_multiple_choice, + prediction_treat_as_number=prediction_treat_as_number, + prediction_treat_as_bool=prediction_treat_as_bool, + prediction_num_decimals=0, + max_retry=5, + sleep_time=180, + verbose=verbose, + ) + single_step_eval_on_few_shot[instruction] = detailed_results_df + + print( + f"Step {i_step}, single_step_eval_on_few_shot:" + f" {single_step_eval_on_few_shot}\n" + ) + old_ins_on_few_shot_results_dict[i_step] = single_step_eval_on_few_shot + + # evaluate newly generated instructions on the training set + for instruction in to_evaluate_instructions: + if instruction not in prev_saved_instructions: + print(f"""computing the score of "{instruction}" by prompting""") + detailed_results_df = eval_utils.evaluate_single_instruction( + data=raw_data, + instruction=instruction, + eval_index_all=train_index, + batch_size=batch_size, + call_server_func=call_scorer_server_func, + dataset_name=dataset_name, + num_servers=num_servers, + extract_final_answer_by_prompting_again=extract_final_answer_by_prompting_again, + include_qa=include_qa, + evaluate_in_parallel=evaluate_in_parallel, + instruction_pos=instruction_pos, + is_multiple_choice=is_multiple_choice, + prediction_treat_as_number=prediction_treat_as_number, + prediction_treat_as_bool=prediction_treat_as_bool, + prediction_num_decimals=0, + max_retry=5, + sleep_time=180, + verbose=verbose, + ) + prev_saved_instructions.add(instruction) + else: + # do not re-evaluate instructions that had been evaluated previously + detailed_results_df = pd.read_csv( + os.path.join(result_by_instruction_folder, f"{instruction}.csv"), + index_col=0, + header=0, + ) + print(f"""reading previously saved "{instruction}" information""") + + scores = detailed_results_df["accuracy"] + average_score = np.average(scores) + print( + f"Step {i_step}, instruction: {instruction}, score: {average_score}" + ) + + # increment the counter on wrong questions + wrong_question_indices_set = set( + list( + detailed_results_df[detailed_results_df["accuracy"] == 0.0].index + ) + ) + for idx in wrong_question_indices_set: + wrong_questions_from_start_counter[idx] += 1 + + filename = eval_utils.instruction_to_filename(instruction) + file_path = os.path.join( + result_by_instruction_folder, f"""{filename}.csv""" + ) + detailed_results_df.to_csv(file_path, index=True, header=True) + print(f"saving results to {file_path}") + + detailed_results_df_by_instruction_dict[instruction] = detailed_results_df + old_instructions_and_scores.append((instruction, average_score, i_step)) + instruction_score_dict[instruction] = average_score + + # record all generated instructions + for instruction in generated_instructions_raw: + if instruction in instruction_score_dict: + average_score = instruction_score_dict[instruction] + else: + average_score = np.nan + old_instructions_and_scores_raw.append( + (instruction, average_score, i_step) + ) + + # =============================== eval ==================================== + # every eval_interval steps, evaluate the instructions that were generated + # in the current step and were not skipped + if not i_step % eval_interval: + for instruction in generated_instructions_raw: + # if the instruction wasn't skipped in any step + if instruction in instruction_score_dict: + if instruction not in instruction_eval_score_dict: + detailed_results_df = eval_utils.evaluate_single_instruction( + data=raw_data, + instruction=instruction, + eval_index_all=eval_index, + batch_size=batch_size, + call_server_func=call_scorer_server_func, + dataset_name=dataset_name, + num_servers=num_servers, + extract_final_answer_by_prompting_again=extract_final_answer_by_prompting_again, + include_qa=include_qa, + evaluate_in_parallel=evaluate_in_parallel, + instruction_pos=instruction_pos, + is_multiple_choice=is_multiple_choice_eval, + prediction_treat_as_number=prediction_treat_as_number, + prediction_treat_as_bool=prediction_treat_as_bool, + prediction_num_decimals=0, + max_retry=5, + sleep_time=180, + verbose=verbose, + ) + eval_score = np.average(detailed_results_df["accuracy"]) + eval_detailed_results_df_dict[instruction] = detailed_results_df + instruction_eval_score_dict[instruction] = eval_score + else: + eval_score = instruction_eval_score_dict[instruction] + print( + f"EVAL: \nStep {i_step}, instruction: {instruction}, eval score:" + f" {eval_score:.2f}" + ) + eval_results.append((i_step, instruction, eval_score)) + + # ===================== save up-to-date results =========================== + results_dict = dict() + results_dict["meta_prompts"] = meta_prompts + results_dict["old_instructions_and_scores"] = list( + old_instructions_and_scores + ) + results_dict["old_instructions_and_scores_raw"] = list( + old_instructions_and_scores_raw + ) + results_dict["generated_ins_on_few_shot_results_dict"] = ( + generated_ins_on_few_shot_results_dict + ) + results_dict["old_ins_on_few_shot_results_dict"] = ( + old_ins_on_few_shot_results_dict + ) + results_dict["few_shot_index_list_by_step_dict"] = ( + few_shot_index_list_by_step_dict + ) + results_dict["eval_results"] = eval_results + results_dict["eval_detailed_results_df_dict"] = ( + eval_detailed_results_df_dict + ) + with open(os.path.join(save_folder, "results_dict.pkl"), "wb") as fp: + pickle.dump(results_dict, fp) + print(f"\nsaved all results to\n{save_folder}") diff --git a/optimization/optimize_instructions.py b/optimization/optimize_instructions.py new file mode 100644 index 0000000..e619ffb --- /dev/null +++ b/optimization/optimize_instructions.py @@ -0,0 +1,736 @@ + +import datetime +import functools +import os +import sys + +OPRO_ROOT_PATH = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +) +sys.path.insert(0, OPRO_ROOT_PATH) + +from absl import app +from absl import flags +import google.generativeai as palm +import numpy as np +import openai +from opro import prompt_utils +from opro.optimization import opt_utils +import pandas as pd +FLAGS = flags.FLAGS +ROOT_DATA_FOLDER_PATH = os.path.join(OPRO_ROOT_PATH, "data") + +flags.DEFINE_string("local_model_path", "", "Path to local vLLM model.") + +_OPENAI_API_KEY = flags.DEFINE_string( + "openai_api_key", "", "The OpenAI API key." +) + +_PALM_API_KEY = flags.DEFINE_string("palm_api_key", "", "The PaLM API key.") + +_SCORER = flags.DEFINE_string( + "scorer", "text-bison", "The name of the scorer LLM." +) + +_OPTIMIZER = flags.DEFINE_string( + "optimizer", "gpt-3.5-turbo", "The name of the optimizer LLM." +) + +_DATASET = flags.DEFINE_string( + "dataset", "gsm8k", "The name of dataset to search for instructions on." +) + +_TASK = flags.DEFINE_string( + "task", + "train", + "The name of task within the above dataset to search for instructions on.", +) + +_INSTRUCTION_POS = flags.DEFINE_string( + "instruction_pos", + "A_begin", + "The position of the instruction to search for.", +) + +_META_PROMPT_TYPE = flags.DEFINE_string( + "meta_prompt_type", + "both_instructions_and_exemplars", + "The type of meta-prompt: whether to have both previous instructions and" + " dataset exemplars (often for fine-tuned optimizers), or to have only" + " previous instructions (often for pre-trained optimizers).", +) + + +def main(_): + local_model_path = FLAGS.local_model_path + openai_api_key = _OPENAI_API_KEY.value + palm_api_key = _PALM_API_KEY.value + scorer_llm_name = _SCORER.value + optimizer_llm_name = _OPTIMIZER.value + dataset_name = _DATASET.value.lower() + task_name = _TASK.value + meta_prompt_type = _META_PROMPT_TYPE.value + + assert dataset_name in { + "mmlu", + "bbh", + "gsm8k", + }, "The lower-case dataset name must be one of mmlu, bbh, or gsm8k." + if dataset_name == "mmlu": + assert task_name in { + "STEM", + "humanities", + "social sciences", + "other (business, health, misc.)", + } # for now only support searching on one MMLU category + elif dataset_name == "bbh": + assert task_name in { + "boolean_expressions", + "causal_judgement", + "date_understanding", + "disambiguation_qa", + "dyck_languages", + "formal_fallacies", + "geometric_shapes", + "hyperbaton", + "logical_deduction_five_objects", + "logical_deduction_seven_objects", + "logical_deduction_three_objects", + "movie_recommendation", + "multistep_arithmetic_two", + "navigate", + "object_counting", + "penguins_in_a_table", + "reasoning_about_colored_objects", + "ruin_names", + "salient_translation_error_detection", + "snarks", + "sports_understanding", + "temporal_sequences", + "tracking_shuffled_objects_five_objects", + "tracking_shuffled_objects_seven_objects", + "tracking_shuffled_objects_three_objects", + "web_of_lies", + "word_sorting", + } + else: + assert dataset_name == "gsm8k" + assert task_name in {"train", "test"} + + assert scorer_llm_name in { + "text-bison", + "gpt-3.5-turbo", + "gpt-4", + "local", + } + assert optimizer_llm_name in { + "text-bison", + "gpt-3.5-turbo", + "gpt-4", + "local", + } + assert meta_prompt_type in { + "both_instructions_and_exemplars", + "instructions_only", + } + + instruction_pos = _INSTRUCTION_POS.value + assert instruction_pos in { + "before_Q", + "Q_begin", + "Q_end", + "A_begin", + }, ( + "The instruction position should be either before the question, or at the" + " beginning of the question, at the end of the question, or at the" + " beginning of the answer." + ) + print( + f"scorer: {scorer_llm_name}, optimizer: {optimizer_llm_name}, dataset:" + f" {dataset_name}, task: {task_name}, instruction_pos: {instruction_pos}" + ) + + + if scorer_llm_name in {"gpt-3.5-turbo", "gpt-4"}: + assert openai_api_key, "The OpenAI API key must be provided." + openai.api_key = openai_api_key + elif scorer_llm_name == "text-bison": + assert palm_api_key, "A PaLM API key is needed when prompting the text-bison model." + palm.configure(api_key=palm_api_key) + elif scorer_llm_name == "local": + # 本地模型,无需 API key + pass + else: + raise ValueError(f"Unknown scorer model: {scorer_llm_name}") + + if optimizer_llm_name in {"gpt-3.5-turbo", "gpt-4"}: + assert openai_api_key, "The OpenAI API key must be provided." + openai.api_key = openai_api_key + elif optimizer_llm_name == "text-bison": + assert palm_api_key, "A PaLM API key is needed when prompting the text-bison model." + palm.configure(api_key=palm_api_key) + elif optimizer_llm_name == "local": + # 本地模型,无需 API key + pass + else: + raise ValueError(f"Unknown scorer model: {optimizer_llm_name}") + + + if dataset_name == "mmlu": + root_data_folder_path = os.path.join(ROOT_DATA_FOLDER_PATH, "MMLU-data") + elif dataset_name == "bbh": + root_data_folder_path = os.path.join( + ROOT_DATA_FOLDER_PATH, "BIG-Bench-Hard-data/" + ) + else: + assert dataset_name == "gsm8k" + root_data_folder_path = os.path.join(ROOT_DATA_FOLDER_PATH, "gsm_data") + + # =================== create the result directory ========================== + datetime_str = ( + str(datetime.datetime.now().replace(microsecond=0)) + .replace(" ", "-") + .replace(":", "-") + ) + + + save_folder = os.path.join( + OPRO_ROOT_PATH, + "outputs", + "optimization-results", + f"{dataset_name.upper()}-{task_name}-s-{scorer_llm_name}-o-{optimizer_llm_name}-{datetime_str}/", +) + + + result_by_instruction_folder = os.path.join( + save_folder, "result_by_instruction" + ) + print(f"Results will be saved to: {os.path.abspath(result_by_instruction_folder)}") + os.makedirs(result_by_instruction_folder,exist_ok=True) + print(f"result directory:\n{save_folder}") + + # ====================== scorer model configs ============================== + + + if scorer_llm_name == "text-bison": + # when prompting text-bison with Cloud API + scorer_finetuned_palm_temperature = 0.0 + scorer_finetuned_palm_max_decode_steps = 1024 + scorer_finetuned_palm_batch_size = 1 + scorer_finetuned_palm_num_servers = 1 + scorer_finetuned_palm_dict = dict() + scorer_finetuned_palm_dict["temperature"] = ( + scorer_finetuned_palm_temperature + ) + scorer_finetuned_palm_dict["num_servers"] = ( + scorer_finetuned_palm_num_servers + ) + scorer_finetuned_palm_dict["batch_size"] = scorer_finetuned_palm_batch_size + scorer_finetuned_palm_dict["max_decode_steps"] = ( + scorer_finetuned_palm_max_decode_steps + ) + + call_scorer_finetuned_palm_server_func = functools.partial( + prompt_utils.call_palm_server_from_cloud, + model="text-bison-001", + temperature=scorer_finetuned_palm_dict["temperature"], + max_decode_steps=scorer_finetuned_palm_dict["max_decode_steps"], + ) + + scorer_llm_dict = { + "model_type": scorer_llm_name.lower(), + } + scorer_llm_dict.update(scorer_finetuned_palm_dict) + call_scorer_server_func = call_scorer_finetuned_palm_server_func + + + elif scorer_llm_name.lower() in {"gpt-3.5-turbo", "gpt-4", "local"}: + # 改成调用本地vLLM版本的函数 + scorer_gpt_max_decode_steps = 1024 + # scorer_gpt_max_decode_steps = 512 + scorer_gpt_temperature = 0.0 + + scorer_llm_dict = { + "model_type": scorer_llm_name.lower(), + "max_decode_steps": scorer_gpt_max_decode_steps, + "temperature": scorer_gpt_temperature, + "num_decodes": 1, + "batch_size": 1, + "num_servers": 1, + } + + call_scorer_server_func = functools.partial( + prompt_utils.call_openai_server_func, # 你本地实现的vLLM调用函数 + max_decode_steps=scorer_gpt_max_decode_steps, + temperature=scorer_gpt_temperature, + local_model_path=FLAGS.local_model_path, # 传入你本地模型路径 + ) + + else: + raise ValueError(f"Unsupported scorer_llm_name: {scorer_llm_name}") + + + # ====================== optimizer model configs ============================ + if optimizer_llm_name.lower() == "text-bison": + # when prompting text-bison with Cloud API + optimizer_finetuned_palm_temperature = 1.0 + optimizer_finetuned_palm_num_decodes = 8 + optimizer_finetuned_palm_max_decode_steps = 1024 + optimizer_finetuned_palm_batch_size = 1 + optimizer_finetuned_palm_num_servers = 1 + optimizer_finetuned_palm_dict = dict() + optimizer_finetuned_palm_dict["temperature"] = ( + optimizer_finetuned_palm_temperature + ) + optimizer_finetuned_palm_dict["num_decodes"] = ( + optimizer_finetuned_palm_num_decodes + ) + optimizer_finetuned_palm_dict["batch_size"] = ( + optimizer_finetuned_palm_batch_size + ) + optimizer_finetuned_palm_dict["num_servers"] = ( + optimizer_finetuned_palm_num_servers + ) + optimizer_finetuned_palm_dict["max_decode_steps"] = ( + optimizer_finetuned_palm_max_decode_steps + ) + + call_optimizer_finetuned_palm_server_func = functools.partial( + prompt_utils.call_palm_server_from_cloud, + model="text-bison-001", + temperature=optimizer_finetuned_palm_dict["temperature"], + max_decode_steps=optimizer_finetuned_palm_dict["max_decode_steps"], + ) + + optimizer_llm_dict = { + "model_type": optimizer_llm_name.lower(), + } + optimizer_llm_dict.update(optimizer_finetuned_palm_dict) + call_optimizer_server_func = call_optimizer_finetuned_palm_server_func + + elif optimizer_llm_name.lower() in {"gpt-3.5-turbo", "gpt-4", "local"}: + # 用本地 vLLM 版本替代调用 + optimizer_gpt_max_decode_steps = 512 + + optimizer_gpt_temperature = 1.0 + + optimizer_llm_dict = { + "max_decode_steps": optimizer_gpt_max_decode_steps, + "temperature": optimizer_gpt_temperature, + "batch_size": 1, + "num_decodes": 1, + } + + call_optimizer_server_func = functools.partial( + prompt_utils.call_openai_server_func, # 你写的本地vLLM调用接口 + max_decode_steps=optimizer_gpt_max_decode_steps, + temperature=optimizer_gpt_temperature, + local_model_path=FLAGS.local_model_path, + ) + + else: + raise ValueError(f"Unsupported optimizer_llm_name: {optimizer_llm_name}") + + + + # ====================== try calling the servers ============================ + print("\n======== testing the scorer and optimizer servers ===========") + scorer_test_output = call_scorer_server_func( + "Does the sun rise from the north? Just answer yes or no." + ) + print(f"number of scorer output decodes: {len(scorer_test_output)}") + print(f"scorer test output: {scorer_test_output}") + optimizer_test_output = call_optimizer_server_func( + "Does the sun rise from the north? Just answer yes or no.", + temperature=1.0, + ) + print(f"number of optimizer output decodes: {len(optimizer_test_output)}") + print(f"optimizer test output: {optimizer_test_output}") + print("Finished testing the servers.") + + # ====================== read data ============================ + print("\n================ prompt optimization settings ==============") + # from https://github.com/hendrycks/test/blob/master/categories.py + subcategories = { + "abstract_algebra": ["math"], + "anatomy": ["health"], + "astronomy": ["physics"], + "business_ethics": ["business"], + "clinical_knowledge": ["health"], + "college_biology": ["biology"], + "college_chemistry": ["chemistry"], + "college_computer_science": ["computer science"], + "college_mathematics": ["math"], + "college_medicine": ["health"], + "college_physics": ["physics"], + "computer_security": ["computer science"], + "conceptual_physics": ["physics"], + "econometrics": ["economics"], + "electrical_engineering": ["engineering"], + "elementary_mathematics": ["math"], + "formal_logic": ["philosophy"], + "global_facts": ["other"], + "high_school_biology": ["biology"], + "high_school_chemistry": ["chemistry"], + "high_school_computer_science": ["computer science"], + "high_school_european_history": ["history"], + "high_school_geography": ["geography"], + "high_school_government_and_politics": ["politics"], + "high_school_macroeconomics": ["economics"], + "high_school_mathematics": ["math"], + "high_school_microeconomics": ["economics"], + "high_school_physics": ["physics"], + "high_school_psychology": ["psychology"], + "high_school_statistics": ["math"], + "high_school_us_history": ["history"], + "high_school_world_history": ["history"], + "human_aging": ["health"], + "human_sexuality": ["culture"], + "international_law": ["law"], + "jurisprudence": ["law"], + "logical_fallacies": ["philosophy"], + "machine_learning": ["computer science"], + "management": ["business"], + "marketing": ["business"], + "medical_genetics": ["health"], + "miscellaneous": ["other"], + "moral_disputes": ["philosophy"], + "moral_scenarios": ["philosophy"], + "nutrition": ["health"], + "philosophy": ["philosophy"], + "prehistory": ["history"], + "professional_accounting": ["other"], + "professional_law": ["law"], + "professional_medicine": ["health"], + "professional_psychology": ["psychology"], + "public_relations": ["politics"], + "security_studies": ["politics"], + "sociology": ["culture"], + "us_foreign_policy": ["politics"], + "virology": ["health"], + "world_religions": ["philosophy"], + } + + categories = { + "STEM": [ + "physics", + "chemistry", + "biology", + "computer science", + "math", + "engineering", + ], + "humanities": ["history", "philosophy", "law"], + "social sciences": [ + "politics", + "culture", + "economics", + "geography", + "psychology", + ], + "other (business, health, misc.)": ["other", "business", "health"], + } + + if dataset_name == "mmlu": + + category_names = [task_name] + folder_name = "test" # one of {'auxiliary_train', 'dev', 'val', 'test'} + task_names = [] + for task_csv_name in os.listdir( + os.path.join(root_data_folder_path, folder_name) + ): + task_names.append(task_csv_name.split(".")[0]) + + tasks_in_category = [] + for category_name in category_names: + for task_name in task_names: + for subname in subcategories: + if subname in task_name: + if subcategories[subname][0] in categories[category_name]: + tasks_in_category.append(task_name) + break + + tasks_all = [(folder_name, task_name) for task_name in tasks_in_category] + multiple_choice_tasks = set([item[1] for item in tasks_all]) + boolean_tasks = set() + numerical_output_tasks = set() + + + elif dataset_name == "bbh": + tasks_all = [task_name] + assert ( + len(tasks_all) == 1 + ), "for now only support prompt optimization on one BBH task" + + + numerical_output_tasks = { + "object_counting", + "multistep_arithmetic_two", + } + + multiple_choice_tasks = { + "date_understanding", + "disambiguation_qa", + "geometric_shapes", + "hyperbaton", + "logical_deduction_five_objects", + "logical_deduction_seven_objects", + "logical_deduction_three_objects", + "movie_recommendation", + "penguins_in_a_table", + "reasoning_about_colored_objects", + "ruin_names", + "salient_translation_error_detection", + "snarks", + "temporal_sequences", + "tracking_shuffled_objects_five_objects", + "tracking_shuffled_objects_seven_objects", + "tracking_shuffled_objects_three_objects", + } + + boolean_tasks = { + "boolean_expressions", # True or False + "causal_judgement", # yes or no + "formal_fallacies", # valid or invalid + "navigate", # yes or no + "sports_understanding", # yes or no + "web_of_lies", # yes or no + } + + else: + assert dataset_name in {"gsm8k"} + tasks_all = [task_name] + multiple_choice_tasks = set() + boolean_tasks = set() + numerical_output_tasks = set(tasks_all) + + if dataset_name == "mmlu": + raw_data = pd.DataFrame() + prediction_treat_as_number = False + prediction_treat_as_bool = False + elif dataset_name == "bbh": + raw_data = [] + prediction_treat_as_number = bool( + tasks_all[0] in numerical_output_tasks + ) # for now only check the first task + prediction_treat_as_bool = bool( + tasks_all[0] in boolean_tasks + ) # for now only check the first task + print( + f"prediction_treat_as_number: {prediction_treat_as_number}," + f" prediction_treat_as_bool: {prediction_treat_as_bool}" + ) + else: + assert dataset_name == "gsm8k" + raw_data = pd.DataFrame() + prediction_treat_as_number = True + prediction_treat_as_bool = False + + for t in tasks_all: + if dataset_name == "mmlu": + folder_name = t[0] + task_name = t[1] + single_task_df = pd.read_csv( + os.path.join(root_data_folder_path, f"{folder_name}/{task_name}.csv"), + index_col=None, + header=None, + ) + raw_data = pd.concat([raw_data, single_task_df]) + elif dataset_name == "bbh": + task_name = t + single_task_list = opt_utils.load_bbh_task_data( + task_name, base_dir=root_data_folder_path + ) + raw_data += single_task_list + else: + assert dataset_name == "gsm8k" + task_name = t + f_gsm = os.path.join(root_data_folder_path, f"gsm_{task_name}.tsv") + single_task_df = pd.read_csv(f_gsm, sep="\t", header=None) + raw_data = pd.concat([raw_data, single_task_df]) + + if dataset_name == "mmlu": + num_examples = raw_data.shape[0] + elif dataset_name == "bbh": + num_examples = len(raw_data) + else: + assert dataset_name in {"gsm8k"} + num_examples = raw_data.shape[0] + print(f"number of examples in the current task: {num_examples}") + + # ================ split data into train/val/test ========================== + if dataset_name == "mmlu": + train_ratio = 0.8 + + eval_ratio = 0.2 + elif dataset_name == "gsm8k": +# train_ratio = 0.035 + train_ratio = 0.01 # 原来是 0.035,改成 0.01,约 74 条 + eval_ratio = 0 + else: + assert dataset_name == "bbh" + train_ratio = 0.2 + eval_ratio = 0 + + + assert train_ratio + eval_ratio <= 1 + test_ratio = 1 - train_ratio - eval_ratio + print( + f"train_ratio: {train_ratio}, eval_ratio: {eval_ratio}, " + f"test_ratio: {test_ratio}" + ) + np.random.seed(0) + train_index = np.sort( + np.array( + np.random.choice( + num_examples, size=int(train_ratio * num_examples), replace=False + ) + ) + ) + eval_and_test_index = np.sort( + np.array(list(set(np.arange(num_examples)) - set(train_index))) + ) + eval_index = np.sort( + np.array( + np.random.choice( + eval_and_test_index, + size=int(eval_ratio * num_examples), + replace=False, + ) + ) + ) + + # ========== set other optimization experiment hyperparameters ============== + if scorer_llm_name == "text-bison": + old_instruction_score_threshold = 0.0 + # old_instruction_score_threshold = 0.15 # for GSM8K + else: + assert scorer_llm_name in {"gpt-3.5-turbo", "gpt-4", "local"} + old_instruction_score_threshold = 0.3 + + if scorer_llm_name == "text-bison": + extract_final_answer_by_prompting_again = False + include_qa = False + evaluate_in_parallel = False + else: + assert scorer_llm_name in {"gpt-3.5-turbo", "gpt-4", "local"} + extract_final_answer_by_prompting_again = False + include_qa = False + evaluate_in_parallel = False + + optimizer_llm_temperature = optimizer_llm_dict["temperature"] + + +# num_few_shot_questions_for_instruction_refinement = 3 + num_few_shot_questions_for_instruction_refinement = 1 # 减少 few-shot 例子数 + +# num_generated_instructions_in_each_step = 8 + num_generated_instructions_in_each_step = 2 # 每步只生成 2 条指令 + +# num_search_steps = 200 + num_search_steps = 3 # 原来是 200,改成 3 步即可 + + + initial_instructions = [ + "Let's solve the problem.", + # "", + # "The answer is", + ] + few_shot_qa_pairs = True + # one of {'accumulative_most_frequent', 'current_most_frequent', 'random', + # 'constant'} + few_shot_selection_criteria = "random" + # whether to evaluate generated instructions on the exemplars in meta-prompt + evaluate_generated_ins_on_few_shot = False + # whether to evaluate old instructions on the exemplars in the meta-prompt + evaluate_old_ins_on_few_shot = False + # every this number of steps, compute the accuracies of current-step + # instructions on the validation set +# eval_interval = 3 + eval_interval = 1 # 每步就 eval 一次,及时看到结果 + # eval_interval = 10 + max_num_instructions = ( + 20 # the maximum number of instructions and scores in the meta-prompt + ) + # The number of buckets when converting scores to integers in the meta-prompt. + num_score_buckets = 100 + # whether to put old instructions and scores to before exemplars in + # the meta-prompt + meta_prompt_instructions_before_exemplars = True + + # ===================== run prompt optimization ====================== + + assert few_shot_selection_criteria in { + "accumulative_most_frequent", + "current_most_frequent", + "random", + "constant", + } + evolution_kwargs = { + "num_search_steps": num_search_steps, + "old_instruction_score_threshold": old_instruction_score_threshold, + "scorer_llm_dict": scorer_llm_dict, + "optimizer_llm_dict": optimizer_llm_dict, + "extract_final_answer_by_prompting_again": ( + extract_final_answer_by_prompting_again + ), + "include_qa": include_qa, + "evaluate_in_parallel": evaluate_in_parallel, + "tasks_all": tasks_all, + "train_ratio": train_ratio, + "eval_ratio": eval_ratio, + "test_ratio": test_ratio, + "train_index": train_index, + "eval_index": eval_index, + "dataset_name": dataset_name, + "task_name": task_name, + "num_examples": num_examples, + "root_data_folder_path": root_data_folder_path, + "optimizer_llm_temperature": optimizer_llm_temperature, + # "optimizer_llm_temperature_schedule": ( + # optimizer_llm_temperature_schedule + # ), + # "optimizer_llm_temperature_end": optimizer_llm_temperature_end, + "initial_instructions": initial_instructions, + "multiple_choice_tasks": multiple_choice_tasks, + "raw_data": raw_data, + "call_scorer_server_func": call_scorer_server_func, + "call_optimizer_server_func": call_optimizer_server_func, + "instruction_pos": instruction_pos, + "prediction_treat_as_number": prediction_treat_as_number, + "prediction_treat_as_bool": prediction_treat_as_bool, + "result_by_instruction_folder": result_by_instruction_folder, + "few_shot_qa_pairs": few_shot_qa_pairs, + "num_score_buckets": num_score_buckets, + "max_num_instructions": max_num_instructions, + "meta_prompt_type": meta_prompt_type, + "meta_prompt_instructions_before_exemplars": ( + meta_prompt_instructions_before_exemplars + ), + "few_shot_selection_criteria": few_shot_selection_criteria, + "optimizer_llm_name": optimizer_llm_name, + "num_generated_instructions_in_each_step": ( + num_generated_instructions_in_each_step + ), + "evaluate_generated_ins_on_few_shot": evaluate_generated_ins_on_few_shot, + "num_few_shot_questions_for_instruction_refinement": ( + num_few_shot_questions_for_instruction_refinement + ), + "evaluate_old_ins_on_few_shot": evaluate_old_ins_on_few_shot, + "eval_interval": eval_interval, + "save_folder": save_folder, + } + print("=== 开始优化过程 ===") + try: + opt_utils.run_evolution(**evolution_kwargs) + print("=== 优化完成 ===") + except Exception as e: + import traceback + print(f"!!! 优化失败: {e} !!!", file=sys.stderr) + traceback.print_exc() + sys.exit(1) + + +if __name__ == "__main__": + app.run(main) + diff --git a/optimization/optimize_linear_regression.py b/optimization/optimize_linear_regression.py new file mode 100644 index 0000000..f856434 --- /dev/null +++ b/optimization/optimize_linear_regression.py @@ -0,0 +1,424 @@ +# Copyright 2023 The OPRO Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r"""Optimize over the objective function of a linear regression problem. + +Usage: + +``` +python optimize_linear_regression.py --optimizer="text-bison" +``` + +Note: +- When using a Google-Cloud-served model (like text-bison at +https://developers.generativeai.google/tutorials/text_quickstart), add +`--palm_api_key=""` +- When using an OpenAI model, add `--openai_api_key=""` +""" + +import datetime +import functools +import json +import os +import re +import sys + +OPRO_ROOT_PATH = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +) +sys.path.insert(0, OPRO_ROOT_PATH) + +from absl import app +from absl import flags +import google.generativeai as palm +import numpy as np +import openai + +from opro import prompt_utils + +_OPENAI_API_KEY = flags.DEFINE_string( + "openai_api_key", "", "The OpenAI API key." +) + +_PALM_API_KEY = flags.DEFINE_string("palm_api_key", "", "The PaLM API key.") + +_OPTIMIZER = flags.DEFINE_string( + "optimizer", "gpt-3.5-turbo", "The name of the optimizer LLM." +) + + +def main(_): + # ============== set optimization experiment configurations ================ + num_points = 50 # number of points in linear regression + w_true = 15 # the true w + b_true = 14 # the true b + max_num_steps = 500 # the number of optimization steps + num_reps = 5 # the number of repeated runs + max_num_pairs = 20 # the maximum number of input-output pairs in meta-prompt + num_input_decimals = 0 # num of decimals for input values in meta-prompt + num_output_decimals = 0 # num of decimals for output values in meta-prompt + num_generated_points_in_each_step = 8 + + # ================ load LLM settings =================== + optimizer_llm_name = _OPTIMIZER.value + assert optimizer_llm_name in { + "text-bison", + "gpt-3.5-turbo", + "gpt-4", + } + openai_api_key = _OPENAI_API_KEY.value + palm_api_key = _PALM_API_KEY.value + + if optimizer_llm_name in {"gpt-3.5-turbo", "gpt-4"}: + assert openai_api_key, "The OpenAI API key must be provided." + openai.api_key = openai_api_key + else: + assert optimizer_llm_name == "text-bison" + assert ( + palm_api_key + ), "A PaLM API key is needed when prompting the text-bison model." + palm.configure(api_key=palm_api_key) + + # =================== create the result directory ========================== + datetime_str = ( + str(datetime.datetime.now().replace(microsecond=0)) + .replace(" ", "-") + .replace(":", "-") + ) + + save_folder = os.path.join( + OPRO_ROOT_PATH, + "outputs", + "optimization-results", + f"linear_regression-o-{optimizer_llm_name}-{datetime_str}/", + ) + os.makedirs(save_folder) + print(f"result directory:\n{save_folder}") + + # ====================== optimizer model configs ============================ + if optimizer_llm_name.lower() == "text-bison": + # when prompting text-bison with Cloud API + optimizer_finetuned_palm_temperature = 1.0 + optimizer_finetuned_palm_max_decode_steps = 1024 + optimizer_finetuned_palm_batch_size = 1 + optimizer_finetuned_palm_num_servers = 1 + optimizer_finetuned_palm_dict = dict() + optimizer_finetuned_palm_dict["temperature"] = ( + optimizer_finetuned_palm_temperature + ) + optimizer_finetuned_palm_dict["batch_size"] = ( + optimizer_finetuned_palm_batch_size + ) + optimizer_finetuned_palm_dict["num_servers"] = ( + optimizer_finetuned_palm_num_servers + ) + optimizer_finetuned_palm_dict["max_decode_steps"] = ( + optimizer_finetuned_palm_max_decode_steps + ) + + call_optimizer_finetuned_palm_server_func = functools.partial( + prompt_utils.call_palm_server_from_cloud, + # prompt_utils.call_vllm, + model="text-bison-001", + temperature=optimizer_finetuned_palm_dict["temperature"], + max_decode_steps=optimizer_finetuned_palm_dict["max_decode_steps"], + ) + + optimizer_llm_dict = { + "model_type": optimizer_llm_name.lower(), + } + optimizer_llm_dict.update(optimizer_finetuned_palm_dict) + call_optimizer_server_func = call_optimizer_finetuned_palm_server_func + + else: + assert optimizer_llm_name in {"gpt-3.5-turbo", "gpt-4"} + optimizer_gpt_max_decode_steps = 1024 + optimizer_gpt_temperature = 1.0 + + optimizer_llm_dict = dict() + optimizer_llm_dict["max_decode_steps"] = optimizer_gpt_max_decode_steps + optimizer_llm_dict["temperature"] = optimizer_gpt_temperature + optimizer_llm_dict["batch_size"] = 1 + call_optimizer_server_func = functools.partial( + prompt_utils.call_openai_server_func, + model=optimizer_llm_name, + max_decode_steps=optimizer_gpt_max_decode_steps, + temperature=optimizer_gpt_temperature, + ) + + # ====================== try calling the servers ============================ + print("\n======== testing the optimizer server ===========") + optimizer_test_output = call_optimizer_server_func( + "Does the sun rise from the north? Just answer yes or no.", + temperature=1.0, + ) + print(f"optimizer test output: {optimizer_test_output}") + print("Finished testing the optimizer server.") + print("\n=================================================") + + # ====================== utility functions ============================ + def evaluate_loss(X, y, w, b): # pylint: disable=invalid-name + residual = y - (X * w + b) + return np.linalg.norm(residual) ** 2 + + def gen_meta_prompt( + old_value_pairs_set, + X, # pylint: disable=invalid-name, unused-argument + y, # pylint: disable=unused-argument + num_input_decimals=5, + num_output_decimals=5, + max_num_pairs=100, + ): + """Generate the meta-prompt for optimization. + + Args: + old_value_pairs_set (set): the set of old (w, b, z) pairs. + X (np.array): the 1D array of x values. + y (np.array): the 1D array of y values. + num_input_decimals (int): the number of decimals for (w, b) in the + meta-prompt. + num_output_decimals (int): the number of decimals for z in the meta-prompt. + max_num_pairs (int): the maximum number of exemplars in the meta-prompt. + + Returns: + meta_prompt (str): the generated meta-prompt. + """ + old_value_pairs_set = set( + [ # pylint: disable=g-complex-comprehension + ( + np.round(w, num_input_decimals) + if num_input_decimals > 0 + else int(w), + np.round(b, num_input_decimals) + if num_input_decimals > 0 + else int(b), + np.round(z, num_output_decimals) + if num_output_decimals > 0 + else int(z), + ) + for w, b, z in old_value_pairs_set + ] + ) + old_value_pairs = list(old_value_pairs_set) + old_value_pairs = sorted(old_value_pairs, key=lambda x: -x[2])[ + -max_num_pairs: + ] + old_value_pairs_substr = "" + for w, b, z in old_value_pairs: + old_value_pairs_substr += f"\ninput:\nw={w}, b={b}\nvalue:\n{z}\n" + meta_prompt = """ + Now you will help me minimize a function with two input variables w, b. I have some (w, b) pairs and the function values at those points. The pairs are arranged in descending order based on their function values, where lower values are better. + """.strip() + meta_prompt += "\n\n" + meta_prompt += old_value_pairs_substr.strip() + meta_prompt += "\n\n" + # function_analytic_form = "" + # for xi, yi in zip(X, y): + # function_analytic_form += f"({yi:.4f} - ({xi:.4f} * w + b)) ** 2 + " + # function_analytic_form = function_analytic_form[:-3] + # meta_prompt += ( + # "The function has the analytic form f(w, b) =" + # f" {function_analytic_form}. When evaluating the value of a (w, b)" + # " pair, you should replace the w and b in the analytic form with your" + # " values and do the computation." + # ) + # meta_prompt += "\n\n" + meta_prompt += """Give me a new (w, b) pair that is different from all pairs above, and has a function value lower than any of the above. Do not write code. The output must end with a pair [w, b], where w and b are numerical values. + """.strip() + return meta_prompt + + def extract_string_in_square_brackets(input_string): + raw_result = re.findall(r"\[.*?\]", input_string) + if raw_result: + for pair in raw_result[::-1]: + if "=" not in pair and ("w" in pair or "b" in pair): + continue + return pair[1:-1] + return "" + else: + return "" + + def parse_output(extracted_output): + """Parse the extracted output 'w, b' string to np.array([w, b]). + + Args: + extracted_output (str): the extracted output string, like '1.5, 2.5'. + + Returns: + parsed_output (np.array): the parsed output in a numpy array, like [1.5, + 2.5]. + """ + if not extracted_output: + return + extracted_values = [] + for item in extracted_output.split(","): + if "=" in item: + item = item[item.index("=") + 1 :] + extracted_values.append(item.strip()) + parsed_output = np.array(extracted_values).astype(float) + return parsed_output + + configs_dict = dict() + results_dict = dict() + num_convergence_steps = [] + for i_rep in range(num_reps): + found_optimal = False + print(f"\nRep {i_rep}:") + + # ================= generate the ground truth X, y ===================== + X = np.arange(num_points).astype(float) + 1 # pylint: disable=invalid-name + np.random.seed(i_rep + 1) + y = X * w_true + b_true + np.random.randn(num_points) + loss_at_true_values = evaluate_loss(X, y, w_true, b_true) + print(f"value at (w_true, b_true): {loss_at_true_values}") + + # ================= generate the starting points ===================== + num_starting_points = 5 # the number of initial points for optimization + np.random.seed((i_rep + 1) * 10) + init_w = np.random.uniform(low=10, high=20, size=num_starting_points) + np.random.seed((i_rep + 1) * 100) + init_b = np.random.uniform(low=10, high=20, size=num_starting_points) + + # ====================== run optimization ============================ + configs_dict_single_rep = { + "optimizer_llm_configs": optimizer_llm_dict, + "data": { + "num_points": num_points, + "w_true": w_true, + "b_true": b_true, + "loss_at_true_values": loss_at_true_values, + "X": list(X), + "y": list(y), + }, + "init_w": list(init_w), + "init_b": list(init_b), + "max_num_steps": max_num_steps, + "max_num_pairs": max_num_pairs, + "num_input_decimals": num_input_decimals, + "num_output_decimals": num_output_decimals, + "num_generated_points_in_each_step": num_generated_points_in_each_step, + } + configs_dict[i_rep] = configs_dict_single_rep + configs_json_path = os.path.join(save_folder, "configs.json") + print(f"saving configs to\n{configs_json_path}") + with open(configs_json_path, "w") as f: + json.dump(configs_dict, f, indent=4) + + old_value_pairs_set = set() + old_value_pairs_with_i_step = [] # format: [(w, b, z = f(w, b), i_step)] + meta_prompts_dict = dict() # format: {i_step: meta_prompt} + raw_outputs_dict = dict() # format: {i_step: raw_outputs} + + rounded_inits = [ + (np.round(w, num_input_decimals), np.round(b, num_input_decimals)) + for w, b in zip(init_w, init_b) + ] + rounded_inits = [ + tuple(item) for item in list(np.unique(rounded_inits, axis=0)) + ] + for w, b in rounded_inits: + z = evaluate_loss(X, y, w, b) + old_value_pairs_set.add((w, b, z)) + old_value_pairs_with_i_step.append((w, b, z, -1)) + + print("\n================ run optimization ==============") + print( + f"initial points: {[tuple(item[:2]) for item in old_value_pairs_set]}" + ) + print(f"initial values: {[item[-1] for item in old_value_pairs_set]}") + results_json_path = os.path.join(save_folder, "results.json") + print(f"saving results to\n{results_json_path}") + + for i_step in range(max_num_steps): + print(f"\nStep {i_step}:") + meta_prompt = gen_meta_prompt( + old_value_pairs_set, + X, + y, + num_input_decimals=num_input_decimals, + num_output_decimals=num_output_decimals, + max_num_pairs=max_num_pairs, + ) + if not i_step % 5: + print("\n=================================================") + print(f"meta_prompt:\n{meta_prompt}") + meta_prompts_dict[i_step] = meta_prompt + + # generate a maximum of the given number of points in each step + remaining_num_points_to_generate = num_generated_points_in_each_step + raw_outputs = [] + while remaining_num_points_to_generate > 0: + raw_outputs += call_optimizer_server_func(meta_prompt) + remaining_num_points_to_generate -= optimizer_llm_dict["batch_size"] + raw_outputs = raw_outputs[:num_generated_points_in_each_step] + + raw_outputs_dict[i_step] = raw_outputs + parsed_outputs = [] + for string in raw_outputs: + if not i_step % 5: + print("\n=================================================") + print("raw output:\n", string) + print("\n=================================================") + try: + parsed_output = parse_output( + extract_string_in_square_brackets(string) + ) + if parsed_output is not None and len(parsed_output) == 2: + parsed_outputs.append(parsed_output) + except ValueError: + pass + parsed_outputs = [tuple(item) for item in parsed_outputs] + print(f"proposed points before rounding: {parsed_outputs}") + + # round the proposed points to the number of decimals in meta-prompt + rounded_outputs = [ + (np.round(w, num_input_decimals), np.round(b, num_input_decimals)) + for w, b in parsed_outputs + ] + rounded_outputs = [ + tuple(item) for item in list(np.unique(rounded_outputs, axis=0)) + ] + print(f"proposed points after rounding: {rounded_outputs}") + + # evaluate the values of proposed and rounded outputs + single_step_values = [] + for w, b in rounded_outputs: + if w == w_true and b == b_true: + found_optimal = True + z = evaluate_loss(X, y, w, b) + single_step_values.append(z) + old_value_pairs_set.add((w, b, z)) + old_value_pairs_with_i_step.append((w, b, z, i_step)) + print(f"single_step_values: {single_step_values}") + + # ====================== save results ============================ + results_dict_single_rep = { + "meta_prompts": meta_prompts_dict, + "raw_outputs": raw_outputs_dict, + "old_value_pairs_with_i_step": old_value_pairs_with_i_step, + } + results_dict[i_rep] = results_dict_single_rep + with open(results_json_path, "w") as f: + json.dump(results_dict, f, indent=4) + if found_optimal: + print( + f"Repetition {i_rep+1}, optimal found at Step {i_step+1}, saving" + f" final results to\n{save_folder}" + ) + num_convergence_steps.append(i_step + 1) + break + print(f"num_convergence_steps: {num_convergence_steps}") + + +if __name__ == "__main__": + app.run(main) diff --git a/optimization/optimize_tsp.py b/optimization/optimize_tsp.py new file mode 100644 index 0000000..a8f99d9 --- /dev/null +++ b/optimization/optimize_tsp.py @@ -0,0 +1,430 @@ +# Copyright 2024 The OPRO Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r"""Optimize over the objective function of a traveling salesman problem. + +Usage: + +``` +python optimize_tsp.py --optimizer="text-bison" +``` + +Note: +- When using a Google-Cloud-served model (like text-bison at +https://developers.generativeai.google/tutorials/text_quickstart), add +`--palm_api_key=""` +- When using an OpenAI model, add `--openai_api_key=""` +""" + +import datetime +import functools +import getpass +import json +import os +import re +import sys +import itertools + +OPRO_ROOT_PATH = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +) +sys.path.insert(0, OPRO_ROOT_PATH) + +from absl import app +from absl import flags +import google.generativeai as palm +import numpy as np +import openai + +from opro import prompt_utils + + +_OPENAI_API_KEY = flags.DEFINE_string( + "openai_api_key", "", "The OpenAI API key." +) + +_PALM_API_KEY = flags.DEFINE_string("palm_api_key", "", "The PaLM API key.") + +_OPTIMIZER = flags.DEFINE_string( + "optimizer", "gpt-3.5-turbo", "The name of the optimizer LLM." +) + +_START_ALGORITHM = flags.DEFINE_string( + "starting_algorithm", "farthest_insertion", "The name of the starting algorithm. Select from [dp, nearest_neighbor, farthest_insertion]" +) + +def main(_): + # ============== set optimization experiment configurations ================ + num_points = 100 # number of points in TSP + num_steps = 500 # the number of optimization steps + max_num_pairs = 10 # the maximum number of input-output pairs in meta-prompt + num_decimals = 0 # num of decimals for distances in meta-prompt + num_starting_points = 5 # the number of initial points for optimization + num_decode_per_step = 8 # the number of decoded solutions per step + + # ================ load LLM settings =================== + optimizer_llm_name = _OPTIMIZER.value + assert optimizer_llm_name in { + "text-bison", + "gpt-3.5-turbo", + "gpt-4", + } + openai_api_key = _OPENAI_API_KEY.value + palm_api_key = _PALM_API_KEY.value + + if optimizer_llm_name in {"gpt-3.5-turbo", "gpt-4"}: + assert openai_api_key, "The OpenAI API key must be provided." + openai.api_key = openai_api_key + else: + assert optimizer_llm_name == "text-bison" + assert ( + palm_api_key + ), "A PaLM API key is needed when prompting the text-bison model." + palm.configure(api_key=palm_api_key) + + # =================== create the result directory ========================== + datetime_str = ( + str(datetime.datetime.now().replace(microsecond=0)) + .replace(" ", "-") + .replace(":", "-") + ) + + save_folder = os.path.join( + OPRO_ROOT_PATH, + "outputs", + "optimization-results", + f"tsp-o-{optimizer_llm_name}-{datetime_str}/", + ) + os.makedirs(save_folder) + print(f"result directory:\n{save_folder}") + + # ====================== optimizer model configs ============================ + if optimizer_llm_name.lower() == "text-bison": + # when prompting text-bison with Cloud API + optimizer_finetuned_palm_temperature = 1.0 + optimizer_finetuned_palm_max_decode_steps = 1024 + optimizer_finetuned_palm_batch_size = 1 + optimizer_finetuned_palm_num_servers = 1 + optimizer_finetuned_palm_dict = dict() + optimizer_finetuned_palm_dict["temperature"] = ( + optimizer_finetuned_palm_temperature + ) + optimizer_finetuned_palm_dict["batch_size"] = ( + optimizer_finetuned_palm_batch_size + ) + optimizer_finetuned_palm_dict["num_servers"] = ( + optimizer_finetuned_palm_num_servers + ) + optimizer_finetuned_palm_dict["max_decode_steps"] = ( + optimizer_finetuned_palm_max_decode_steps + ) + + call_optimizer_finetuned_palm_server_func = functools.partial( + prompt_utils.call_palm_server_from_cloud, + # prompt_utils.call_vllm, + model="text-bison-001", + temperature=optimizer_finetuned_palm_dict["temperature"], + max_decode_steps=optimizer_finetuned_palm_dict["max_decode_steps"], + ) + + optimizer_llm_dict = { + "model_type": optimizer_llm_name.lower(), + } + optimizer_llm_dict.update(optimizer_finetuned_palm_dict) + call_optimizer_server_func = call_optimizer_finetuned_palm_server_func + + else: + assert optimizer_llm_name in {"gpt-3.5-turbo", "gpt-4"} + optimizer_gpt_max_decode_steps = 1024 + optimizer_gpt_temperature = 1.0 + + optimizer_llm_dict = dict() + optimizer_llm_dict["max_decode_steps"] = optimizer_gpt_max_decode_steps + optimizer_llm_dict["temperature"] = optimizer_gpt_temperature + optimizer_llm_dict["batch_size"] = 1 + call_optimizer_server_func = functools.partial( + prompt_utils.call_openai_server_func, + model=optimizer_llm_name, + max_decode_steps=optimizer_gpt_max_decode_steps, + temperature=optimizer_gpt_temperature, + ) + + # ====================== try calling the servers ============================ + print("\n======== testing the optimizer server ===========") + optimizer_test_output = call_optimizer_server_func( + "Does the sun rise from the north? Just answer yes or no.", + temperature=1.0, + ) + print(f"optimizer test output: {optimizer_test_output}") + print("Finished testing the optimizer server.") + print("\n=================================================") + + # ====================== utility functions ============================ + def evaluate_distance(x, y, trace, num_decimals): # pylint: disable=invalid-name + dis = 0 + try: + for i in range(len(trace) - 1): + id0 = trace[i] + id1 = trace[i + 1] + dis += np.sqrt((x[id0] - x[id1]) ** 2 + (y[id0] - y[id1]) ** 2) + except: + return -1 + id0 = trace[-1] + id1 = trace[0] + dis += np.sqrt((x[id0] - x[id1]) ** 2 + (y[id0] - y[id1]) ** 2) + dis = np.round(dis, num_decimals) if num_decimals > 0 else int(dis) + return dis + + def solve_tsp(x, y, num_points, num_decimals, starting_algorithm): + if starting_algorithm == "nearest_neighbor": + min_dis = 0 + gt_sol = [0] + remaining_points = list(range(1, num_points)) + while len(remaining_points) > 0: + min_p = -1 + min_cur_dis = -1 + for p in remaining_points: + cur_dis = np.sqrt((x[p] - x[gt_sol[-1]]) ** 2 + (y[p] - y[gt_sol[-1]]) ** 2) + if min_p == -1 or cur_dis < min_cur_dis: + min_p = p + min_cur_dis = cur_dis + gt_sol.append(min_p) + min_dis += min_cur_dis + remaining_points.remove(min_p) + min_dis += np.sqrt((x[0] - x[gt_sol[-1]]) ** 2 + (y[0] - y[gt_sol[-1]]) ** 2) + min_dis = np.round(min_dis, num_decimals) if num_decimals > 0 else int(min_dis) + return gt_sol, min_dis + elif starting_algorithm == "farthest_insertion": + gt_sol = [0] + remaining_points = list(range(1, num_points)) + while len(remaining_points) > 0: + max_p = -1 + max_cur_dis = -1 + max_cur_index = -1 + for p in remaining_points: + min_cur_dis = -1 + min_cur_index = -1 + for index in range(1, len(gt_sol) + 1): + new_sol = gt_sol[:index] + [p] + gt_sol[index:] + cur_dis = evaluate_distance(x, y, new_sol, num_decimals) + if min_cur_dis == -1 or cur_dis < min_cur_dis: + min_cur_dis = cur_dis + min_cur_index = index + if max_cur_dis == -1 or min_cur_dis > max_cur_dis: + max_p = p + max_cur_dis = min_cur_dis + max_cur_index = min_cur_index + gt_sol = gt_sol[:max_cur_index] + [max_p] + gt_sol[max_cur_index:] + remaining_points.remove(max_p) + min_dis = evaluate_distance(x, y, gt_sol, num_decimals) + return gt_sol, min_dis + + f = {(0, 1): (0, [0])} + q = [(0, 1)] + min_dis = -1 + gt_sol = list(range(num_points)) + while len(q) > 0: + p, status = q[0] + q = q[1:] + for i in range(num_points): + if 2 << i >> 1 & status == 0: + new_status = status + (2 << i >> 1) + new_dis = f[(p, status)][0] + np.sqrt((x[i] - x[p]) ** 2 + (y[i] - y[p]) ** 2) + if (i, new_status) not in f or new_dis < f[(i, new_status)][0]: + f[(i, new_status)] = (new_dis, f[(p, status)][1] + [i]) + if new_status == (2 << num_points >> 1) - 1: + new_dis += np.sqrt((x[i] - x[0]) ** 2 + (y[i] - y[0]) ** 2) + if min_dis == -1 or new_dis < min_dis: + min_dis = new_dis + gt_sol = f[(i, new_status)][1][:] + elif (i, new_status) not in q: + q.append((i, new_status)) + min_dis = np.round(min_dis, num_decimals) if num_decimals > 0 else int(min_dis) + return gt_sol, min_dis + + def gen_meta_prompt( + old_value_pairs_set, + x, # pylint: disable=invalid-name + y, + max_num_pairs=100, + ): + """Generate the meta-prompt for optimization. + + Args: + old_value_pairs_set (set): the set of old traces. + X (np.array): the 1D array of x values. + y (np.array): the 1D array of y values. + num_decimals (int): the number of decimals in the + meta-prompt. + max_num_pairs (int): the maximum number of exemplars in the meta-prompt. + + Returns: + meta_prompt (str): the generated meta-prompt. + """ + old_value_pairs = list(old_value_pairs_set) + old_value_pairs = sorted(old_value_pairs, key=lambda x: -x[1])[ + -max_num_pairs: + ] + old_value_pairs_substr = "" + for trace, dis in old_value_pairs: + old_value_pairs_substr += f"\n {trace} \nlength:\n{dis}\n" + meta_prompt = "You are given a list of points with coordinates below:\n" + for i, (xi, yi) in enumerate(zip(x, y)): + if i: + meta_prompt += ", " + meta_prompt += f"({i}): ({xi}, {yi})" + meta_prompt += ".\n\nBelow are some previous traces and their lengths. The traces are arranged in descending order based on their lengths, where lower values are better.".strip() + meta_prompt += "\n\n" + meta_prompt += old_value_pairs_substr.strip() + meta_prompt += "\n\n" + meta_prompt += """Give me a new trace that is different from all traces above, and has a length lower than any of the above. The trace should traverse all points exactly once. The trace should start with '' and end with . + """.strip() + return meta_prompt + + def extract_string(input_string): + start_string = "" + end_string = "" + if start_string not in input_string: + return "" + input_string = input_string[input_string.index(start_string) + len(start_string):] + if end_string not in input_string: + return "" + input_string = input_string[:input_string.index(end_string)] + parsed_list = [] + for p in input_string.split(","): + p = p.strip() + try: + p = int(p) + except: + continue + parsed_list.append(p) + return parsed_list + + # ================= generate the ground truth trace ===================== + + x = np.random.uniform(low=-100, high=100, size=num_points) + y = np.random.uniform(low=-100, high=100, size=num_points) + x = [np.round(xi, num_decimals) if num_decimals > 0 else int(xi) for xi in x] + y = [np.round(yi, num_decimals) if num_decimals > 0 else int(yi) for yi in y] + + starting_algorithm = _START_ALGORITHM.value + + gt_sol, min_dis = solve_tsp(x, y, num_points, num_decimals, starting_algorithm) + print("ground truth solution" + str(gt_sol)) + print("min distance: ", min_dis) + gt_sol_str = ",".join([str(i) for i in gt_sol]) + point_list = range(num_points) + init_sols = [] + while len(init_sols) < num_starting_points: + sol = np.random.permutation(point_list) + if sol[0] != 0: + continue + sol_str = ",".join([str(i) for i in sol]) + if sol_str == gt_sol_str: + continue + init_sols.append(list(sol)) + + # ====================== run optimization ============================ + configs_dict = { + "num_starting_points": num_starting_points, + "num_decode_per_step": num_decode_per_step, + "optimizer_llm_configs": optimizer_llm_dict, + "data": { + "ground truth solution": [",".join([str(i) for i in gt_sol])], + "loss_at_true_values": min_dis, + "x": list(x), + "y": list(y), + }, + "init_sols": [",".join([str(i) for i in sol]) for sol in init_sols], + "num_steps": num_steps, + "max_num_pairs": max_num_pairs, + "num_decimals": num_decimals, + } + configs_json_path = os.path.join(save_folder, "configs.json") + print(f"saving configs to\n{configs_json_path}") + with open(configs_json_path, "w") as f: + json.dump(configs_dict, f, indent=4) + + old_value_pairs_set = set() + old_value_pairs_with_i_step = [] # format: [(trace, dis = f(trace), i_step)] + meta_prompts_dict = dict() # format: {i_step: meta_prompt} + raw_outputs_dict = dict() # format: {i_step: raw_outputs} + + for sol in init_sols: + dis = evaluate_distance(x, y, sol, num_decimals) + sol_str = ",".join([str(i) for i in sol]) + old_value_pairs_set.add((sol_str, dis)) + old_value_pairs_with_i_step.append((sol_str, dis, -1)) + + print("\n================ run optimization ==============") + print(f"initial points: {[tuple(item[:-1]) for item in old_value_pairs_set]}") + print(f"initial values: {[item[-1] for item in old_value_pairs_set]}") + results_json_path = os.path.join(save_folder, "results.json") + print(f"saving results to\n{results_json_path}") + + for i_step in range(num_steps): + print(f"\nStep {i_step}:") + meta_prompt = gen_meta_prompt( + old_value_pairs_set, + x, + y, + max_num_pairs=max_num_pairs, + ) + print("\n=================================================") + print(f"meta_prompt:\n{meta_prompt}") + meta_prompts_dict[i_step] = meta_prompt + raw_outputs = [] + parsed_outputs = [] + while len(parsed_outputs) < num_decode_per_step: + raw_output = call_optimizer_server_func(meta_prompt) + for string in raw_output: + print("\n=================================================") + print("raw output:\n", string) + try: + parsed_output = extract_string(string) + if parsed_output is not None and len(set(parsed_output)) == num_points and len(parsed_output) == num_points and parsed_output[0] == 0: + dis = evaluate_distance(x, y, parsed_output, num_decimals) + if dis == -1: + continue + parsed_outputs.append(parsed_output) + raw_outputs.append(string) + except: + pass + print("\n=================================================") + print(f"proposed points: {parsed_outputs}") + raw_outputs_dict[i_step] = raw_outputs + + # evaluate the values of proposed and rounded outputs + single_step_values = [] + for trace in parsed_outputs: + dis = evaluate_distance(x, y, trace, num_decimals) + single_step_values.append(dis) + trace_str = ",".join([str(i) for i in trace]) + old_value_pairs_set.add((trace_str, dis)) + old_value_pairs_with_i_step.append((trace_str, dis, i_step)) + print(f"single_step_values: {single_step_values}") + print("ground truth solution" + str(gt_sol)) + print("min distance: ", min_dis) + + # ====================== save results ============================ + results_dict = { + "meta_prompts": meta_prompts_dict, + "raw_outputs": raw_outputs_dict, + "old_value_pairs_with_i_step": old_value_pairs_with_i_step, + } + with open(results_json_path, "w") as f: + json.dump(results_dict, f, indent=4) + + +if __name__ == "__main__": + app.run(main) diff --git a/optimization/test.py b/optimization/test.py new file mode 100644 index 0000000..a38abcc --- /dev/null +++ b/optimization/test.py @@ -0,0 +1,967 @@ +# Copyright 2023 The OPRO Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r"""The .py file for prompt optimization. + +Usage: + +Step 1: edit the starting instructions by modifying `initial_instructions` + +Step 2: edit the training ratio by modifying `train_ratio` + +Step 3: check if the model configs (like batch size) are the same as the actual serving configs + +Step 4: run + +``` +python optimize_instructions.py \ + --optimizer="gpt-3.5-turbo" --scorer="text-bison" \ + --instruction_pos="A_begin" --dataset="gsm8k" --task="train" +``` + +The outputs will then be written to `outputs/optimization-results/` in the opro folder. + +Notes: + +1. One or more API keys may need to be provided: +- When using a Google-Cloud-served model (like text-bison at https://developers.generativeai.google/tutorials/text_quickstart), add `--palm_api_key=` +- When using an OpenAI model, add `--openai_api_key=””` + +2. The initial instructions should be provided in the "initial_instructions" +variable. +""" + +import datetime +import functools +import os +import sys + +OPRO_ROOT_PATH = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +) +sys.path.insert(0, OPRO_ROOT_PATH) + +from absl import app +from absl import flags +import google.generativeai as palm +import numpy as np +import openai +from opro import prompt_utils +from opro.optimization import opt_utils +import pandas as pd + +ROOT_DATA_FOLDER_PATH = os.path.join(OPRO_ROOT_PATH, "data") + +_LOCAL_MODEL_PATH = flags.DEFINE_string("local_model_path", None, "Path to local vLLM model.") + + +_OPENAI_API_KEY = flags.DEFINE_string( + "openai_api_key", "", "The OpenAI API key." +) + +_PALM_API_KEY = flags.DEFINE_string("palm_api_key", "", "The PaLM API key.") + +_SCORER = flags.DEFINE_string( + "scorer", "text-bison", "The name of the scorer LLM." +) + +_OPTIMIZER = flags.DEFINE_string( + "optimizer", "gpt-3.5-turbo", "The name of the optimizer LLM." +) + +_DATASET = flags.DEFINE_string( + "dataset", "gsm8k", "The name of dataset to search for instructions on." +) + +_TASK = flags.DEFINE_string( + "task", + "train", + "The name of task within the above dataset to search for instructions on.", +) + +_INSTRUCTION_POS = flags.DEFINE_string( + "instruction_pos", + "A_begin", + "The position of the instruction to search for.", +) + +_META_PROMPT_TYPE = flags.DEFINE_string( + "meta_prompt_type", + "both_instructions_and_exemplars", + "The type of meta-prompt: whether to have both previous instructions and" + " dataset exemplars (often for fine-tuned optimizers), or to have only" + " previous instructions (often for pre-trained optimizers).", +) + + +def main(_): + openai_api_key = _OPENAI_API_KEY.value + palm_api_key = _PALM_API_KEY.value + scorer_llm_name = _SCORER.value + optimizer_llm_name = _OPTIMIZER.value + dataset_name = _DATASET.value.lower() + task_name = _TASK.value + meta_prompt_type = _META_PROMPT_TYPE.value + local_model_path = _LOCAL_MODEL_PATH.value + + assert dataset_name in { + "mmlu", + "bbh", + "gsm8k", + }, "The lower-case dataset name must be one of mmlu, bbh, or gsm8k." + if dataset_name == "mmlu": + assert task_name in { + "STEM", + "humanities", + "social sciences", + "otheran (business, health, misc.)", + } # for now only support searching on one MMLU category + elif dataset_name == "bbh": + assert task_name in { + "boolean_expressions", + "causal_judgement", + "date_understanding", + "disambiguation_qa", + "dyck_languages", + "formal_fallacies", + "geometric_shapes", + "hyperbaton", + "logical_deduction_five_objects", + "logical_deduction_seven_objects", + "logical_deduction_three_objects", + "movie_recommendation", + "multistep_arithmetic_two", + "navigate", + "object_counting", + "penguins_in_a_table", + "reasoning_about_colored_objects", + "ruin_names", + "salient_translation_error_detection", + "snarks", + "sports_understanding", + "temporal_sequences", + "tracking_shuffled_objects_five_objects", + "tracking_shuffled_objects_seven_objects", + "tracking_shuffled_objects_three_objects", + "web_of_lies", + "word_sorting", + } + else: + assert dataset_name == "gsm8k" + assert task_name in {"train", "test"} + + assert scorer_llm_name in { + "text-bison", + "gpt-3.5-turbo", + "gpt-4", + "local", + } + assert optimizer_llm_name in { + "text-bison", + "gpt-3.5-turbo", + "gpt-4", + "local", + } + assert meta_prompt_type in { + "both_instructions_and_exemplars", + "instructions_only", + } + + instruction_pos = _INSTRUCTION_POS.value + assert instruction_pos in { + "before_Q", + "Q_begin", + "Q_end", + "A_begin", + }, ( + "The instruction position should be either before the question, or at the" + " beginning of the question, at the end of the question, or at the" + " beginning of the answer." + ) + print( + f"scorer: {scorer_llm_name}, optimizer: {optimizer_llm_name}, dataset:" + f" {dataset_name}, task: {task_name}, instruction_pos: {instruction_pos}" + ) + + # make sure the scorer and optimizer models are callable + + if scorer_llm_name in {"gpt-3.5-turbo", "gpt-4"}: + assert openai_api_key, "The OpenAI API key must be provided." + openai.api_key = openai_api_key + elif scorer_llm_name == "text-bison": + assert scorer_llm_name == "text-bison" + assert ( + palm_api_key + ), "A PaLM API key is needed when prompting the text-bison model." + palm.configure(api_key=palm_api_key) + + elif scorer_llm_name == "local": + assert local_model_path, "The local model path must be provided." + assert os.path.exists(local_model_path), ( + f"The local model path {local_model_path} does not exist." + ) + # set the local model path for vLLM + # prompt_utils.call_local_server_func(local_model_path) + else: + raise ValueError( + f"Unknown scorer_llm_name: {scorer_llm_name}. " + "It should be one of text-bison, gpt-3.5-turbo, gpt-4, or local." + ) + + if optimizer_llm_name in {"gpt-3.5-turbo", "gpt-4"}: + assert openai_api_key, "The OpenAI API key must be provided." + openai.api_key = openai_api_key + elif optimizer_llm_name == "text-bison": + assert optimizer_llm_name == "text-bison" + assert ( + palm_api_key + ), "A PaLM API key is needed when prompting the text-bison model." + palm.configure(api_key=palm_api_key) + + elif optimizer_llm_name == "local": + assert local_model_path, "The local model path must be provided." + assert os.path.exists(local_model_path), ( + f"The local model path {local_model_path} does not exist." + ) + # set the local model path for vLLM + # prompt_utils.call_local_server_func(local_model_path) + else: + raise ValueError( + f"Unknown scorer_llm_name: {optimizer_llm_name}. " + "It should be one of text-bison, gpt-3.5-turbo, gpt-4, or local." + ) + + + if dataset_name == "mmlu": + root_data_folder_path = os.path.join(ROOT_DATA_FOLDER_PATH, "MMLU-data") + elif dataset_name == "bbh": + root_data_folder_path = os.path.join( + ROOT_DATA_FOLDER_PATH, "BIG-Bench-Hard-data/" + ) + else: + assert dataset_name == "gsm8k" + root_data_folder_path = os.path.join(ROOT_DATA_FOLDER_PATH, "gsm_data") + + # =================== create the result directory ========================== + datetime_str = ( + str(datetime.datetime.now().replace(microsecond=0)) + .replace(" ", "-") + .replace(":", "-") + ) + + save_folder = os.path.join( + OPRO_ROOT_PATH, + "outputs", + "optimization-results", + f"{dataset_name.upper()}-{task_name}-s-{scorer_llm_name}-o-{optimizer_llm_name}-{datetime_str}/", + ) + result_by_instruction_folder = os.path.join( + save_folder, "result_by_instruction" + ) + os.makedirs(result_by_instruction_folder) + print(f"result directory:\n{save_folder}") + + # ====================== scorer model configs ============================== + # difference between num_decodes and batch_size: + # - num_decodes: how many outputs we actually want for each input + # - batch_size: the batch size in model serving, should equal to that in + # model serving config + # 常量定义 + DEFAULT_MAX_TOKENS = 1024 + DEFAULT_TEMPERATURE = 0.0 + PALM_MODEL_NAME = "text-bison-001" + + if scorer_llm_name == "text-bison": + config = { + "temperature": DEFAULT_TEMPERATURE, + "max_decode_steps": DEFAULT_MAX_TOKENS, + "batch_size": 1, + "num_servers": 1, + } + call_scorer_server_func = functools.partial( + prompt_utils.call_palm_server_from_cloud, + model=PALM_MODEL_NAME, + **config + ) + scorer_llm_dict = {"model_type": "text-bison", **config} + + elif scorer_llm_name in {"gpt-3.5-turbo", "gpt-4"}: + config = { + "temperature": DEFAULT_TEMPERATURE, + "max_decode_steps": DEFAULT_MAX_TOKENS, + "batch_size": 1, + "num_servers": 1, + } + call_scorer_server_func = functools.partial( + prompt_utils.call_openai_server_func, + model=scorer_llm_name.lower(), + **config + ) + scorer_llm_dict = {"model_type": scorer_llm_name.lower(), **config} + + elif scorer_llm_name == "local": + print(f"[DEBUG] local_model_path: {local_model_path}") + assert local_model_path, "Local model path must be provided." + config = { + "temperature": DEFAULT_TEMPERATURE, + "max_decode_steps": DEFAULT_MAX_TOKENS, + "batch_size": 8, + "num_servers": 8,# number of servers to use for local model + } + call_scorer_server_func = functools.partial( + prompt_utils.call_local_server_func, + local_model_path=local_model_path, + **config + ) + scorer_llm_dict = {"model_type": "local", **config} + + else: + raise ValueError(f"Unsupported model: {scorer_llm_name}") + +# if scorer_llm_name == "text-bison": +# # when prompting text-bison with Cloud API +# scorer_finetuned_palm_temperature = 0.0 +# scorer_finetuned_palm_max_decode_steps = 1024 +# scorer_finetuned_palm_batch_size = 1 +# scorer_finetuned_palm_num_servers = 1 +# scorer_finetuned_palm_dict = dict() +# scorer_finetuned_palm_dict["temperature"] = ( +# scorer_finetuned_palm_temperature +# ) +# scorer_finetuned_palm_dict["num_servers"] = ( +# scorer_finetuned_palm_num_servers +# ) +# scorer_finetuned_palm_dict["batch_size"] = scorer_finetuned_palm_batch_size +# scorer_finetuned_palm_dict["max_decode_steps"] = ( +# scorer_finetuned_palm_max_decode_steps +# ) + +# call_scorer_finetuned_palm_server_func = functools.partial( +# prompt_utils.call_palm_server_from_cloud, +# model="text-bison-001", +# temperature=scorer_finetuned_palm_dict["temperature"], +# max_decode_steps=scorer_finetuned_palm_dict["max_decode_steps"], +# ) + +# scorer_llm_dict = { +# "model_type": scorer_llm_name.lower(), +# } +# scorer_llm_dict.update(scorer_finetuned_palm_dict) +# call_scorer_server_func = call_scorer_finetuned_palm_server_func + +# elif scorer_llm_name in {"gpt-3.5-turbo", "gpt-4"}: +# # assert scorer_llm_name.lower() in {"gpt-3.5-turbo", "gpt-4"} +# scorer_gpt_max_decode_steps = 1024 +# scorer_gpt_temperature = 0.0 + +# scorer_gpt_dict = dict() +# scorer_gpt_dict["max_decode_steps"] = scorer_gpt_max_decode_steps +# scorer_gpt_dict["temperature"] = scorer_gpt_temperature +# scorer_gpt_dict["num_decodes"] = 1 +# scorer_gpt_dict["batch_size"] = 1 +# scorer_gpt_dict["num_servers"] = 1 + +# scorer_llm_dict = { +# "model_type": scorer_llm_name.lower(), +# } +# scorer_llm_dict.update(scorer_gpt_dict) +# call_scorer_server_func = functools.partial( +# prompt_utils.call_openai_server_func, +# model=scorer_llm_name.lower(), +# max_decode_steps=scorer_gpt_max_decode_steps, +# temperature=scorer_gpt_temperature, +# ) +# elif scorer_llm_name == "local": +# # local vLLM model +# scorer_local_max_decode_steps = 1024 +# scorer_local_temperature = 0.0 +# call_scorer_server_func = functools.partial( +# prompt_utils.call_local_model_server_func, +# model_path=local_model_path, +# max_decode_steps=scorer_local_max_decode_steps, +# temperature=scorer_local_temperature, +# ) +# else: +# raise ValueError( +# f"Unknown scorer_llm_name: {scorer_llm_name}. " +# "It should be one of text-bison, gpt-3.5-turbo, gpt-4, or local." +# ) + + # ====================== optimizer model configs ============================ + if optimizer_llm_name.lower() == "text-bison": + # PaLM text-bison 模型配置 + optimizer_llm_dict = { + "model_type": "text-bison", + "temperature": 1.0, # 更高的随机性以生成多样化解 + "max_decode_steps": 1024, # 最大生成长度 + "batch_size": 1, # 单样本处理 + "num_decodes": 8, # 生成8个候选结果 + "num_servers": 1 # 单服务器 + } + + call_optimizer_server_func = functools.partial( + prompt_utils.call_palm_server_from_cloud, + model="text-bison-001", + temperature=optimizer_llm_dict["temperature"], + max_decode_steps=optimizer_llm_dict["max_decode_steps"], + ) + + elif optimizer_llm_name.lower() in {"gpt-3.5-turbo", "gpt-4"}: + # GPT 模型配置 + optimizer_llm_dict = { + "model_type": optimizer_llm_name.lower(), + "temperature": 1.0, # 更高的随机性 + "max_decode_steps": 512, # 较短的最大长度 + "batch_size": 1, + "num_decodes": 1 , # 单次生成 + "num_servers": 1 # 单服务器 + } + + call_optimizer_server_func = functools.partial( + prompt_utils.call_openai_server_func, + model=optimizer_llm_name, + max_decode_steps=optimizer_llm_dict["max_decode_steps"], + temperature=optimizer_llm_dict["temperature"], + ) + elif optimizer_llm_name.lower() == "local": + assert local_model_path, "Local model path must be provided." + optimizer_llm_dict = { + "model_type": optimizer_llm_name.lower(), + "temperature": 1.0, # 更高的随机性 + "max_decode_steps": 512, # 较短的最大长度 + "batch_size": 8, + "num_decodes": 1 , # 单次生成 + "num_servers": 8 # 单服务器 + } + call_optimizer_server_func = functools.partial( + prompt_utils.call_local_server_func, + local_model_path=local_model_path, + max_decode_steps=optimizer_llm_dict["max_decode_steps"], + temperature=optimizer_llm_dict["temperature"], + ) + + else: + raise ValueError( + f"Unsupported optimizer model: {optimizer_llm_name}. " + "Must be one of: text-bison, gpt-3.5-turbo, gpt-4" + ) + +# if optimizer_llm_name.lower() == "text-bison": +# # when prompting text-bison with Cloud API +# optimizer_finetuned_palm_temperature = 1.0 +# optimizer_finetuned_palm_num_decodes = 8 +# optimizer_finetuned_palm_max_decode_steps = 1024 +# optimizer_finetuned_palm_batch_size = 1 +# optimizer_finetuned_palm_num_servers = 1 +# optimizer_finetuned_palm_dict = dict() +# optimizer_finetuned_palm_dict["temperature"] = ( +# optimizer_finetuned_palm_temperature +# ) +# optimizer_finetuned_palm_dict["num_decodes"] = ( +# optimizer_finetuned_palm_num_decodes +# ) +# optimizer_finetuned_palm_dict["batch_size"] = ( +# optimizer_finetuned_palm_batch_size +# ) +# optimizer_finetuned_palm_dict["num_servers"] = ( +# optimizer_finetuned_palm_num_servers +# ) +# optimizer_finetuned_palm_dict["max_decode_steps"] = ( +# optimizer_finetuned_palm_max_decode_steps +# ) + +# call_optimizer_finetuned_palm_server_func = functools.partial( +# prompt_utils.call_palm_server_from_cloud, +# model="text-bison-001", +# temperature=optimizer_finetuned_palm_dict["temperature"], +# max_decode_steps=optimizer_finetuned_palm_dict["max_decode_steps"], +# ) + +# optimizer_llm_dict = { +# "model_type": optimizer_llm_name.lower(), +# } +# optimizer_llm_dict.update(optimizer_finetuned_palm_dict) +# call_optimizer_server_func = call_optimizer_finetuned_palm_server_func + +# else: +# assert optimizer_llm_name in {"gpt-3.5-turbo", "gpt-4"} +# optimizer_gpt_max_decode_steps = 512 +# optimizer_gpt_temperature = 1.0 + +# optimizer_llm_dict = dict() +# optimizer_llm_dict["max_decode_steps"] = optimizer_gpt_max_decode_steps +# optimizer_llm_dict["temperature"] = optimizer_gpt_temperature +# optimizer_llm_dict["batch_size"] = 1 +# optimizer_llm_dict["num_decodes"] = 1 +# call_optimizer_server_func = functools.partial( +# prompt_utils.call_openai_server_func, +# model=optimizer_llm_name, +# max_decode_steps=optimizer_gpt_max_decode_steps, +# temperature=optimizer_gpt_temperature, +# ) + + # ====================== try calling the servers ============================ + print("\n======== testing the scorer and optimizer servers ===========") + scorer_test_output = call_scorer_server_func( + "Does the sun rise from the north? Just answer yes or no." + ) + print(f"number of scorer output decodes: {len(scorer_test_output)}") + print(f"scorer test output: {scorer_test_output}") + optimizer_test_output = call_optimizer_server_func( + "Does the sun rise from the north? Just answer yes or no.", + temperature=1.0, + ) + print(f"number of optimizer output decodes: {len(optimizer_test_output)}") + print(f"optimizer test output: {optimizer_test_output}") + print("Finished testing the servers.") + + # ====================== read data ============================ + print("\n================ prompt optimization settings ==============") + # from https://github.com/hendrycks/test/blob/master/categories.py + subcategories = { + "abstract_algebra": ["math"], + "anatomy": ["health"], + "astronomy": ["physics"], + "business_ethics": ["business"], + "clinical_knowledge": ["health"], + "college_biology": ["biology"], + "college_chemistry": ["chemistry"], + "college_computer_science": ["computer science"], + "college_mathematics": ["math"], + "college_medicine": ["health"], + "college_physics": ["physics"], + "computer_security": ["computer science"], + "conceptual_physics": ["physics"], + "econometrics": ["economics"], + "electrical_engineering": ["engineering"], + "elementary_mathematics": ["math"], + "formal_logic": ["philosophy"], + "global_facts": ["other"], + "high_school_biology": ["biology"], + "high_school_chemistry": ["chemistry"], + "high_school_computer_science": ["computer science"], + "high_school_european_history": ["history"], + "high_school_geography": ["geography"], + "high_school_government_and_politics": ["politics"], + "high_school_macroeconomics": ["economics"], + "high_school_mathematics": ["math"], + "high_school_microeconomics": ["economics"], + "high_school_physics": ["physics"], + "high_school_psychology": ["psychology"], + "high_school_statistics": ["math"], + "high_school_us_history": ["history"], + "high_school_world_history": ["history"], + "human_aging": ["health"], + "human_sexuality": ["culture"], + "international_law": ["law"], + "jurisprudence": ["law"], + "logical_fallacies": ["philosophy"], + "machine_learning": ["computer science"], + "management": ["business"], + "marketing": ["business"], + "medical_genetics": ["health"], + "miscellaneous": ["other"], + "moral_disputes": ["philosophy"], + "moral_scenarios": ["philosophy"], + "nutrition": ["health"], + "philosophy": ["philosophy"], + "prehistory": ["history"], + "professional_accounting": ["other"], + "professional_law": ["law"], + "professional_medicine": ["health"], + "professional_psychology": ["psychology"], + "public_relations": ["politics"], + "security_studies": ["politics"], + "sociology": ["culture"], + "us_foreign_policy": ["politics"], + "virology": ["health"], + "world_religions": ["philosophy"], + } + + categories = { + "STEM": [ + "physics", + "chemistry", + "biology", + "computer science", + "math", + "engineering", + ], + "humanities": ["history", "philosophy", "law"], + "social sciences": [ + "politics", + "culture", + "economics", + "geography", + "psychology", + ], + "other (business, health, misc.)": ["other", "business", "health"], + } + + if dataset_name == "mmlu": + # EITHER: filter by category + # category_names = [ + # "STEM", + # "humanities", + # "social sciences", + # "other (business, health, misc.)", + # ] + category_names = [task_name] + folder_name = "test" # one of {'auxiliary_train', 'dev', 'val', 'test'} + task_names = [] + for task_csv_name in os.listdir( + os.path.join(root_data_folder_path, folder_name) + ): + task_names.append(task_csv_name.split(".")[0]) + + tasks_in_category = [] + for category_name in category_names: + for task_name in task_names: + for subname in subcategories: + if subname in task_name: + if subcategories[subname][0] in categories[category_name]: + tasks_in_category.append(task_name) + break + + tasks_all = [(folder_name, task_name) for task_name in tasks_in_category] + multiple_choice_tasks = set([item[1] for item in tasks_all]) + boolean_tasks = set() + numerical_output_tasks = set() + + # OR: filter by task + # tasks_all = [ + # # ('test', 'abstract_algebra_test'), + # # ('test', 'college_computer_science_test'), + # # ('test', 'college_mathematics_test'), + # # ('test', 'college_physics_test'), + # # ('test', 'elementary_mathematics_test'), + # # ('test', 'global_facts_test'), + # # ('test', 'high_school_physics_test'), + # # ('test', 'machine_learning_test'), + # # ('test', 'management_test'), + # # ('test', 'medical_genetics_test'), + # # ('test', 'moral_scenarios_test'), + # # ('test', 'professional_psychology_test'), + # # ('test', 'public_relations_test'), + # # ('test', 'professional_law_test'), + # # ('test', 'high_school_psychology_test'), + # # ('test', 'high_school_world_history_test'), + # # ('test', 'human_aging_test'), + # # ('test', 'miscellaneous_test'), + # # ('test', 'moral_scenarios_test'), + # ('test', 'professional_psychology_test'), + # # ('test', 'security_studies_test'), + # ] + + elif dataset_name == "bbh": + tasks_all = [task_name] + assert ( + len(tasks_all) == 1 + ), "for now only support prompt optimization on one BBH task" + + # all BBH tasks are as below + # tasks_all = [ + # 'boolean_expressions', + # 'causal_judgement', + # 'date_understanding', + # 'disambiguation_qa', + # 'dyck_languages', + # 'formal_fallacies', + # 'geometric_shapes', + # 'hyperbaton', + # 'logical_deduction_five_objects', + # 'logical_deduction_seven_objects', + # 'logical_deduction_three_objects', + # 'movie_recommendation', + # 'multistep_arithmetic_two', + # 'navigate', + # 'object_counting', + # 'penguins_in_a_table', + # 'reasoning_about_colored_objects', + # 'ruin_names', + # 'salient_translation_error_detection', + # 'snarks', + # 'sports_understanding', + # 'temporal_sequences', + # 'tracking_shuffled_objects_five_objects', + # 'tracking_shuffled_objects_seven_objects', + # 'tracking_shuffled_objects_three_objects', + # 'web_of_lies', + # 'word_sorting' + # ] + numerical_output_tasks = { + "object_counting", + "multistep_arithmetic_two", + } + + multiple_choice_tasks = { + "date_understanding", + "disambiguation_qa", + "geometric_shapes", + "hyperbaton", + "logical_deduction_five_objects", + "logical_deduction_seven_objects", + "logical_deduction_three_objects", + "movie_recommendation", + "penguins_in_a_table", + "reasoning_about_colored_objects", + "ruin_names", + "salient_translation_error_detection", + "snarks", + "temporal_sequences", + "tracking_shuffled_objects_five_objects", + "tracking_shuffled_objects_seven_objects", + "tracking_shuffled_objects_three_objects", + } + + boolean_tasks = { + "boolean_expressions", # True or False + "causal_judgement", # yes or no + "formal_fallacies", # valid or invalid + "navigate", # yes or no + "sports_understanding", # yes or no + "web_of_lies", # yes or no + } + + else: + assert dataset_name in {"gsm8k"} + tasks_all = [task_name] + multiple_choice_tasks = set() + boolean_tasks = set() + numerical_output_tasks = set(tasks_all) + + if dataset_name == "mmlu": + raw_data = pd.DataFrame() + prediction_treat_as_number = False + prediction_treat_as_bool = False + elif dataset_name == "bbh": + raw_data = [] + prediction_treat_as_number = bool( + tasks_all[0] in numerical_output_tasks + ) # for now only check the first task + prediction_treat_as_bool = bool( + tasks_all[0] in boolean_tasks + ) # for now only check the first task + print( + f"prediction_treat_as_number: {prediction_treat_as_number}," + f" prediction_treat_as_bool: {prediction_treat_as_bool}" + ) + else: + assert dataset_name == "gsm8k" + raw_data = pd.DataFrame() + prediction_treat_as_number = True + prediction_treat_as_bool = False + + for t in tasks_all: + if dataset_name == "mmlu": + folder_name = t[0] + task_name = t[1] + single_task_df = pd.read_csv( + os.path.join(root_data_folder_path, f"{folder_name}/{task_name}.csv"), + index_col=None, + header=None, + ) + raw_data = pd.concat([raw_data, single_task_df]) + elif dataset_name == "bbh": + task_name = t + single_task_list = opt_utils.load_bbh_task_data( + task_name, base_dir=root_data_folder_path + ) + raw_data += single_task_list + else: + assert dataset_name == "gsm8k" + task_name = t + f_gsm = os.path.join(root_data_folder_path, f"gsm_{task_name}.tsv") + single_task_df = pd.read_csv(f_gsm, sep="\t", header=None) + raw_data = pd.concat([raw_data, single_task_df]) + + if dataset_name == "mmlu": + num_examples = raw_data.shape[0] + elif dataset_name == "bbh": + num_examples = len(raw_data) + else: + assert dataset_name in {"gsm8k"} + num_examples = raw_data.shape[0] + print(f"number of examples in the current task: {num_examples}") + + # ================ split data into train/val/test ========================== + if dataset_name == "mmlu": + train_ratio = 0.8 + eval_ratio = 0.2 + elif dataset_name == "gsm8k": + train_ratio = 0.035 + eval_ratio = 0 + else: + assert dataset_name == "bbh" + train_ratio = 0.2 + eval_ratio = 0 + + # train-validation-test split + # It is important to sort the indices, as this ensures the is_multiple_choice + # Boolean variables match the data points. + assert train_ratio + eval_ratio <= 1 + test_ratio = 1 - train_ratio - eval_ratio + print( + f"train_ratio: {train_ratio}, eval_ratio: {eval_ratio}, " + f"test_ratio: {test_ratio}" + ) + np.random.seed(0) + train_index = np.sort( + np.array( + np.random.choice( + num_examples, size=int(train_ratio * num_examples), replace=False + ) + ) + ) + eval_and_test_index = np.sort( + np.array(list(set(np.arange(num_examples)) - set(train_index))) + ) + eval_index = np.sort( + np.array( + np.random.choice( + eval_and_test_index, + size=int(eval_ratio * num_examples), + replace=False, + ) + ) + ) + + # ========== set other optimization experiment hyperparameters ============== + if scorer_llm_name == "text-bison": + old_instruction_score_threshold = 0.0 # 完全保留旧指令 表示不过滤任何历史指令(即使质量很低的旧指令也会保留)。 + # old_instruction_score_threshold = 0.15 # for GSM8K + elif scorer_llm_name == "local": + old_instruction_score_threshold = 0.3 + else: + assert scorer_llm_name in {"gpt-3.5-turbo", "gpt-4"} # 模型校验 + old_instruction_score_threshold = 0.3 # 过滤低质量旧指令 + + if scorer_llm_name == "text-bison": + extract_final_answer_by_prompting_again = False # 是否通过二次提示提取最终答案(例如从冗长响应中提取关键内容) + include_qa = False # 是否在元提示中包含问答对 + evaluate_in_parallel = False # 是否并行评估 + elif scorer_llm_name == "local": + extract_final_answer_by_prompting_again = True + include_qa = True + evaluate_in_parallel = True + else: + assert scorer_llm_name in {"gpt-3.5-turbo", "gpt-4"} + extract_final_answer_by_prompting_again = False + include_qa = False + evaluate_in_parallel = False + + optimizer_llm_temperature = optimizer_llm_dict["temperature"] + + num_few_shot_questions_for_instruction_refinement = 3 # number of few-shot questions 每次优化指令时参考的少样本示例数量(Few-shot QA对)。 + + # To change the number of generated instructions in each step, one should + # edit the value of the variable below, instead of editing the number of + # decodes in model parameters, because those values are limited by model + # serving configs. + num_generated_instructions_in_each_step = 3 # number of generated instructions in each step 每轮搜索生成的候选指令数量。 + num_search_steps = 50 # number of search steps 总优化迭代次数。 + + initial_instructions = [ + "Let's solve the problem.", + # "", + # "The answer is", + ] + few_shot_qa_pairs = True #是否使用少样本示例指导指令生成。 + # one of {'accumulative_most_frequent', 'current_most_frequent', 'random', + # 'constant'} + few_shot_selection_criteria = "random" #对多样性要求高时用 random,稳定性要求高时用 most_frequent。 + # whether to evaluate generated instructions on the exemplars in meta-prompt + evaluate_generated_ins_on_few_shot = False # 是否评估新指令 开发阶段设为 True调试指令质量。 + # whether to evaluate old instructions on the exemplars in the meta-prompt + evaluate_old_ins_on_few_shot = False # 是否评估旧指令 生产阶段设为 False加速运行。 + # every this number of steps, compute the accuracies of current-step + # instructions on the validation set + eval_interval = 3 # 每N步在验证集上测试当前指令的准确率。 + + max_num_instructions = ( + 20 # 元提示中保留的历史指令数量上限。 + ) + # 将连续分数离散化为N档(如0-100整数),简化模型理解。 + num_score_buckets = 100 + # whether to put old instructions and scores to before exemplars in + # 控制元提示中历史指令和少样本示例的顺序。 + meta_prompt_instructions_before_exemplars = True + + # ===================== run prompt optimization ====================== + + assert few_shot_selection_criteria in { + "accumulative_most_frequent", + "current_most_frequent", + "random", + "constant", + } + evolution_kwargs = { + "num_search_steps": num_search_steps, + "old_instruction_score_threshold": old_instruction_score_threshold, + "scorer_llm_dict": scorer_llm_dict, + "optimizer_llm_dict": optimizer_llm_dict, + "extract_final_answer_by_prompting_again": ( + extract_final_answer_by_prompting_again + ), + "include_qa": include_qa, + "evaluate_in_parallel": evaluate_in_parallel, + "tasks_all": tasks_all, + "train_ratio": train_ratio, + "eval_ratio": eval_ratio, + "test_ratio": test_ratio, + "train_index": train_index, + "eval_index": eval_index, + "dataset_name": dataset_name, + "task_name": task_name, + "num_examples": num_examples, + "root_data_folder_path": root_data_folder_path, + "optimizer_llm_temperature": optimizer_llm_temperature, + # "optimizer_llm_temperature_schedule": ( + # optimizer_llm_temperature_schedule + # ), + # "optimizer_llm_temperature_end": optimizer_llm_temperature_end, + "initial_instructions": initial_instructions, + "multiple_choice_tasks": multiple_choice_tasks, + "raw_data": raw_data, + "call_scorer_server_func": call_scorer_server_func, + "call_optimizer_server_func": call_optimizer_server_func, + "instruction_pos": instruction_pos, + "prediction_treat_as_number": prediction_treat_as_number, + "prediction_treat_as_bool": prediction_treat_as_bool, + "result_by_instruction_folder": result_by_instruction_folder, + "few_shot_qa_pairs": few_shot_qa_pairs, + "num_score_buckets": num_score_buckets, + "max_num_instructions": max_num_instructions, + "meta_prompt_type": meta_prompt_type, + "meta_prompt_instructions_before_exemplars": ( + meta_prompt_instructions_before_exemplars + ), + "few_shot_selection_criteria": few_shot_selection_criteria, + "optimizer_llm_name": optimizer_llm_name, + "num_generated_instructions_in_each_step": ( + num_generated_instructions_in_each_step + ), + "evaluate_generated_ins_on_few_shot": evaluate_generated_ins_on_few_shot, + "num_few_shot_questions_for_instruction_refinement": ( + num_few_shot_questions_for_instruction_refinement + ), + "evaluate_old_ins_on_few_shot": evaluate_old_ins_on_few_shot, + "eval_interval": eval_interval, + "save_folder": save_folder, + } + + opt_utils.run_evolution(**evolution_kwargs) + + +if __name__ == "__main__": + app.run(main) diff --git a/outputs/user_feedback.jsonl b/outputs/user_feedback.jsonl new file mode 100644 index 0000000..e504227 --- /dev/null +++ b/outputs/user_feedback.jsonl @@ -0,0 +1,2 @@ +{"session_id": "06e84e4b03c54df188961f4e033f9e77", "round": 1, "choice": "哪些奶茶含有较少的糖分?", "answer": "选择低糖奶茶时,可以通过品牌选择、糖分选项调整以及配料搭配来控制摄入的糖分。以下是一些具体建议和推荐:\n\n---\n\n### **一、低糖奶茶品牌及选项**\n1. **喜茶(HEYTEA)** \n - **推荐系列**:轻乳茶(如「轻乳茉莉奶绿」)、部分限定款(如「芝士奶盖茶」可选低糖)。 \n - **糖分控制**:默认糖度为50%,可选「三分糖」(约10%糖度)或「无糖」(使用代糖)。 \n - **注意**:奶盖本身含糖,建议选择「无奶盖」选项。\n\n2. **奈雪的茶** \n - **推荐系列**:鲜果茶(如「霸气橙子」)、「鲜果轻乳茶」系列。 \n - **糖分控制**:默认糖度为50%,可选「少糖」(约30%)或「无糖」(使用代糖)。 \n - **注意**:鲜果茶的天然果糖含量较高,需结合糖分选项综合判断。\n\n3. **CoCo都可** \n - **推荐系列**:「鲜果茶」(如「百香果鲜果茶」)、「轻乳茶」。 \n - **糖分控制**:默认全糖(约100%),可选「三分糖」(约30%)或「无糖」(代糖)。 \n - **注意**:部分经典奶茶(如「珍珠奶茶」)糖分较高,建议选择鲜果类饮品。\n\n4. **一点点** \n - **推荐系列**:「鲜果茶」(如「百香果鲜果茶」)、「轻乳茶」。 \n - **糖分控制**:默认全糖(约100%),可选「半糖」(50%)或「无糖」(代糖)。 \n - **注意**:部分奶盖茶含糖量较高,建议选择鲜果或轻乳系列。\n\n---\n\n### **二、如何选择低糖奶茶?**\n1. **糖分选项优先级** \n - **无糖/代糖**:使用赤藓糖醇、甜菊糖等代糖,几乎不含热量(但需注意部分人对代糖敏感)。 \n - **三分糖/半糖**:糖分约为全糖的30%-50%,适合控制摄入。 \n - **避免全糖**:全糖奶茶的糖分通常在"} +{"session_id": "af370656137a4b8f97d8ffac2c60b191", "round": 1, "choice": "瑞幸咖啡有哪些值得推荐的饮品?", "answer": "瑞幸咖啡近年来凭借高性价比和创新口味,推出了许多受欢迎的饮品。以下是一些值得推荐的饮品,涵盖经典款、特色新品和季节限定,适合不同口味和需求:\n\n---\n\n### **一、爆款必试系列**\n1. **生椰拿铁** \n - **推荐理由**:瑞幸的“现象级”产品,椰香浓郁与咖啡的苦香完美融合,口感顺滑,常年占据销量榜首。 \n - **小贴士**:可尝试“生椰拿铁+巧克力糖浆”组合,变成“生椰摩卡”,风味更丰富。\n\n2. **丝绒拿铁** \n - **推荐理由**:以红茶为基底,搭配丝滑牛奶,茶香与奶香平衡,适合喜欢茶咖融合的人群。 \n - **特色**:选用锡兰红茶,口感更醇厚,冷热皆宜。\n\n3. **厚乳拿铁** \n - **推荐理由**:使用厚乳(高乳脂含量的牛奶),奶香更浓郁,适合追求绵密口感的爱好者。\n\n---\n\n### **二、果味与创意系列**\n1. **冰椰拿铁**(夏季限定) \n - **推荐理由**:生椰拿铁的冰饮版本,加入冰块和椰香糖浆,清爽解暑,适合夏天。\n\n2. **蓝莓生椰拿铁** \n - **推荐理由**:在生椰拿铁基础上加入蓝莓糖浆,果香与椰香交织,甜而不腻。\n\n3. **蜜桃生椰拿铁** \n - **推荐理由**:蜜桃风味糖浆与生椰拿1:1搭配,清新果香与咖啡的苦香碰撞,适合喜欢果味的人。\n\n---\n\n### **三、季节限定款**\n1. **桂花拿铁**(秋季限定) \n - **推荐理由**:桂花糖浆与拿铁结合,香气扑鼻,甜度适中,是"} diff --git a/prompt_utils.py b/prompt_utils.py new file mode 100644 index 0000000..d60aaf2 --- /dev/null +++ b/prompt_utils.py @@ -0,0 +1,220 @@ +# Copyright 2023 The OPRO Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The utility functions for prompting GPT and Google Cloud models.""" +import openai +import time +try: + import google.generativeai as palm +except Exception: + palm = None +try: + from vllm import LLM, SamplingParams +except Exception: + LLM = None + SamplingParams = None +from pathlib import Path + + +# 缓存 vLLM 实例,避免重复加载 +_llm_instance = None + +def get_llm(local_model_path): + if LLM is None: + raise RuntimeError("vLLM not available") + global _llm_instance + if _llm_instance is None: + assert local_model_path is not None, "model_path cannot be None" + local_model_path = str(Path(local_model_path).resolve()) + _llm_instance = LLM( + model=local_model_path, + dtype="bfloat16", + tensor_parallel_size=8, + max_num_batched_tokens=8192, + max_num_seqs=64, + gpu_memory_utilization=0.7, + enforce_eager=True, + block_size=16, + enable_chunked_prefill=True, + trust_remote_code=True, + ) + return _llm_instance + +def call_local_server_single_prompt(prompt, local_model_path=None, temperature=0.8, max_decode_steps=512, **kwargs): + """ + 使用本地 vLLM 模型生成单个 prompt 的响应,替代原本 OpenAI API。 + """ + llm = get_llm(local_model_path) + sampling_params = SamplingParams( + temperature=temperature, + top_p=0.9, + max_tokens=max_decode_steps, + skip_special_tokens=True # 避免特殊字符触发协议错误 + ) + outputs = llm.generate([prompt], sampling_params) + return outputs[0].outputs[0].text + +def call_local_server_func(inputs, local_model_path=None, temperature=0.8, max_decode_steps=512, **kwargs): + """ + 批量处理多个输入 prompt。 + """ + + assert local_model_path is not None, "local_model_path must be provided" + # 强制类型检查 + if isinstance(inputs, bytes): + inputs = inputs.decode('utf-8') + outputs = [] + for input_str in inputs: + output = call_local_server_single_prompt( + input_str, + local_model_path=local_model_path, + temperature=temperature, + max_decode_steps=max_decode_steps + ) + outputs.append(output) + return outputs + + +def call_openai_server_single_prompt( + prompt, model="gpt-3.5-turbo", max_decode_steps=20, temperature=0.8 +): + """The function to call OpenAI server with an input string.""" + try: + completion = openai.ChatCompletion.create( + model=model, + temperature=temperature, + max_tokens=max_decode_steps, + messages=[ + {"role": "user", "content": prompt}, + ], + ) + return completion.choices[0].message.content +# 函数捕获了 ​​6 类常见异常​​,并在遇到错误时自动重试: + except openai.error.Timeout as e: # API 请求超时 + retry_time = e.retry_after if hasattr(e, "retry_after") else 30 + print(f"Timeout error occurred. Retrying in {retry_time} seconds...") + time.sleep(retry_time) + return call_openai_server_single_prompt( + prompt, max_decode_steps=max_decode_steps, temperature=temperature + ) + + except openai.error.RateLimitError as e: #请求频率超限(Rate Limit) + retry_time = e.retry_after if hasattr(e, "retry_after") else 30 + print(f"Rate limit exceeded. Retrying in {retry_time} seconds...") + time.sleep(retry_time) + return call_openai_server_single_prompt( + prompt, max_decode_steps=max_decode_steps, temperature=temperature + ) + + except openai.error.APIError as e: # API 错误(如服务器错误) + retry_time = e.retry_after if hasattr(e, "retry_after") else 30 + print(f"API error occurred. Retrying in {retry_time} seconds...") + time.sleep(retry_time) + return call_openai_server_single_prompt( + prompt, max_decode_steps=max_decode_steps, temperature=temperature + ) + + except openai.error.APIConnectionError as e: # API 连接错误(如网络问题) + retry_time = e.retry_after if hasattr(e, "retry_after") else 30 + print(f"API connection error occurred. Retrying in {retry_time} seconds...") + time.sleep(retry_time) + return call_openai_server_single_prompt( + prompt, max_decode_steps=max_decode_steps, temperature=temperature + ) + + except openai.error.ServiceUnavailableError as e: # 服务不可用(如服务器维护) + retry_time = e.retry_after if hasattr(e, "retry_after") else 30 + print(f"Service unavailable. Retrying in {retry_time} seconds...") + time.sleep(retry_time) + return call_openai_server_single_prompt( + prompt, max_decode_steps=max_decode_steps, temperature=temperature + ) + + except OSError as e: # 操作系统级连接错误(如网络中断) + retry_time = 5 # Adjust the retry time as needed + print( + f"Connection error occurred: {e}. Retrying in {retry_time} seconds..." + ) + time.sleep(retry_time) + return call_openai_server_single_prompt( + prompt, max_decode_steps=max_decode_steps, temperature=temperature + ) + + +def call_openai_server_func( #​​批量处理多个输入提示(prompts),通过 OpenAI API 并行或顺序获取多个生成结果​​。 + inputs, model="gpt-3.5-turbo", max_decode_steps=20, temperature=0.8 +): + """The function to call OpenAI server with a list of input strings.""" + if isinstance(inputs, str): # 将单个字符串转为列表,统一处理 + inputs = [inputs] + outputs = [] + for input_str in inputs: + output = call_openai_server_single_prompt( + input_str, + model=model, + max_decode_steps=max_decode_steps, + temperature=temperature, + ) + outputs.append(output) + return outputs + +#通过 Google PaLM API(Cloud 版)调用 text-bison模型生成文本​​,并包含基本错误处理和自动重试机制。 +def call_palm_server_from_cloud( + input_text, model="text-bison-001", max_decode_steps=20, temperature=0.8 +): + if palm is None: + raise RuntimeError("google.generativeai not available") + assert isinstance(input_text, str) + assert model == "text-bison-001" + all_model_names = [ + m + for m in palm.list_models() + if "generateText" in m.supported_generation_methods + ] + model_name = all_model_names[0].name + try: + completion = palm.generate_text( + model=model_name, + prompt=input_text, + temperature=temperature, + max_output_tokens=max_decode_steps, + ) + output_text = completion.result + return [output_text] + except Exception: + retry_time = 10 + time.sleep(retry_time) + return call_palm_server_from_cloud( + input_text, max_decode_steps=max_decode_steps, temperature=temperature + ) + +def refine_instruction(query: str) -> str: + return f""" +你是一个“问题澄清与重写助手”。 +请根据用户的原始问题: +【{query}】 +生成不少于20条多角度、可直接执行的问题改写,每行一条。 +""" + +def refine_instruction_with_history(query: str, rejected_list: list) -> str: + rejected_text = "\n".join(f"- {r}" for r in rejected_list) if rejected_list else "" + return f""" +你是一个“问题澄清与重写助手”。 +原始问题: +{query} + +以下改写已被否定: +{rejected_text} + +请从新的角度重新生成至少20条不同的改写问题,每条单独一行。 +""" diff --git a/session_state.py b/session_state.py new file mode 100644 index 0000000..08b78ad --- /dev/null +++ b/session_state.py @@ -0,0 +1,26 @@ +import uuid + +SESSIONS = {} + +def create_session(query: str) -> str: + sid = uuid.uuid4().hex + SESSIONS[sid] = { + "original_query": query, + "round": 0, + "history_candidates": [], + "user_feedback": [] + } + return sid + +def get_session(sid: str): + return SESSIONS.get(sid) + +def update_session_add_candidates(sid: str, candidates: list): + s = SESSIONS[sid] + s["round"] += 1 + s["history_candidates"].extend(candidates) + +def log_user_choice(sid: str, choice: str): + SESSIONS[sid]["user_feedback"].append( + {"round": SESSIONS[sid]["round"], "choice": choice} + ) diff --git a/user_prompt_optimizer.py b/user_prompt_optimizer.py new file mode 100644 index 0000000..933b929 --- /dev/null +++ b/user_prompt_optimizer.py @@ -0,0 +1,52 @@ +import re +import numpy as np +from sklearn.cluster import AgglomerativeClustering +from sklearn.metrics.pairwise import cosine_similarity + +from opro.ollama_client import call_qwen +from opro.xinference_client import embed_texts +from opro.prompt_utils import refine_instruction, refine_instruction_with_history + +def parse_candidates(raw: str) -> list: + lines = [l.strip() for l in re.split(r'\r?\n', raw) if l.strip()] + cleaned = [] + for l in lines: + l = re.sub(r'^[\-\*\d\.\)\s]+', '', l).strip() + if len(l) >= 6: + cleaned.append(l) + return list(dict.fromkeys(cleaned)) + +def cluster_and_select(candidates: list, top_k=5, distance_threshold=0.15): + if not candidates: + return [] + vecs = embed_texts(candidates) + X = np.array(vecs) + if len(candidates) <= top_k: + return candidates + + clustering = AgglomerativeClustering(n_clusters=None, + distance_threshold=distance_threshold, + metric="cosine", + linkage="average") + labels = clustering.fit_predict(X) + + selected_idx = [] + for label in sorted(set(labels)): + idxs = [i for i,l in enumerate(labels) if l == label] + sims = cosine_similarity(X[idxs]).mean(axis=1) + rep = idxs[int(np.argmax(sims))] + selected_idx.append(rep) + + selected = [candidates[i] for i in sorted(selected_idx)] + return selected[:top_k] + +def generate_candidates(query: str, rejected=None, top_k=5): + rejected = rejected or [] + if rejected: + prompt = refine_instruction_with_history(query, rejected) + else: + prompt = refine_instruction(query) + + raw = call_qwen(prompt, temperature=0.9, max_tokens=512) + all_candidates = parse_candidates(raw) + return cluster_and_select(all_candidates, top_k=top_k) diff --git a/xinference_client.py b/xinference_client.py new file mode 100644 index 0000000..dc9029f --- /dev/null +++ b/xinference_client.py @@ -0,0 +1,11 @@ +import requests +from typing import List + +XINFERENCE_EMBED_URL = "http://127.0.0.1:9997/models/bge-base-zh/embed" + +def embed_texts(texts: List[str]) -> List[List[float]]: + payload = {"inputs": texts} + resp = requests.post(XINFERENCE_EMBED_URL, json=payload, timeout=30) + resp.raise_for_status() + data = resp.json() + return data.get("embeddings", [])