From dd5339de32631d864007d8ff14134135697f5602 Mon Sep 17 00:00:00 2001 From: xxm Date: Fri, 5 Dec 2025 07:11:25 +0000 Subject: [PATCH] =?UTF-8?q?=E5=8E=9F=E5=A7=8B=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .trae/documents/修复预览访问与静态页面挂载.md | 37 + API.md | 360 ++++++ __pycache__/config.cpython-310.pyc | Bin 0 -> 721 bytes __pycache__/config.cpython-313.pyc | Bin 0 -> 795 bytes __pycache__/ollama_client.cpython-310.pyc | Bin 0 -> 722 bytes __pycache__/prompt_utils.cpython-310.pyc | Bin 0 -> 5562 bytes __pycache__/session_state.cpython-310.pyc | Bin 0 -> 1003 bytes .../user_prompt_optimizer.cpython-310.pyc | Bin 0 -> 2268 bytes __pycache__/xinference_client.cpython-310.pyc | Bin 0 -> 589 bytes .../__pycache__/api.cpython-310.pyc | Bin 0 -> 10299 bytes .../__pycache__/api.cpython-313.pyc | Bin 0 -> 16998 bytes _qwen_xinference_demo/api.py | 284 +++++ .../__pycache__/ollama_client.cpython-310.pyc | Bin 0 -> 1775 bytes .../__pycache__/ollama_client.cpython-313.pyc | Bin 0 -> 3115 bytes .../__pycache__/prompt_utils.cpython-310.pyc | Bin 0 -> 1107 bytes .../__pycache__/prompt_utils.cpython-313.pyc | Bin 0 -> 1307 bytes .../__pycache__/session_state.cpython-310.pyc | Bin 0 -> 1995 bytes .../__pycache__/session_state.cpython-313.pyc | Bin 0 -> 2839 bytes .../user_prompt_optimizer.cpython-310.pyc | Bin 0 -> 2374 bytes .../user_prompt_optimizer.cpython-313.pyc | Bin 0 -> 3362 bytes .../xinference_client.cpython-310.pyc | Bin 0 -> 1110 bytes .../xinference_client.cpython-313.pyc | Bin 0 -> 1741 bytes _qwen_xinference_demo/opro/ollama_client.py | 52 + _qwen_xinference_demo/opro/prompt_utils.py | 20 + _qwen_xinference_demo/opro/session_state.py | 56 + .../opro/user_prompt_optimizer.py | 55 + .../opro/xinference_client.py | 29 + config.py | 19 + examples/client_demo.py | 55 + frontend/index.html | 446 +++++++ frontend/react-app.html | 164 +++ frontend/react/index.html | 192 +++ frontend/ui_offline.html | 157 +++ ollama_client.py | 19 + opro_qwen_xinference_demo_package.zip | Bin 0 -> 6336 bytes .../__pycache__/opt_utils.cpython-310.pyc | Bin 0 -> 19855 bytes optimization/opt_utils.py | 1035 +++++++++++++++++ optimization/optimize_instructions.py | 736 ++++++++++++ optimization/optimize_linear_regression.py | 424 +++++++ optimization/optimize_tsp.py | 430 +++++++ optimization/test.py | 967 +++++++++++++++ outputs/user_feedback.jsonl | 2 + prompt_utils.py | 220 ++++ session_state.py | 26 + user_prompt_optimizer.py | 52 + xinference_client.py | 11 + 46 files changed, 5848 insertions(+) create mode 100644 .trae/documents/修复预览访问与静态页面挂载.md create mode 100644 API.md create mode 100644 __pycache__/config.cpython-310.pyc create mode 100644 __pycache__/config.cpython-313.pyc create mode 100644 __pycache__/ollama_client.cpython-310.pyc create mode 100644 __pycache__/prompt_utils.cpython-310.pyc create mode 100644 __pycache__/session_state.cpython-310.pyc create mode 100644 __pycache__/user_prompt_optimizer.cpython-310.pyc create mode 100644 __pycache__/xinference_client.cpython-310.pyc create mode 100644 _qwen_xinference_demo/__pycache__/api.cpython-310.pyc create mode 100644 _qwen_xinference_demo/__pycache__/api.cpython-313.pyc create mode 100644 _qwen_xinference_demo/api.py create mode 100644 _qwen_xinference_demo/opro/__pycache__/ollama_client.cpython-310.pyc create mode 100644 _qwen_xinference_demo/opro/__pycache__/ollama_client.cpython-313.pyc create mode 100644 _qwen_xinference_demo/opro/__pycache__/prompt_utils.cpython-310.pyc create mode 100644 _qwen_xinference_demo/opro/__pycache__/prompt_utils.cpython-313.pyc create mode 100644 _qwen_xinference_demo/opro/__pycache__/session_state.cpython-310.pyc create mode 100644 _qwen_xinference_demo/opro/__pycache__/session_state.cpython-313.pyc create mode 100644 _qwen_xinference_demo/opro/__pycache__/user_prompt_optimizer.cpython-310.pyc create mode 100644 _qwen_xinference_demo/opro/__pycache__/user_prompt_optimizer.cpython-313.pyc create mode 100644 _qwen_xinference_demo/opro/__pycache__/xinference_client.cpython-310.pyc create mode 100644 _qwen_xinference_demo/opro/__pycache__/xinference_client.cpython-313.pyc create mode 100644 _qwen_xinference_demo/opro/ollama_client.py create mode 100644 _qwen_xinference_demo/opro/prompt_utils.py create mode 100644 _qwen_xinference_demo/opro/session_state.py create mode 100644 _qwen_xinference_demo/opro/user_prompt_optimizer.py create mode 100644 _qwen_xinference_demo/opro/xinference_client.py create mode 100644 config.py create mode 100644 examples/client_demo.py create mode 100644 frontend/index.html create mode 100644 frontend/react-app.html create mode 100644 frontend/react/index.html create mode 100644 frontend/ui_offline.html create mode 100644 ollama_client.py create mode 100644 opro_qwen_xinference_demo_package.zip create mode 100644 optimization/__pycache__/opt_utils.cpython-310.pyc create mode 100644 optimization/opt_utils.py create mode 100644 optimization/optimize_instructions.py create mode 100644 optimization/optimize_linear_regression.py create mode 100644 optimization/optimize_tsp.py create mode 100644 optimization/test.py create mode 100644 outputs/user_feedback.jsonl create mode 100644 prompt_utils.py create mode 100644 session_state.py create mode 100644 user_prompt_optimizer.py create mode 100644 xinference_client.py diff --git a/.trae/documents/修复预览访问与静态页面挂载.md b/.trae/documents/修复预览访问与静态页面挂载.md new file mode 100644 index 0000000..256d598 --- /dev/null +++ b/.trae/documents/修复预览访问与静态页面挂载.md @@ -0,0 +1,37 @@ +## 问题分析 +- 浏览器报错 `net::ERR_ABORTED http://0.0.0.0:8010/`,常见于预览环境对 `0.0.0.0` 的访问被中止或跳转流程未完成。 +- 现有后端已启动并挂载静态目录到 `/`,但预览器可能对根路径加载敏感,或端口/主机解析不一致。 +- 目前 API 路由与静态挂载并存,使用相对 `fetch('/query')` 正常;问题主要是根页面加载。 + +## 修复方案 +1. 主机与端口访问 +- 推荐使用 `http://127.0.0.1:8010/` 或 `http://localhost:8010/` 访问,而不是 `0.0.0.0`。 +- 新增 `/health` 路由用于快速诊断服务是否运行。 + +2. 静态页面挂载位置 +- 将静态目录从根 `/` 改为 `/ui` 挂载,降低与根路径的潜在冲突。 +- 新增 `GET /` 路由,返回 `frontend/index.html` 或 302 跳转到 `/ui/index.html`。 + +3. 前端请求与错误处理 +- 保持相对路径 `fetch('/query')`、`/select`、`/reject`,保持同源;增强错误提示(显示响应状态与内容)。 +- 若需要跨源(前端单独部署),补充 CORS:允许前端源访问后端 API。 + +4. 诊断与验证 +- 使用 `curl http://127.0.0.1:8010/health` 验证健康。 +- 使用 `curl` 端到端:`/query`(新会话与带 `session_id` 再生)、`/select`(生成答案)。 +- 浏览器打开 `/ui/` 页面,执行完整流程:开始生成 → 拒绝并再生 → 选择并出答案。 + +## 具体改动清单 +- `_qwen_xinference_demo/api.py` + - 添加 `GET /health` 路由返回 `{status:"ok"}`。 + - 将 `StaticFiles(directory="frontend", html=True)` 从 `/` 挂载到 `/ui`。 + - 添加 `GET /` 路由,返回 `index.html` 或重定向到 `/ui/index.html`。 +- `frontend/index.html` + - 增强错误显示:同时显示响应状态码与文本(提升诊断能力)。 + +## 后续增强(可选) +- 为 `/query` 与 `/select` 增加耗时、来源日志,便于问题排查。 +- 在页面上展示历史候选与拒绝原因列表,提升可观测性。 +- 提供配置项切换嵌入优先级(Xinference/Ollama)。 + +请确认是否按以上方案进行修改与验证,我将立即实施并完成端到端测试。 \ No newline at end of file diff --git a/API.md b/API.md new file mode 100644 index 0000000..bdc1ad8 --- /dev/null +++ b/API.md @@ -0,0 +1,360 @@ +# 项目 API 文档 + +本项目提供用于 OPRO 风格提示优化与会话交互的 REST API。所有接口均使用 `application/json`,无鉴权。示例以默认本地启动地址为例:`http://127.0.0.1:8010`。 + +- 基础路径:`/` +- 前端页面:`/ui/`(三栏界面),`/ui/react`(React 示例页面),`/ui/offline`(离线备份页面) +- 内容类型:`Content-Type: application/json` + +### 统一响应格式 + +所有 JSON 接口统一返回以下包装结构: + +```json +{ + "code": 0, + "msg": "ok", + "data": {} +} +``` + +- 成功:`code` 固定为 `0`,`msg` 为简要说明(默认 `ok`),业务数据在 `data` 字段中。 +- 失败:HTTP 状态码保持原值(如 400/404/500),`code` 同步为该状态码,`msg` 为错误信息,`data` 为 `null`。 + +错误处理位于 `_qwen_xinference_demo/api.py:23-31`(异常处理器),成功响应包装器为 `_qwen_xinference_demo/api.py:21-22` 的 `ok()`。 + +--- + +## 健康检查 + +- 方法与路径:`GET /health` +- 作用:服务可用性检查 +- 响应示例: +```json +{ + "code": 0, + "msg": "ok", + "data": { "status": "ok" } +} +``` + +--- + +## 模型管理 + +### 获取可用模型 +- 方法与路径:`GET /models` +- 作用:列出可用于推理的 Ollama 模型(过滤掉 embedding/reranker) +- 响应示例: +```json +{ + "code": 0, + "msg": "ok", + "data": { + "models": ["qwen3:8b", "qwen3:14b", "qwen3:32b"] + } +} +``` + +### 设置当前会话模型 +- 方法与路径:`POST /set_model` +- 请求体: +```json +{ + "session_id": "", + "model_name": "qwen3:8b" +} +``` +- 成功响应: +```json +{ + "code": 0, + "msg": "ok", + "data": { + "session_id": "", + "model_name": "qwen3:8b" + } +} +``` +- 说明:`model_name` 必须在 `/models` 返回列表中;否则返回 400 错误。 + +--- + +## 会话与候选生成(提示优化) + +提示优化由以下流程实现:根据用户问题或最近消息构造“改写/变异”指令 → 调用 Qwen 批量生成候选 → 通过 Xinference(失败回退到 Ollama embedding)做语义向量 → 聚类去重并选取 Top‑K(默认 5)。 + +### 首次生成候选(创建会话) +- 方法与路径:`POST /query` +- 请求体(新会话): +```json +{ "query": "我想买苹果" } +``` +- 成功响应: +```json +{ + "code": 0, + "msg": "ok", + "data": { + "session_id": "", + "round": 0, + "candidates": ["...", "..."] + } +} +``` +- 说明: + - 新建会话并记录用户原始问题与首轮候选;`round` 会在候选入库后加 1。 + +### 继续优化(基于最近消息再生候选) +- 方法与路径:`POST /query_from_message` +- 请求体: +```json +{ "session_id": "" } +``` +- 成功响应: +```json +{ + "code": 0, + "msg": "ok", + "data": { + "session_id": "", + "round": 1, + "candidates": ["...", "..."] + } +} +``` +- 说明: + - 从会话的最近一条“用户消息”或原始问题作为基线生成新候选。 + +### 选择候选并回答 +- 方法与路径:`POST /select` +- 请求体: +```json +{ + "session_id": "", + "choice": "选中的提示词" +} +``` +- 成功响应: +```json +{ + "code": 0, + "msg": "ok", + "data": { + "prompt": "选中的提示词", + "answer": "模型回答内容" + } +} +``` +- 说明: + - 将 `choice` 记录为当前会话的 `selected_prompt`,并用该提示词生成回答。 + - 会把用户选择与回答追加到 `outputs/user_feedback.jsonl`。 + +### 拒绝候选并再生成 +- 方法与路径:`POST /reject` +- 请求体: +```json +{ + "session_id": "", + "candidate": "不合适的候选", + "reason": "可选的拒绝理由" +} +``` +- 成功响应: +```json +{ + "code": 0, + "msg": "ok", + "data": { + "session_id": "", + "round": 2, + "candidates": ["...", "..."] + } +} +``` +- 说明: + - 将被拒绝的候选加入会话历史,生成新一轮候选以“避撞并多样化”。 + +### 直接回答 + 候选(可选流程) +- 方法与路径:`POST /answer` +- 请求体: +```json +{ "query": "我想买苹果" } +``` +- 成功响应: +```json +{ + "code": 0, + "msg": "ok", + "data": { + "session_id": "", + "answer": "直接回答内容", + "candidates": ["...", "..."] + } +} +``` +- 说明: + - 先对用户问题直接回答,再生成提示优化候选。该路由默认使用后端配置的模型。 + +### 再次生成(旧接口,含 MAX_ROUNDS) +- 方法与路径:`POST /next` +- 请求体: +```json +{ "session_id": "" } +``` +- 成功响应(达到最大轮次时): +```json +{ + "code": 0, + "msg": "ok", + "data": { "final": true, "answer": "最终回答" } +} +``` +- 成功响应(未达到最大轮次时): +```json +{ + "code": 0, + "msg": "ok", + "data": { "session_id": "", "round": 1, "candidates": ["...", "..."] } +} +``` +- 说明: + - `MAX_ROUNDS` 当前为 3,仅对该路由有效;前端默认不使用此路由。 + +--- + +## 会话聊天 + +### 发送消息并获取回答 +- 方法与路径:`POST /message` +- 请求体: +```json +{ + "session_id": "", + "message": "继续提问或补充说明" +} +``` +- 成功响应: +```json +{ + "code": 0, + "msg": "ok", + "data": { + "session_id": "", + "answer": "模型回答", + "history": [ + {"role": "user", "content": "..."}, + {"role": "assistant", "content": "..."} + ] + } +} +``` +- 说明: + - 回答会在已选提示词(如无则原始问题)基础上拼接本次消息生成。 + +--- + +## 会话管理 + +### 列出会话 +- 方法与路径:`GET /sessions` +- 成功响应: +```json +{ + "code": 0, + "msg": "ok", + "data": { + "sessions": [ + { + "session_id": "", + "round": 2, + "selected_prompt": "...", + "original_query": "我想买苹果" + } + ] + } +} +``` + +### 会话详情 +- 方法与路径:`GET /session/{sid}` +- 成功响应: +```json +{ + "code": 0, + "msg": "ok", + "data": { + "session_id": "", + "round": 2, + "original_query": "我想买苹果", + "selected_prompt": "...", + "candidates": ["...", "..."], + "user_feedback": [{"round": 1, "choice": "..."}], + "rejected": ["...", "..."], + "history": [ + {"role": "user", "content": "..."}, + {"role": "assistant", "content": "..."} + ] + } +} +``` + +--- + +## 静态页面与重定向 + +- `GET /` → 重定向到 `/ui/` +- `GET /ui/` → 前端三栏页面(由后端挂载静态目录 `frontend`) +- `GET /ui/react` → React 版本示例页面 +- `GET /ui/offline` → 离线页面(无 CDN 依赖) +- `GET /react` → 与 `/ui/react` 等价的页面入口 + +--- + +## 错误码与通用返回 + +- 错误包装: + - HTTP 404:`{"code": 404, "msg": "session not found", "data": null}` + - HTTP 400:`{"code": 400, "msg": "model not available: "|"ollama error: ", "data": null}` + - HTTP 500:`{"code": 500, "msg": "internal error", "data": null}` + +--- + +## 调用示例(curl) + +```bash +# 创建会话并生成首轮候选 +curl -X POST http://127.0.0.1:8010/query \ + -H 'Content-Type: application/json' \ + -d '{"query": "我想买苹果"}' + +# 选择某个候选并回答 +curl -X POST http://127.0.0.1:8010/select \ + -H 'Content-Type: application/json' \ + -d '{"session_id": "", "choice": "选中的提示词"}' + +# 拒绝某个候选并再生成 +curl -X POST http://127.0.0.1:8010/reject \ + -H 'Content-Type: application/json' \ + -d '{"session_id": "", "candidate": "不合适的候选", "reason": "太笼统"}' + +# 基于最近消息继续优化 +curl -X POST http://127.0.0.1:8010/query_from_message \ + -H 'Content-Type: application/json' \ + -d '{"session_id": ""}' + +# 普通聊天 +curl -X POST http://127.0.0.1:8010/message \ + -H 'Content-Type: application/json' \ + -d '{"session_id": "", "message": "有无更甜的品种?"}' + +# 获取会话详情 +curl http://127.0.0.1:8010/session/ +``` + +--- + +## 备注 +- 候选 Top‑K 默认 5,聚类阈值默认 `0.15`。 +- 向量优先使用 Xinference(`http://127.0.0.1:9997/...`),失败自动回退到 Ollama embedding(`qwen3-embedding:4b`)。 +- 回答默认使用 Ollama 中的 `qwen3:8b`,或通过 `/set_model` 设置当前会话模型。 diff --git a/__pycache__/config.cpython-310.pyc b/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ed0201c76444e6355e8eb958fead21c6b2f8c27 GIT binary patch literal 721 zcmZWlU2D@&7*3jW>&NO~6LDVk!kb;JX}7_)B5G_;*J9R$o!Y9ATJ~ zK>QyZ6}RbYl_+2aF_=Xxj-UwUFps0}B8bC7h$s9WgNOGcC>pXDWkWcQhH(PLa1zDw z2uk1-O5#fzGdb!{$cieb71yzwURrK?rfvGfO^ZsUU8Z~b+dJJKdlUzU-@C_edM{s} zY;5*6-kxnd8@&3^f4W0u_jvd8>;CEH^WNc6_uE$g^Fi2Ph=@~>Q|%jITg zX8n|4G)Sbj8TFPQd-#YnX7ckje=KBE#I6yH%|<;xTk|s)<}EBN%n7!G ziPaKnbuv{mT4c&!6T-H$C_O`cCXAoAs*@v~xTq*PtUw8LMgv*|T0yNSup(DG$-uY= zRPCQA$Wy zMIfQ^zaCh+1B&P7<@=Rt38}JEX&qC*DxD`({vr(a#jsXSali?%zCD|-ih6yhiY5oCQBPzS1Ire S&+=@7<+((FL2^77V}ApSi{7gM literal 0 HcmV?d00001 diff --git a/__pycache__/config.cpython-313.pyc b/__pycache__/config.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82e4658e505368ba4c1ea996e20079fb5ab4ba78 GIT binary patch literal 795 zcmZXSOKTHR7=|aeHkT%W^r9(*yz;>iT* ztfnC1jW!FFT2Wi9U^luIDQ^4&=`PH+xwN}N|9~ekb<^)~KHm4?yTIX`JDqj{$NI{Q zlrs{l~5%Lw)`uK@h3P#bQ>><)l(RGMiNcm*y=j zEXWHv8OXN8J5t*(*1wo?m!hJ3gef9|IJL%NO1WO$f_46VDnGDd?am(0K< zo7Np&@POvPe-mZXkcvXx!7>b$qWA<0pg!C;i*%aiC`j{59HYDY5K@bv2Rt!4PAxHz zk0vO{r`Q-uTk0k-5x+O5d z(ICMxu><0~`>FH^63Mbc779gPL4s5w$>$eexqtZyn{w? zlL6m~fYX`DU@$a>CN*2bv#hxru2Id=xN@^T6+dzz9DXfH@O46UfZt}FdRpa;mZBI#!fa3} zAxq8BWkYO7^2$c4fkM@0xsSIAwd};G+D1@RQcc0eTABvk(h|!+o7G%t+ICralAa56@(xy=5tgd|*5v8qzLQP`;A z=cVEzV=kj*f&>3f^d|a#KT5SykXF1#P8E)PtKvny(o$8HyyE$|EMyJa&ADxjd0FOe zb?#$x2Z2)RAtSQC;h&9G?8V-ur|+Q)M}>PN+U+Na^i8MVc5nm1`Zy&UOt3Jpeo@0J dWL4BPxa(fV`#9WbDW7?0nHV7^j56;&{Rg$pzHa~k literal 0 HcmV?d00001 diff --git a/__pycache__/prompt_utils.cpython-310.pyc b/__pycache__/prompt_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa55b45abba71f465d171751f9b61e274ea5dbc1 GIT binary patch literal 5562 zcmbtYYmD5~6}COL$LpPaWRq-wB#mk5W1!gtDo~1|Bs4`l3QI&n(^y3#@A&TQY&`bl z+Ox@yW~+rP&#)nsKp~H^OL-R&R4Ne&kKdKrAMH=~AK95?rB|&!3%CaGI}&_g ztDafvv^$B4lW{dypOWk@ceUH$np3jf4XD%20P5*TyWetq++M)F0$$?w0bT-ljl0ZU z4!FyB+ZEXQ}PJA$C#F>9+AR9veRxXcR&<{>X>FssC=d+xb6 z?%wW{${yBibGTCqvvRC%bE1*FOjyl};0*4qxWivyh#>Ha#El5T-Y-8aM~Zm9zDDzM zxa@8SM$D%owI=PCXS7+26Icg~l&g}++_Y;zU+5IJFU(TK=p++XpdVu{_(dBEeYoadEtML&!jKhNS+7;&x=TAH&k@WvPv z*9&%zm%AvdHziJIKX=1A9t6?4lH=pIE`XV=EjdMBK6LqB<+4Bh-)e^U0rAX{7wMYcJTKt{dD>pCKcqGKC4WmYiU=#S{jALU{qlrTIH8mY+AVHKK}H~;Wrw`-mV{e z*Ah0{ z_m!D{+g;Xe+wM9&`LG~c_{*Q^`p4!SSvBsASgFjo6ID2~wJ0d(%CSkN&@_vs;qGEN z7nWJR2*D?loV+Eg35oC&;MgcIpNOzH9aN%nB?@CTV*8@35$lQcc@OaTazgGU@-8Dw z$b}Irhq1nUufq$Woq|&Ap0Mez!f;C=zlv6pmXl!#(O3gB@N$VYgp7|d4kPpt2w-*; zA#MgzNE$QLiz?N>=+=8>z8ceuRF}r2P~A2FH8?$-+_>Q6aOz0jukgN_Gy^rx!qE%T zgjSP&Dm~aUp+~8jUW1yUWECn-L&a(>JuA!7p8lF9l=?%E71D1!@xlCKPuE|WI&o z*MLs1W-6wv!pSbxMMhU;{$0F2*HtJ-YkxFw^I90;P%ys?b8Y$|-B6|!1Sxb>`y**o z9#tUolyEKt<6}wzIS$fSYii4#Mzx*Nc1eXRj_RZ7@jGScOqQzGMMh0?HR$sPHEp!R zmt7q?G~IMfKPtJ#@zj23F6x}gz(on_s!2k(9n-0u@bO)eYffmtf?kvVy>gR7_xSiC z4!hbobY3Bcp7FO9ad@(g1NMAHXXqXO`yvjf+BkGwA%`W0sji;#Yu;fV41Zvmlt z-40|+_@KooV_HEzU*Vj&1J?b7NmWtNx2m01h~VnG;lRMan~I1Cf(b&({Rq*PS1c7l zVqcyy1d#=cDN%eq`5OoW^nE7kvad(Jb~}P|F^_W2zfkM6bDY=cWGlD+H-fiX)$aBc z*`)|Xh+a<7?%q1$L|YN?Jx1k8tmiosStvgefNL11naGtW1JpwLMK9< zLD4zZ%1~E?#04S8TIexWj)jtQ3Afzk5{=zWGQEdHD#{q!fOrz9575P9C37p9@_68) zIK+!MecyHgWYe5NW?~u?%#)mH<}V&nX0e%=2JLh~Fy^;n4HdBAJ`9SgH)F1$DChMr zdW}A18Oj_Jv%1`O-k9uDzuEJ$1^B(H3R=CgdfrU6Xe#H#E*8abisK@NZv;SvvdF0h z6+f1vfUGIp?}Cis zkK|;$iw2S+zZL`Z@oj?Jag&+8!Mg?+B?N`E{yQPO^R5HqLC0tH0wIx%)pHXKa zXOyPMl&*%dQ=8DLy_6#%bxCAYT}T}{lM2*=+{~iJ_}$HJF+E{K9W!Rl5c#pQnIDbm z6lKTl65?eCqD?`xW&_RIusFt?Vtd!ax8MM6eHirxR8?=s9g-syU?vZCOvo;ZvQcSF z;z;_>D0`@rlI={Z?A$@MrfjbYE9G*)BXl1vrh$((CK?Ic5CK>jSRc!!h1xPVl7*0i zZF-N`QmLU!3Q88`D_yCs6Ur<*Xik~|U>GH4E>8{Tz(!G|g_{7hXe8BfGq}?cga_M8 zsOnAXr)C*FPmK5ghHNHQi2?7VnG05EcM{R0+d)C?J0)g!wT_%4BZ+nNMinoLJ8nlf zM`IlLd;`(=F@|{M*X%soVVg-|M|2jLx7|b4b<$-K*AOiMb=En0;N-n_XX_}mdK0m1 zuh98jn0p&{xE_O%GURSKgP5Z$`d6l=!1`OYp#Fuf8+fNSqmp@E`BJQowLKNa$LWkH z4E;7cp$;U4Tev|606HuQDPB3Af@P!zD?2e~&V4-7c;Up{$w^e%pG_W{|IIt|FTc?E zWU6uUk-3ux=O3G{Kl)<*$u}EM9QbVVI91+fPrTol`JgfV4)tA+zfVoo!I$coQ)zg!v%yzl0&sF5*TG$AvChPn#?nNN#p`Z#ql#$YU)ZE7IuL^Y(FZ2ci}S0i!PqtbEZ=lBk_M{#8lvz4{hU~)u- z-v;DZ<_k~Rv61Y1UzuK2lubh%ExKRQU;(e{&;P zsPYOqq-oBbdUfvPfrijs{rx|jefjtGL$B4}IYL%Z*(Fe5K69rIHlBYMs!pr`Hb#%4 zVL$uWdoUDXC1f*)$RHAOE%PFSARC%~?#zL=VHPkBG7I`*H5wFsq%2_>gJ5P~)U$Bl zcJGVV{118k6(mCPn^ok8DfLT5H_q`Ln1J)}rTFL6hJoYv2>w0|G>eHX?Gr2FbS~!x z`CN{Zr;G4oFTG6YzYwiOBH*_Z@|_rLbqsC4_#%{iOi0s5Ft0cyzW#99Wb&nf`gYMU zCfS)%rHB(r#umn}lFf@_Tw+*4G~%m)u@Pac;;~H>KEy{FwTr&}Pyv&n|MLpUBML=; czsjUk)lkgkX1CIhptKeDQuBP@>T8Vu0DUYpoB#j- literal 0 HcmV?d00001 diff --git a/__pycache__/session_state.cpython-310.pyc b/__pycache__/session_state.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f7401360691e9ff2029f2913a4eb177abfe3f78 GIT binary patch literal 1003 zcmZ8fJ&zMH5VgJD&9eD|6o*nEQP2bdp+KTJAtXdW6N*r!S(Y7NvXEQn>|N<1}8X>`qH-Y4Ac{HB!%)I9r-R&j}*Ts z3O3^)lz>2qGYJyHuW}YZgf{^#i18-r26t45{sh`+iM*xbBf8K-OAGjvjrup3nN3`H zQ7f}{kx^A`%AAvRw-6W$@w~xfAEKz(H@;<`#TDNmq&b^25ZAH_HvE)rC}}V;lwc%X zPn%+1l-V+E@`py(CAh(&u$4CJG|x&XAgh#h-P$UXo+$;Nviu9i%DOJVg(P3NV4*Hu zJbgQzzI#8Jj)WmpF0ciBXQqq#uP0-pbu~WEO59_;GWzLxR+L^FTV-vb%hXm_>)C4U zdbv@kraK4mvpp0E7ea8_$!{5%KAKDi#nF!txd4y<5$YiQ$cAl+JD(u9HriY;SJjC7 zXpgWFRK&H>wJ%L2EU`(rxkt|A|EJ_+-!H021L3PaMx|mKHe^>~BY=mffq03qum?CA zLqf#2oa5LaDmgJk@rIC_2*n;`q%JE}!YF8P4$$Sey^Z%ly`m=VA*UI@LF*{UQqq{; z@9rj2_S6ZVeX^&M{@BvY#7v5T=p}>TB>W5Ul+3jN literal 0 HcmV?d00001 diff --git a/__pycache__/user_prompt_optimizer.cpython-310.pyc b/__pycache__/user_prompt_optimizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eac9c0b6acd4c960d87eb2fe35b6313f04a3440b GIT binary patch literal 2268 zcmZuyO>Z1E7`8nhJ3G5cHk&kUr1B-&ifPrPe27!k0_~MtdO^dM)pYG_raND;z0HSa zC8Puz4n6P>k|TcuuJ8wZ<$xk_K;njk#QSYFb-|o z&supqZs&9H9P}Gvo(E!K)8$Pb9=dUdFY!6(Gd|B34!wAhFY}I=hyEh(Lch!D6B1oH z1z*TUM9tMthr_JQg-mq%LOjaGN(-45Ll{~dlqxMmU!{4PB{J1}U~CSOEbDLYhypq- zkv|i>uf?ub5i?yWHt~nFP+E=$IxP$H*6X(ZPO7*1TdC5e+%t|$c8o8D9?N1t(XClQ z+kdv9g+#F>0|PAW#YgX<1Y7V=fpX7P6$}s+Bu-y zA9?a-&5m(|QnJ0Eecd#{8MdY%;Ute2H6@8(Hb*w*@33E2lzUg>^W zX#jd($hB(Eg!ePRX8;2aTCmig=+Skuf1^sI68%9^@RTQ7sI~i92@z2bpmFxYzSolv zd&Pd|t1o+Zd$)VMx7Le#>ggRrP7%PM)Kh;oMzC=D&3gQMv&!yr#w?VzQm zBPy>P;S86Pnjwj}^U9rgwFh)?VWtlhY9JwM{{ZL!RJ!>ceMvVdQt2?z0opCu0^O|L zI#`8upNHJ%jaPKitOMPGUE6gN$+Su)Vcj^Mt3yk)+b;>M6WDe9x-na)Rky$w`VCtJ zzoO`$c@HIb#*IDvR%;=0^l?zb4iw%tA{;ipd zeIR6CiA)Ui|1t6ubsf>zO?Fdt$P%r59Awr(i5^5$z73!0 z;J?1T2PTE9E|X|pBB=~3D&r+mCVO%Pf56eWwV7&Ap0EV>RVlULpar2#Qxs#H4~0Y+ zjhh3>jguBS_7geGlU;cg>`~ipZ#-RA{gHeJ7j<8Vfim=qIGo`EiPZQOLKU~C_&~l) z`EKk&rD7LytBM^dDifRy+XF>xdbk^~>{L(P*P*tc;elPNs1OtdEygH5{@^UyrU~eUxorC*%$;X*^fXY-TyZ|t0;{!bD&@SsVm#_T; DzUVB+ literal 0 HcmV?d00001 diff --git a/__pycache__/xinference_client.cpython-310.pyc b/__pycache__/xinference_client.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75b9114728fc06d591647c1ef25707a355057a07 GIT binary patch literal 589 zcmYjOJ8#=C5GE<=(L8|kv3TjwO@l3`C=l2vilV6s6i7X!2vB$tWb^DOkVKU{0dZh& zUHc#6v44qcr)-@&^`xQ59qzmP$RqKQr;m>h7}?dIaZz%{es|=t1w`Ic*eQWxikEEO zWl{1<%mru4n|n%v06AdFr!2r9umM|16{zrsoCk`3WpVUC=h-agPF@ri??%^%%`{0) zU!RS}VK0|`4s0A_6>pJ zTUPKLM|Q(eY`GH3`|0h(R@{4g#u;U~XFG4}-TO+e*G^1GaHe<(U6s<|*y{9VbR>Osc#lPjqASu*!?N zuaeuMUV(vn3E8qNpvF?wCA{lqU?YNv9dRjMiGL#pl~Diy literal 0 HcmV?d00001 diff --git a/_qwen_xinference_demo/__pycache__/api.cpython-310.pyc b/_qwen_xinference_demo/__pycache__/api.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93a16cf8b78b46330eefd5e45281ddbd4e0ac4a5 GIT binary patch literal 10299 zcmb7KS#TWJb?tSgXD}EH7Ve}>QR`>{L~5rfiXaJ4l*j=EKwDdKdo-^B3^*HKH$-wg za43O^?Z~iXTV5oS7HLaWW|EStvXoP{lYFEqRr$zQ{<w@7=!py*<_5-lT%h^_gA8kM35K|DZwVrx$~LxV&A4q6kH(C8ey|YFV?j zvTo}t&1og0Y}#fyX2;5LJI-@@$tox81ow?nvfN|$aNjJY?39R!c&WFXw$tT4yRV$F zGv$7}AL}t;VSd0K;C=%AL3^+~WDoH;iSasn9rt_CUvIDHeyX&gywTps{oc~1@@9K8 z_tWTavA3YFh`!R+@-};$s<^6~n!g3tS8#2|b*spTelc)P6N6%if7gli{JTMHMB0+r=Gc z6>*2FE$Q~1?v4w0xp!Vv#htD$(3>&D-6DI|SW@k~c};7rX;`~Q+{ z6ZiAly*Oi7jNptBw+Cn3hw%eqC&oK@d_Ts!#Df?=$m3y*9}*8^{4kG4Fn&Zlit(d7 zegMz5Q#>X%;E5hz*2Hf2z8OtCA@=ZTyLk1r2Ks}J((zK+#ltd5HPy!sHY z9_Uyd!RkRV#;Xqlk8yDbcpPFp9%1OGJD?u{^kMM~LqCeu$HcR4LL31Hjsnk|m{^8P zV*YV)Ol*m;e@;xERY2X{JU=Be(fqW?@ca`z|9ory1@R)!@8S8ETKl{#Y@UCT=fB#T ze?`2?^Lu&zYpr=le4Xc?;`zK-7oGc>SjY4Gc)rk@7w#yY&J{EKbo)VzSuqD%%rPwv z0D7k=icQfjyWInTo)^a%`XHc7q73LVLys|ZrGs)~fUb%fLyrUcgph!i41I{9y$ltBADcQy~vL;zQh>zzzfArYl zp!ezN>0@K33U19WRx3e#(mk=@dOkohlddRAx8P5@Uaea3+#oexEV=E`GgA|}<|wNM zJyU+(FBS;K!!WTw@3}{-!Y$>lsdT6jqzlr``>x};o=1oE%({MSux+6xIu@P05KbXq z0f{l|1%0LJth3;`(kaYUiv<@aw5Oyy4?MG4FzE4)C5&9*)MT|>^MefCTA0iGP8kR1 zXX)Pl5Rmb3%7i6NK0h@!>5Pw!9o#>9;92L$#34)$&bk#>5@u~|Fr6smOC{&TNwFaGM2zy02)AOHQ;zxz0_A_986<0rp= z|I=TezxwgTtMB}YZB+5l3veCEmABuy`r&Us`Tb8XfA`0i&%Af-%y+K7`@xlOzk_Yi zzw*|lkt=Wg;Og(r+Xj)QW2aBw_|wZDUb3}?B2QjB|K61|Z(O_hH`mVpB5MYwU-V0^ z+>8@~`0K7DM%wX0wc_UszMZPp+)BPybo~6R7pV74XAQd-gsPr&s-t-n~t_&>4DOQTU z<4D3BdTV%YX{HUy)0dzEj#DV*Jr96HO?d}3_o2BTmq%D+)TCNp*ZDNV&Aqd6I)|j+ zae|~n%Ck_Se#&u9EaXekie9Yvf$sT|Hx7EEBMUB+FFKab9;UMyfxFNnBp3lp*`Z3B z3F0y|Y=Q2cDinxs%@6g7S1eAqpyMkIbzT*!ui;l;0`cYz2$wGOWesFAa^shDc`E<| zlerMmBP3xnk;Z3)gA&c6S)cc2u(an4P)OTUyyO~wVDGIojKS(DB4Z{c@z}J zG)=WsLrwmVd=kU8A7Lw2S3ZK!gpLO=_yNq^xSb@k?~F}OPE5+jaD&{9CTpyU^B!7X zfoH!n=lflVJ`Si3L^VQmFXq-pbO@`Yx`Z5Q0T7Hag7IAOzo6vf^)$He%1XYpoe3SJ z59g-GCUc`lR$?x1BOF&lbjV@*T`Y7Gb+n1A0ix0uxbDN;4LVF}L<`r~qiO2!^NI`Y z73v?xG$^K+xwn`AVfqZW8{qOMY7tAom0E zFs|zcnxv9ad1g(_Bs~i`^qhN&QAkG4!6}NIJcz{(*0oPLgoPutI8F_X);!@l>BNoC z&NR?K1D$jtNKCopbwLFAG~k0+=%#Oi!ZQGyz;&IJFR5569+l?SM8qQ{zZoL73FE5o zSHK)2Ld*;jlN=nh5QF26f{I)aLpe4;Y}BJ!hoBBnL+UP5XaXH^pQKcl9cNaTlRWDcMPSm+ zUI)R+ro%DOL%xD;NNsI197{%(HfHMPNCl$b_0Xg`t%p8~LHnT_s@Lno7n0Fg(%*z} z7se)#;J7Jo^9_v5O_cb?g%uJ)B=GRbFtT7`U>#6B)Lw{-6iC(!NZq%Hm<^qYW^$i* zZs;N%k*+Uy?Oo98GsK=8cW#F%*gnHF>z^8%nmRm@b8-{Y&iKUhxr6d~pdF;cs40&~ zlMKt%2l?}_PZbq-txH9(PR_M9%g3fG?Pb1Uvef1oheD-DDx|1Qrb`&2K`ag=j);h}wyP zFP&1cb}HT`6`+q=AHH+NLSx~ zk3RqW)HMnrLp9fBp1@wCh7|tVJ0!vBYN?bj=R>c&dwY~Sej^L2BgpeCULu`mN6Vl_gj-~&*G~;g zImvdAXh+~$6vuS8q|3( zvr;!=p}3#Gsz<&*)Qhy1hJn-6tnOcD&yo0um<)%79T?Zf5WiRQ*IL}(m-XdHg%QEX@g z)^lXI$V^|ZVE9Bbg><5?&TAbNl~Ao}2Pf|{RxQ|o;p?B-X?L)Q~TH-!5LEHK;kAuu}Zx`l=gqhl^i>KsPOmvwz9TUVb@)q7MG_%;~ z)pf6Ei|Qf@Hx2Cqf+eIh8u#!LhktcAD&w5-eeDFL?q(Q95S~F_=!aWnWA;j2(Z){! z7S5`Y{0*{hZTJDn_JZAgx6ufad#Eg2Kn;5DKLMN_O-zZI+80Uqa7#^oQQz96+wdN) zI(mE2Y2JiPaexYv!_ADIbkCJPAv9R=OvD~#Pm)*Y5C=9`M6Q3NLy1Hw7_69?{JfrS z@@eE7NVXOk^)8~at-u}3F>{34ag(I~7r6i5Xb(>MX#~3&-PXlBGDxi)>8~;PL$O^u4?^x^3iF~`+Xk+Ik5C{@ z#1Y#e#@m1_r{0U$PZt)FVS-FTiF%MF8BvxLKYa*U65h0OmXs2`4U@B^bd)8T-!wT( z;>=9+H4Kgo!bp))?hGWPKZ*_qZt(tR>ETCjY(8V(rC)T>ZarKU#B32Ht- zgYp8$RDvch(8S6#H;nv$N$aH6VGeMLdMqcgc0h0UyoQ%r(R(GkXJ!Fa=qSaQBltg} zrrQcg+V|gad8CqIoS8vLpN0@^?g;J?brNI&SJy|!XdkX-ov)AM$#J~dI_k|f<>)5j z^C9m2Pr84c8XDCH!+_C&eaOA>NmjbcZ{i^i;S5iej4Dz6SL z>J1$;kSGy5uE~3Ogxs2X0v2KsS#U)?v7w>Wc7p$~7K;XIP9zI#GjMMxWBQ^Q=GG09 zV@1l=DWgcivY4Vbl9}`kGDD1~5C{eIv*^kvs9_sH24>YX!&;GMklCT|#1)}baGF5{ zj~k^Tfwv$&znl!za$t~*tXc*c;6S~L?$@y)o0d>A4Ac-ea)PIwNukHt$R)u1LAgkb za+#-{XdEO|Z|`o{HBdSL`bQE$BZP%vk*G@w46K;ASPH1p)|LW#GeWQ2P)F3^$wR$n zf6h4|t(svi1SK^<%_>V8as#R=h6OTITN)Zc1$G8u!k4kDlAj6F4p5hrc7VEt z@YGmRnRaQSldh0()DM)w~X0|D6bl8uJ5=B$h-$Oq%49*V$B8(NoC5)3sfEC+K! zW}AFQ2lYsTgG7r`gTy>aS)!^5SU1E*RsIDgvuZB;tR0_$cH+IQ)aVEzrzCpX{n7le zjQ-!?h1nEYK7J2o4>lpfgl&fY*&r>t*Q%WssBaYuGC|ParRHbUaB|5{%BF&(L-tWa zNQRDvpL7sIdoxFO9PET%#Hr#h0OevG`7DXE9fp5lu?_E9i+HaX3Omuw-q1v^-1@7* zaG@$axq(&_l#NXvo<1^W_fhxY*wlf^!^fs6Ia?+GPS*}h=y(9%j>_OSbNwS%eY#Zyi zBvoUP=ny4VNtP#T4}=24@(%qZOnd z!WtOirWT#BcWC<_ozR@ct+{ zkkzV*_Vyt3_N9#8?p=k}%Nexx0Ku#X9)DZcwNuwPmQThOqN54{YdK9Du)t#k%Yzzn{UF@uMIRz7(>DO70kdI)Hp z%A(aqX`KdIgBi3AuR?1mLmCbPtutA)+9|EGK+Bgwt78>f=Q3z@04;wOtea=|jsOw;5z!}!l7t_#j))QMrFwEAm{IBT2ZO$VKO78UsU#PSj*s{z zC6V|7=f^_<5x$TvBI0Ynr%NXnOO)a$yoo{I1Q{Qjh{%Q5G;rP@@r}WU{Ug}6NI|50 z_{NY+EF2Kp`+Xhl?S~Jwwsre@`Z}Socti|~1UXA{1I6Zme{|G$@sb#Z7wc$90_sQx zyY6V~>lb=M!QiNP$xlSta9;F}M$S*#U>F`={qDmbe*e>NPk(yzyTC9X_K%6OH7H5} z5>lqXpwO8t1?*4fZ$13OuRff4=NGr;Ke_kYzxdvNeDtI5$)?l*N|fR!?|t(Z?|$pi z_ijC!eN*Pgq0^IKt}vgzJ^Sc8Km6pq=)>Roy@%5`9#3C;^!1yc&ddTC$bb6PJMK@v z`a6&Q{97`QLz6-27gzu7!|&XYb(0}l`1o5lKApb$_}0IF{H;H7>1BN+6d4ss6?{Q9 zzAh3vk;cGyIN}dPWZU?J81_$ud=dYMB(qzO!>rq3pvEux#z|--6qe*7y28Kmg`m=0TYJcUWSAmM)g;fQN+5Yvg7(86S_hM@Pqe5#kSp-8hiG(Q%k&H=SnR z&QYCOIl{PfUYCyG{zdiz0Cl04CMQH9n|!`dI27^uNHdn} zWnL0Thtrs}ST_57flCi^vG+*Cd{ z7&Dc}O)n=aH+`*u?1D;|R{CW#vVyD2=aVfyT%41mn78?S7bpFrshR?xZ#YDx$Y>}m zhR30r3xy*xCq+n#6X1}lHJ?8cA)!|%K^Fm*J;i@O1SQ~uPomjOCFy0Q8;$Tz5a7T{ zwE%wgCbOi|>Feg&mKb>4EvAqA8-Iy!|A>IHPaDZbqNrh+6E6n>+PDK5Bm7a}AB9Y& z@DH&4nRrl(Nni(YfN25@RqoUxc&c_N)Z(hhNwAzafosu$wDI1KI~<`hrl;aW$N7#T z1u|6uu%KOfy10}HZv$#nK2+2WEI^Z9LSAcX=wQi(KL5mo4|?i54+tr6art6V$*l0s>FcL|cYU&~;+^pI@Dn3{(C~!eZN?>r*Bj|D<{z_@ zKa!7Ba!TJI0iT^ zh_?;EXN*JbvyZ{kN)vk_@PZzCVxT=SdP6Iq6+5OJuoy%V_K()k36~vR-sA24-qxPn zE|Lb|k37|*5+B_>LxHTxJr207*s1mDPQk<`i#8<7>Y*dMpPEd-(-I0lrsWy{*^0A{ z83L6l=r%)*z!Ad;5BkRhW(ffs=Hg?V5LplA4Y_iM@!l~K^lHgV2wp~jJP{OG8MJ~w z93IDo4~7ruvcp()3IR%`vcykd4gu0##&!7?le9YC^e8x36M#Q0Jul-%uXvgA-2)_K zOKQ*eg@Tk3FF6gxxjdc4G6dL0<*Bgv%|3-EtjW*QR;XEhFzK*zbXZ{z0z1+mn*|ZA zEF4#KsmVD2vO%%UK7&F0IzfPqBpU!2w8G*q_a$uwZ+0tOVs(Blal)RJOV~sC9WW~( z5PHfv?3^GFM%mm?twrQq0D21NV9upXO|NVKD^(f~(_Xktt4ayMy^_3!AfI@QUoKf+w7)mo}cp7O1}{~f4`r5cC3OJ=|yuCdx=_4>7|95mULJA zj5!)bd<57qYmU;Agw<09H|&ih0v0b%?LCf&7qcuiBVcjxxBNo@S&Em*w^^xMQ8;tc z%8Th{29Y*R1FYJv->jCUpMq8^=vuh2ZqTaMopd&(Hq07k(5{vaumU%r6Zi&Br3zc0 z-b=!O0EQzz2sVSQ@Ekm3J(%0!pls3PM{*en$~B?Gjz&B|;xY%$4sqZswH7%tK})4j zP}XB(Ny3o=LrB>Z^>k>-6(}2mKdA=*Yyg(>Xz|U88x;x5#_6_2+xj`~w)vKMe)oIZ z?gtY4JPZ50@qOM8zmV8B6t|t3ZeQd#%x!$Y?^v|jqr&{Id+l+@p@g+9#=0PoVAKGX$ifKoXp;l?+7{u&Hw^*nu`TSfnBhYhZt=8K%Nh1TiC_tIzB6_8s?i^bL3qlL&AmI~1SE z4}K%T9A*!F(KCKkbON4?^TcoBI z9u)WIabp|MBK`Pf!7eb}mM~AB*06qn@2^NxS95|xC=d$Kza7Mo22FTYfhQF$oQ??_ z1`PexYEAk%pzlZPH7BjA5=yBu$~#)1%*q;g0Zu|7c3iOoD60Lg!SfmTmocLo)DPw# zO`uhosTemfgBX#Z$;kL&>4LIDR%jB;f@OenFpWAvJ6{kX3K8a7xnXvgIl!%FhS?3^ z!fJ{g6xc9~)q0rQ8peA>#mM>!%sygxUr{Nzs=VYk;EUu72)+nFHb_?_P+pI|c>~i= zUV6-c&QSJPj$B2we}e#b+sCskXidlX=%|0pulOPRYGhkqPfu%atKx^qW(XyKe)EST zbZ?f8ZW1DVNnde>JMCKcnxg5QJM8m=_ez@J6@no z!VoD5g-56w1{HS>nL(9ILxEK@? zDDVX5M$xROPi3yY%1eogeN>+U#$4UhAAa{Ep|kC%JCJ&G3zIvZ6BDe4{d1S$%Wsjjix_ zV%FCU2!rt2iRGaH!A0)xT7&) z-4o;Y{Ay)AV=G-|SlScVr{)^&7Tx7yHGA%c6OQ(nsr{dpIIR5jFKtNWZ|t4_wBkrL zvH#3u!&>{Vr6Vxl|H)R@Ze`v#Z)`uvyubaW_8r_`ZZkpt1FNNdEB8TZVS63-L7g7+ zJ1p({jUO~~Si7IXgO&~}_g5Ac;FQs=#J;n3Ce7@A4W{+~pPdPG7VeojNi_^X1E)2- zYDA^X!2H_I&7!tk#l$nHT2wGn^HJMjR?*SA3@ux%vbYK^E10N0*9qn<6VH;X=g1pC ziTox4RAq!3TU3uwi_4O$9|{pEgn+u0*``*e-=HjZYHWeJWzq{b=|e{e6GB?*V@hU5s3N>Zq`A^i2HdnGe4dZ#o;hFuHJIB7R{acJYge z3va}2UwU5aDgZ2M344QSc@KcqU44x1V30iJ;PaaSThN1p&p~LybPTO!^n8C0im4yw zRa7`ea6Z?gKZlJDBghH(+{Q8lNLIqhp#h&Sdy>@~raiOVZ1B3}r~C%RnMXWukg4IG ze{rG_{|yF23@~$(y5ag^-yBrOWEEUh(8J(asqH@71IsWha~B$$*N4tsOqo7VF6;0P zh8ViYkg^ta-kyS|mUl5_8IHIva+20-ZZN%EIWTCL{(RV}F$8{$Dxv(C5Lrl%w-y6< zzEL^3J8%Cv!3qXiw#u+QpURMVv|Xr?lUa8F9$MldhXS)) z0w8m*#jr;UOOGI@c=(7L3`TB}z&SI)z(-}Qa+I8(1qh)$#llW4{zX|)dNY}_99cz3a`YU!Nd&>W@RvH${KRUXxg0HtIo2ht z>tp>9|0g1;pq_fP_4dA0XHTV9E?3 z3>tTmvIYa?j84_Zkb#SgAhdXqF^rlUZ5)~)BFgk@dQ&Y0M>Ul#W=^mtbQ0%=^Q>a# z$UcUd1pPnZ9AF1@c{L{-HmMqN6%1|1S*1QBIMBc!)%Cu}FjnAMf(UPz&6Igbrp~YV6><{@I`@iihkKS>ptwfb?Gb9iWx0U&u>l%(aJb~eTLKO0(F`gIk^C6z zcLI?0;HiyC-6KpI^0+?nV^4lRWH8TY+N@+C)$sUw!TCwWa z*8c+()LpBe*|=n8)>I~oobQxgFP*(K$0mwulT~YPHr;57ZP=ZtYD`vE-K@J&7pvKs zsNA(;)>j*zVAFp9Nk!*(Ad|Vu2eb&rQzi-G)07(&P7B)rSoBJOSPcTh(?IQQRxqZXipjtV=G40d zpMtf)l-4hw#A!WBGOy&7wqaYn{0uJ*( z+k73|+^X2x?eWSTAOW%JhIo}bSq3U`d%SE1I>Il-E4C#|N@riY-kU5gnLQgX-t?r% zTwqvc%y?{WR*%i$vszLoYx%j4pDew~MJu$ye+KmE;K|0+^{dY~dipPji~w6@9X;6# z-|V59j8;#X+$myNB58tFR3sn#KXQCOKO_q7s%AX{x2XjB5IDwkAZ^2gOxNIIpu!*4 zf$eDA0T`f*v7b2!Yr-4^Tv4Ob1Sb$CI4jBjJ&XE1!3x1uKFcaj+3qSGd{LD5unZ@e+2*(^Eu3a1;KyDg4}ZuB~16% zSdW@q3D}`+tSoIqI{GGCUxho5>a|1b@FZLgO2q?b5dRhcs$#)khooF1po9^8gtZ8s z&0kl8Y$xW0cL(nbKHyuD&f<6KuGhuNT?yyb1!wELKeqGbcth)SchXr6cBHd@zWCj$ zJ5_P#p6Tv%M9zA*_|`*}FUB3i`J!>@gA(b(b@9%VamT5I^>mCs z{i_w@GseawvA2*&F&)@FY1oOf$S1nL*S3TsAvFu>)tzcg(l81{ju7lbw3 z3(n0d*qW`iCB-Gh-;Manh8)(_Fnv0-zXo)9^cUT%LC0u%5A;K7Ua)8CCYR$Wy@JXCy)w9tMpaA?ynjwS&y*9-Lu&%B z$cjNlb706Q;|f5g<}kx;$OygPEVE-Wk9tHoOGS%}YKv6p-bo`$St3brv;3d%C!s0w zEZQSkSTfzUXfK`hzaw53qu%*Li8A+3?K@}Kq}BFz)3v7AvacS9cKnUCcJAWs$y<~2 zrxrFhCpI@H%PXR;8@8FQpV^9@7#YW|B@-291#6&-&!7y>(53&sA@e*-KfUFFFv5l9yow^F`*Olnc?4RyyGJA3ni2uJ`pmP?jB6Enzm`@X za5|V?q{XU>23%-H#YrbDdPCOIF?y8+@>@7+bl0869PY4Ln=2a2S$AWch*pLmXJIIJ z3KWx>Zk(3^DJM!r_iF?!s9tFpG9dWbw`1@eDq26c@wV%h>j5}Z7CgJ^y52S4cyCYK z@-p0?p%~ldqzC-AjJq@mTSLyxnR}CQM|;BB5#u`)Z`$ZqF6js0yIR5tpdzsyGMOR) zmvB=Pn)aD9eN+fG5RBb`n}FT|S%UzDRS`lr=CGauEA|(IP0(Kq%KSmq?^nkhn-kV8 zF@B4}odsOC)Xx#JWL9G^_^-5GrMIfI%ryb4m@+m7kmDK#qz<^i1g34?lc53Ke!zeq z%4CKjAasx)o0^V>wCR|Z+5BrPKa~ub|GcdvFVBK8^ZeaqdTLvF8#+8yN~MS10P?>M z_w)U)j*4)oUEJGuxV?wo0%|?c+SSu~sHYud;v6)gmMnOoxDb^BbS~kdI$SEF;zBK5 zjD#yzsdBxHMjB*2{YF;6K{jg=LjE2a6~I7CxYV9y_gJ(QeBJuH_5}#OIuo`{)9p!& z4qtsEI*>5dDmQp)o9}TycKpB*YwnEI9Eo{+3CFpZ>D(e;Ivc<%+e+~F0pFxql=m8c zeDDVc;|@>4+8yJSo1iY%>+;CPVb~hs#})+F5aAJ^Aj?Im@&=OneUhXER%zVOgc!vt zh7!RY#$P58GuD_Ut_0z>Kqx>g_-cUPuMCGq$X~#Va(R!AGrB9vnSCMViV#qxMaLP} zpj=4GFTS#_+>>L7Mtz&(m^*>sEP?=n5d@)I~ zHiGXWxQzg8N{D`q!R5AyV)^_TmR2CRhk(W=e}p-VdC@!9H1t7zd-B&<7DwMsP!}1G^GtM}Zyfax`Hd~i0t5lx&0(TVH;wOxO)h#&~N5QmZg=g88 z6$8iC%^p}`-~oX_oqmBWi?e03Ur4Z>E9W4hU8uAwt z(3m1$QjpJR$n{SlUrDeVXyM#uihS3sA=lFK6$i@}%v@Q4ow#7dz_KMXAJlsUo|gd!t(r8OT(yO|!s1JP;K|4020* z^I;m1ETG~ug^SshLLFO+orOn!=I!he$~-<(%wv{L_gm(x=9#+dRdZAK_;}HQ`$wkD zai;fUmVax_)ipE4vxjH*POpivl}~gmy9FCpEn#lSfE*%9MRZHF5>YlGN)4jGOhPV= zQWZTMrBrGWh4vFV4!JZ+bu<_~p;6hRQQ4D5DUX&!4T!Q{qf(PYr7G%)zKAGWp6HAU zyRIe7Ejd%{#>^FSm2)=4*|DUvgK%=o3_S3I4YXioosKP1~gmkDf+Gwk+vPY!5rTf0@B2;<}#DB5JGs EU-5h= MAX_ROUNDS: + ans = call_qwen(s["original_query"], temperature=0.3, max_tokens=512) + return ok({"final": True, "answer": ans}) + + cands = generate_candidates(s["original_query"], s["history_candidates"], model_name=s.get("model_name")) + update_session_add_candidates(req.session_id, cands) + return ok({"session_id": req.session_id, "round": s["round"], "candidates": cands}) + +@app.post("/select", tags=["opro"]) +def select(req: SelectReq): + s = get_session(req.session_id) + if not s: + raise AppException(404, "session not found", "SESSION_NOT_FOUND") + + log_user_choice(req.session_id, req.choice) + set_selected_prompt(req.session_id, req.choice) + log_chat_message(req.session_id, "system", req.choice) + try: + ans = call_qwen(req.choice, temperature=0.2, max_tokens=1024, model_name=s.get("model_name")) + except Exception as e: + raise AppException(400, f"ollama error: {e}", "OLLAMA_ERROR") + log_chat_message(req.session_id, "assistant", ans) + try: + import os, json + os.makedirs("outputs", exist_ok=True) + with open("outputs/user_feedback.jsonl", "a", encoding="utf-8") as f: + f.write(json.dumps({ + "session_id": req.session_id, + "round": s["round"], + "choice": req.choice, + "answer": ans + }, ensure_ascii=False) + "\n") + except Exception: + pass + return ok({"prompt": req.choice, "answer": ans}) + +@app.post("/reject", tags=["opro"]) +def reject(req: RejectReq): + s = get_session(req.session_id) + if not s: + raise AppException(404, "session not found", "SESSION_NOT_FOUND") + log_user_reject(req.session_id, req.candidate, req.reason) + cands = generate_candidates(s["original_query"], s["history_candidates"] + [req.candidate], model_name=s.get("model_name")) + update_session_add_candidates(req.session_id, cands) + return ok({"session_id": req.session_id, "round": s["round"], "candidates": cands}) +class QueryReq(BaseModel): + query: str + session_id: str | None = None + +@app.post("/query", tags=["opro"]) +def query(req: QueryReq): + if req.session_id: + s = get_session(req.session_id) + if not s: + raise AppException(404, "session not found", "SESSION_NOT_FOUND") + cands = generate_candidates(s["original_query"], s["history_candidates"], model_name=s.get("model_name")) + update_session_add_candidates(req.session_id, cands) + return ok({"session_id": req.session_id, "round": s["round"], "candidates": cands}) + else: + sid = create_session(req.query) + log_chat_message(sid, "user", req.query) + cands = generate_candidates(req.query, [], model_name=get_session(sid).get("model_name")) + update_session_add_candidates(sid, cands) + return ok({"session_id": sid, "round": 0, "candidates": cands}) +app.mount("/ui", StaticFiles(directory="frontend", html=True), name="static") + +@app.get("/", tags=["ui"]) +def root(): + return RedirectResponse(url="/ui/") + +@app.get("/health", tags=["health"]) +def health(): + return ok({"status": "ok", "version": config.APP_VERSION}) + +@app.get("/version", tags=["health"]) +def version(): + return ok({"version": config.APP_VERSION}) + +# @app.get("/ui/react", tags=["ui"]) +# def ui_react(): +# return FileResponse("frontend/react/index.html") + +# @app.get("/ui/offline", tags=["ui"]) +# def ui_offline(): +# return FileResponse("frontend/ui_offline.html") + +@app.get("/react", tags=["ui"]) +def react_root(): + return FileResponse("frontend/react/index.html") + +@app.get("/sessions", tags=["sessions"]) +def sessions(): + from .opro.session_state import SESSIONS + return ok({"sessions": [{ + "session_id": sid, + "round": s.get("round", 0), + "selected_prompt": s.get("selected_prompt"), + "original_query": s.get("original_query") + } for sid, s in SESSIONS.items()]}) + +@app.get("/session/{sid}", tags=["sessions"]) +def session_detail(sid: str): + s = get_session(sid) + if not s: + raise AppException(404, "session not found", "SESSION_NOT_FOUND") + return ok({ + "session_id": sid, + "round": s["round"], + "original_query": s["original_query"], + "selected_prompt": s["selected_prompt"], + "candidates": s["history_candidates"], + "user_feedback": s["user_feedback"], + "rejected": s["rejected"], + "history": s["chat_history"], + }) + +class MessageReq(BaseModel): + session_id: str + message: str + +@app.post("/message", tags=["chat"]) +def message(req: MessageReq): + s = get_session(req.session_id) + if not s: + raise AppException(404, "session not found", "SESSION_NOT_FOUND") + log_chat_message(req.session_id, "user", req.message) + base_prompt = s.get("selected_prompt") or s["original_query"] + full_prompt = base_prompt + "\n\n" + req.message + try: + ans = call_qwen(full_prompt, temperature=0.3, max_tokens=1024, model_name=s.get("model_name")) + except Exception as e: + raise AppException(400, f"ollama error: {e}", "OLLAMA_ERROR") + log_chat_message(req.session_id, "assistant", ans) + return ok({"session_id": req.session_id, "answer": ans, "history": s["chat_history"]}) + +class QueryFromMsgReq(BaseModel): + session_id: str + +@app.post("/query_from_message", tags=["opro"]) +def query_from_message(req: QueryFromMsgReq): + s = get_session(req.session_id) + if not s: + raise AppException(404, "session not found", "SESSION_NOT_FOUND") + last_user = None + for m in reversed(s.get("chat_history", [])): + if m.get("role") == "user" and m.get("content"): + last_user = m["content"] + break + base = last_user or s["original_query"] + cands = generate_candidates(base, s["history_candidates"], model_name=s.get("model_name")) + update_session_add_candidates(req.session_id, cands) + return ok({"session_id": req.session_id, "round": s["round"], "candidates": cands}) + +class AnswerReq(BaseModel): + query: str + +@app.post("/answer", tags=["opro"]) +def answer(req: AnswerReq): + sid = create_session(req.query) + log_chat_message(sid, "user", req.query) + ans = call_qwen(req.query, temperature=0.2, max_tokens=1024) + log_chat_message(sid, "assistant", ans) + cands = generate_candidates(req.query, []) + update_session_add_candidates(sid, cands) + return ok({"session_id": sid, "answer": ans, "candidates": cands}) + +@app.get("/models", tags=["models"]) +def models(): + return ok({"models": list_models()}) + +@app.post("/set_model", tags=["models"]) +def set_model(req: SetModelReq): + s = get_session(req.session_id) + if not s: + raise AppException(404, "session not found", "SESSION_NOT_FOUND") + avail = set(list_models() or []) + if req.model_name not in avail: + raise AppException(400, f"model not available: {req.model_name}", "MODEL_NOT_AVAILABLE") + set_session_model(req.session_id, req.model_name) + return ok({"session_id": req.session_id, "model_name": req.model_name}) diff --git a/_qwen_xinference_demo/opro/__pycache__/ollama_client.cpython-310.pyc b/_qwen_xinference_demo/opro/__pycache__/ollama_client.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf092deefe992542cfb708ea4372e2b7469c1201 GIT binary patch literal 1775 zcma)6OK%)S5bo-E?Cg5|Ahtq6upA*0F$nCyfdjG(OT12yoDGT{MJ%#LlkMJR?0LDn zE!(SE6cJ1Ifgo`N<={Q$FZC59_X z$nR!yd5q!YOC-sVF~T4w?U95wlylmqoNiie8@bbVkyqLt8)r7JGKbfg%g0y+b@+-f zk5{`jR%NvVr(MU{IImtOUCPE-{lIQdFjyo0#0j26T7Ulbs~IO>oDfQgc5{&>xzb~b zCpj0P+7VpWl5j6j*%nVF>KWs4kcJ7@uHa}-JGSw3I;tp5G8U5j0~rH_M<0?yIHX;R zQDz;~2X9AjI&s8PHJ|TmFUZ6F zI18D6wxw#s&8V{Og)AkORN@_dLq|&vAtE-I%jo|U4RXv3nNgi}F z5lDsjl;R2=q^%xTx_W*=VuR@0+|ue68RRvf&nCdd zHPqe+lZ9^ylZP-*Xv?1*g=$=Ax09vasHdk-`{7=*)qK=gZ% zH&+^uSJ#8P_Z#a$*=S`};%#Gfy4d4NTT!Y?tZLhYUWRfhsf*0KMXc;j%$Lj?CruxD P#I%7+T}7!3lzuL7vj{ucFw&t z&?Xa_Hpyu#5!OY}Di2JpG%{6c`m{dmp-Oq!+wjO{>jY_$HciT#8LF0LFZ-YC4=801 zd)$%E*Z=GM&;OkB`=4!>%YmSL_*Q3(hzR|aPD;gD9Bi+^;2u(t!n}@fp=U5-%~{Nj zvm>Cnw_qOVmhJCyun6?Z@6zZDONw1#$L$I??ojw~ry{_ZIblOc5y$Oa4#n1huJhL! zTmf^}xLp99E>@8!6<1C(vf~*wGfc!3P9{<&ahh5prQxtSgEitvgl9u$@@-8wV4RF< z@sJ))Xhg&s@bw5!b&Ca-L>$!ZpRou{pb2Jzh0QZ#B|a>xvMN%UE>2-#dFeW@^D1|q zdATSNg>d-L7-$!YF1&=s$}8*_tX4E8m84a^6{$iuQ6+scZi+=9%znc30Uo>nx8b{Ef-TU%<-@Xln?<1s3Dp!h4 zcmW}m+lkHDn+Sz}j?n)ap&bkV50Nt17I-$g|NL3|4Z;sZC$hgy(^rqFq=1x+zp zQF4kYQDkYDFo{UAAsfpcMqv3$SO>b3o;7Z&MI6xe+0?=aKFrAdh>`= zpGkyLSc}FYrpyz;dIOkVb0Qih*1`#5OQy_NQa2ueBDMnu(>G?xST+o5$TvSfPx?D> zGHG_i%cmt;^LlgVQ){GSFP9uo$rZ#KJ4#L6CiDT%E zzQAB;*gq7YLN{hK!!!t=N*X3{yfHZFAM%IPkwM~tof=wbDv3je38-S=T3Cp1)0#=_ zmx94J0~jYUu?J=&niWXdz7xAbDCw!;Dyn5m;aDP!=Lk>TB78J#h6$q?H0)QeUiDc8 z(-%eT<oq*zi|8g z+xM?Lxc2boMtjfit26GwG@p|yvQl$KYJM!WY&xjaPcsrJm=TPH)!PoN+d9BZpA=cW3SL@y{I_4V`&bbk%Oww=CWM z#qH0oe0lA6?{1uaZ9RB7Q~#s8e9lvo^&HK3jz0FZ%b)!e_n!X|@-hhlWS1dbfk*c$tHUE4p~+)UDSU1ETh03X{7hUQH=;c*@Ukm&`7T{)rkNeA)6Ot~dLJ7>H z;O7vyuX6h(r=29`LqnGka#HJ@7hM(&@X??ic7liESVd`j`~3@AA^)qpW%fxk-NhJ{ z{0dXQPJ&w~I`QwWmRZXgDjfuvni@}+~Jpnaqa%d&0w8AT$Q z!FA9P(LzuF3i^-0bPMk`Yl%rM8jb1Gw=jG>^tX$hhHp(yYqAqJffV*pZ7Jv5OjB z3uvK-&sV??XFO(@p@O%V-nC;;jXyv`0mLOVZC2DQ48A{j|G;WzrlLK~X!&QqNho>bMP{W)*VQuku_ za@|Mgv);}X|E~kT3aqw$I-KtNi`$zvbDr9zip7fM;GaFMTQ*c(za^pO)(y|ZCxN>+ zKCk}b;BOADpHkM<%a2E|WJkv{qvKimt);4u18HI1Gm&%Gq$~dQ4-Qp@nCE|U)&MSm z=bnc2tv^bQc@8Eo3^;N0)849!o#;{hLE4_|xOjs9=CDXxnFGoG0i{;= z1^WH!U@&y?l0O)-5{z|(xPfwoDad&y;bOXJK|0~-5m{%0=_d`-%w9M==^Px~OkBHw zDcCVkjU+U%VTO5v+)t4G335Gkd9$vzjH_+lk#kmMorf~cL-SIOQjcX^$L1Ydd?QnN wH%L zF-)q)Fen^1Hupfk30-f3LB=UGhh(AXfmBb_ybrN}{e8imE!pL?UH-V2ez*T-eSc}z z{W0p6pV{Ty{#e#|{?>W1=1yex(#1;pNuXBRbQiYV?0R+PgFC)iEsQ$3d1p0K&pg;I z1S;vgc^9-R>FGdqX5Jmo+vTjYF>P<>+FIS>qVv8`TYc$luT;__&Qz&7_sz{MyA!Lm z#Ys^6eS?`TIQU?9(k)Fv9PS5+?D#bl!jwIYvbXW!@OFf1nh}Y|`y_*QDQZMZP)&+! zv95ST>K`DgG^nWEgc3DIB$@PSd_x2_^(73+G?cg=P9&{nO1c%5NQ$Z()QlO5rm}-{ z=8u2{kN=2zVQK{_=_N6P$WmOOj4LO$#?;(4rapY7J7;ov5IAJ zSdhKk4_`_ER~m|5W@yyHfj?Yi>P%H)-lxfCoOEy! KS*RJc_7%Q6oB9CkDWMmTD1{HNC8XbfF&BYF;y>-4;7W@p{ayrDx^rO;B1nOoZWPG zmnNK+s`M|7>!5^&KvcJ)O$sO}0V(e9xl%8!;Is!0_3qj^btS`r8*j!os^|eR($0JH z=FNMv?|svV$76uwlP4){MgXX?MSDbAO5{NCQ7mOHM*N&Ne;RpE2PD@At?^pC(#=79+rcp4X@&O zmG7TB-)>)B-@dUFJUSaxzVIr=?elZ~xv%|8w}Xqb&z+U3b0${b*b2&bgSqwE!dh@) ztG0aBFJALkXPYNqJXnrZosyly*i~mfR$I6hTqt>!IseYQcdyugAXvHXf45v;{nEd8 ztLl92U)iXA^Gi^?8C+bgU%!meS~pm@i-})6xEyR;K|9_L53=;PAxS9K4@~x_3HNA< z8itiB6ebmm4r_YKn4*SKF!IBNlyd5{sw>l)KB`hx&#TIaI%#l+6go{#S&D6G1+#Bz z#@$WTQB7ABO*buS=Pk|9X%rK%BxWy84upcZO9~?($DG3<2wfzM$~&YGWLavU1GsUj zPN;cH9Z?FJY311^+e1c$b))eSPK%Hu@3eM?Uk7aUhc9=nwZmGUBYWUbYkjEYbAn_= zUC2tjmvT+WOB}w5kVSi$fvTG*=SLi$75^Ay^N+c+Q)$JS@Q+0KEie}vYS4*dy)~s=59IIf^ znZ5zu4KvgC&l&a-%-3+*#=1W3&1}Zk4wi-T!ScS9@y+;Ori5E(cJ~k$>?5fDe@Oq0 zr2Qrf6Z7uh^eVFfXUpH(tl#*-FWvIjml?J8$J_<+_U;#hukIjW0ci?!Jw5nF_pVVfkbG+}6ZQsnT`PE7U^FqS%PnGBFbyTEV# zM^$MLdb21dtB?b~A30Ha|LU>%W6ydsKgECT&HUDrd73-v9B0@P#8!wp(GbKAuJ3eq zqK@`>CquYXmS_it8B9v9h!23H?@A0~o`-|LJ(}I`7??Ch?1Gx1d+`(l)9lA76bT`r i5Fyf1J_Hoy{2xq&ySj+9mJR_$#p2?z_$OQOn12JE_{G8i literal 0 HcmV?d00001 diff --git a/_qwen_xinference_demo/opro/__pycache__/session_state.cpython-310.pyc b/_qwen_xinference_demo/opro/__pycache__/session_state.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd1643bef123136db48c1061912462973353618e GIT binary patch literal 1995 zcmZWqOK;mo5Z+xkCik8mo)AM#v0S1;uZ0gnXz zfCs2EuW)Aao|w_)&cM(`Jx89XVzTk1u+8W=EQ0wJ)*ok}lOPSpy60MVI*q~BX#UDN z$NI#2!-L`Q>9fJGrwl=|P9DRBFqq+$uP@C_GWwTX zN2ISldeOzfuYQa|A!cG%yi_v9V&E~*?!eUW#B1{05NPC> zoviDVLZ{GkNGiQOSHZW^=3qwOdgwX8c)N`y~MLvd0{e?3J z9Zn_)+HU?bJoq_ya&Yk7qlb@w41Rj{J>Ejm=3B0ku#se=fOEfuMlf^Q-9*^#f3Bk> zW|3^ZG4G<%HWp!|q7CybjVOZEM>09jTzrGVG8bj!uA`r+$u>CPp1I2$P>W-b76WbW zE>`#~Ri=5LG;iTlMfp8au$F3TDK9S}mPbE8bKb_OD_w*}@anSpulW#Fwq92GwFnb3 zMn&Ars<>jy;Y?HmL%j+MgYvK9o4VFjzPEBNl9iUy$9-HZ9cg+fTi0ZXwrf$AVj36M z7^ak9Unp3s3}cn3GRxR6Y=uXeOz+OX@26=pM6;Q{@vY?vVCMI5$fM%NEGp z_iNmhQ@;4plTFN8^BJ9^9@J}@Cdmx>9h9*D7N+cg+?U)0{}*s^bAodca1cs(e9@%? lEOE2*?%~tZMBk_8!_C$cu5x8pwcA~%Tif%yjUDH<^FNMgkq7_) literal 0 HcmV?d00001 diff --git a/_qwen_xinference_demo/opro/__pycache__/session_state.cpython-313.pyc b/_qwen_xinference_demo/opro/__pycache__/session_state.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c6872b401c6187b3f4facbc96273240072bded0 GIT binary patch literal 2839 zcmbVO&2JM&6rb50uQ&cku=7#kBw&asF-mMGp$AHkfQGh#1eK8`Dv(yjUJ|E{H_X}s z4wZuF0qB7s71ip=M{-MVoccG!r6|@m2c(|3IZC+qy;(bsLpad2cHY}JGjHbo-j}z6 zK_7u<>cXj9hepW1_|bgy2D9}Fm=}Z-POlI_P7*L_I?gzYka;E;_rRCJ zHTmViFYdji!MrXdUMcC2JA6{2^Ij?Q!^!}>bj(Z0xW3a<z9l5Gm(j%j9Zr!zk~x-jR>Owr6bUACE* zoUtgZg~g&1$jqgSMni?;Td=Zb-YBFO%s6vcsg%n)9vsIVb@pnHv_vBnz zeeEAyV}qg>H~3m!%zb;9V#ti^fk>j z(%I~;m=sZ1`x(48q72sS2bQ~^1eY$gmNrK-~$Y_~dot^5o>#7cX48VqCrU z4QRnImRcJ5=LTF?-kLRJ_PclqjDSIS?eAd#*jWG*VmS%uuCsWR#W%u{^6BSe6>l{> zvU<6yf3>YD9&KAAp1yT9SY?B?zW$9+e|h-%Q2Aapbi5WmxannwwD*L0ysmFLc7hfQ zXWF(3nccC5_DVPX3e2atDYYv)|F1Jr$^;#ObMeVSMnwr%IZLp;4#EZ}J>WnZh~b?t zFbHd>h@&J1Lo)|Zsg$;7LCXP+jbpIxb9n7Z7+MZEQDsBxELLT)$`60DSOusX#kO$3q?DID56q{8!;LU> z$zAygW~ZU88iG;;~38 fa&DTgp2o>(iiaL6pTw#K z#<-g_(q_;s>$Rknwu5#$6-+_f6VtpdI(^0)ym`Qa8Bu#W3*Uu40Is_TSnW)l|qkY*27Is9|YBfZy7!nK*0&82bQ*u2)7ICn66m+WI_hkgcfAQ z+L)i%&_@x@!alZ->|+Zotg!b9cZZI=UeF`#A;j96)~>D{c?CTt`v^*H_1xAJcEj!0 zaO*2j46MT1XiP{!j;v!#!iM<&o^H`K>nDW8ffL6gz<@5^G9-~rm{b;jt}e+u*Gox~ zhFV5(wv>-#9ws^9xK!d8?!;MNNRjnK$VHl$1D3{0$Ph-RBOT@=4RP2La&fe48n+Wb zt_K}=TCm9<=<5}0@7gGmN`$>A<1vr4P>Z+AJ=7hjoKDzty7F!}+v|M&RrmewyItO0 z?D}1`ano2Y5o*6~s1$IRQ4(uotMLHODK1WR?Gm?w;QWR|o@gV1lLh

rH_i^#v%zrs zplj$g+SIR599oJEAZQwD`w*)kk*z~)6KqfH!j_k{g@mX^3~l+=A!xE7-_ivfM~`64 zF_|#Uz)F#D+G`81%S6w;GbCaQA^aK)dVTk>{jOkah>2_RwAu)6?~W72X2u z{ItPcUVA|%^}^Fl*rip}k$4N%q*2t4+C`%z<;Tx0SjQ{Zaf{l?I?bYa=AmC-4ctnb zK`nX)(j5q&1&u81Rq|D4mP2?VB$M8hLT%;=58PA$J$t4G(HX*&b7V*Okn9%_#xt^a z>4bo?=%F>FNcf8EXR-8c*@hX0#AT3C14;BmoKPPFEit4wurdS1-i;F%L!fA=L?U|n z|2Qhq7oP+y+KJUUM^9B*IywiVN>#cOh;9K`zUCk;9h7gE>&dI|tGxf$x0_&6xayqc zPsxio0QqESHZo2mWwa}?H)!@_4O$c>nP=6l7$BG(dE+1rmeyKAo_PX;Pr{n5G3oITdy#C~X%G z*j>oaE(ndifJEv$tAk&D#^?9DF;ddEO;d|>gg4X}@=^ov40S#fnSy%GMG|IFDtsa_ zlD;E7_(}9lI#^)#D39RR*B zAwjFIrS>5#YLo2-VbMOMKT@+#r7zX4R0(O-hd#8DjS*NELo1a&AoU?Ev=T2>&$T^v zO3_|v?m73|bMHNK?m6GNFMK`^f_D4E1L}XA2>qQ6tR~c%7iC~>q9}@T6NuC(M{F>* zP4MFa5rRy{PuRyD#8KtB2`6zzg{XbPHSQ*EkO@)8glAkNVwHDJc*lLjH|{5X7`tLp z)E#T-6QZ82=%PRZG3QjbWki#Lw5K_;xISA{gtfM)m#w|Xk@hI}=JipokB|Cd66pHt zx*c;|NP5P$nhbGt^YGl9mQG=!nCcaLOv@T3CaOLMQ-MUY63#)Fr%NgapEJlX4ohYY;{T zbtW{d=s0Px%nU-KlM#|mn-NV*#Z989dL*48>A04Lpo~<5GCr^BvzTB#f#XS>N>>#k zSp$(3*OPPza{tHBVX$xSMmCn z#y%MP9n-JcFF6^D;L8F{Qu?nQ1ikYidBG^VUvEC)&76BW~Wb^T1Wu7xGUA z*pxYe?QTIWntEfNG~-Ygnt1@LZBdTp&=Enm%l0=$zhM23?tt8MZisxT>Ko`}hoxgt z$loiFzpj}^PBe7@OWfx?246=L$e%3uD06#yZD6{di(c5%4Y<%!#|2S$>z%8d+ z20qqVQyrGPNm=PQ!s%XFoY`ziH~0vA{d(JpWUsZOMcFC3Pj*gidtDps^I2Me3iGd{ zg6y}#I3}(p=K-*olyWLz zcSQw|8B~DkX+V9$8S1jM4HktA$~H8@?u)k84~YX7#cS$h5GI&)MlU334m}Y&%2Hs_ zU2%FA2Y&c*es{_FR$1CoaBZoGo0hMY#GSd3C!(}+aM}Fq$gla;@n1=$wg_nCd-DAS zX-C1eqvG-`xju0%x8LyRhMxxZ+!-kkO_uk4yS#tkxsz`l%8jqPkiYYxZ&%T`YjwEf z>&cD%t*s+>q7vBrb3Q-zGx0&7d(mF$>RG*VzbkxSI(n}w{AprYxwGSUdv9^YuJEEz zl#W)U;A-F(n@dvYp%gAk;XA_xDO`{aEed6S>&n3ua%@?3Ll;hoZ^TPsh6Ql7$pS+znXh5j3_H0dw_BN~ zJDNy0eLgB2v30YGJNK7Av-Zz-RRpGocq*DWm0@qTEEHM~#Bbb8U&6Wpd^(A>xUQsd zh$9i$K4d5C5kXw6V?bN2Zq=v?Imr+RsEgo>)iACuAm3p!N3BxUs$)sS?raTRscJ^etIWRuG$tVo)~o)d_Oqu>$0z94V<+V3+0m)^$oHaWhbNDX#?Ky~ z8jT)5bz-D?1%x^2b3{hj!YTx~hTsab8h(eI0U4{y#s%osd5+^AA@?KH`WPL2j1E0Q zzJH+Z$7t_g(aCke$?eJST1P-uyI(T0-ir1e`t3)be^fYq?&0an#nYFcqfT!7YTv&Y fS?@#P17B>qyZy`Ug{gDTk((RL8~OmuAw%kN=N5z>XMzfU5x9$-ieM2IVaT}0{bLr^IO#z%I73Y6 zHi==KSn0znt$q==MEXH+9%mzETrKvsGM*HAEVBWz>eCe^gjaffY*+p^f_i?9gc{# zRQiAH+f!w8cwut{9@_1}9yIG+#+VWPk8C3ITwj> zt0!ky@P%Q#ulxVyjH1}qvxi$PAQ?dk?}G^`$qClcX5=&7rC1qaYJcJJ;EaSuBXsGp zVX+qJH;~ux`LCBPndiEdBx$5&Kh9ctA@eB72mPcK?e2*z+K;oLkRlt1h>J9@6k0RF zK@y8hFBJ!-zLH=M4Dz&Cy>0!o&Mv^=y~>ZMYJ{bg3Rq2wwF@M{o8Ji^W!lo5JP zoyJstoy?QVc#E3n@&P_+*5$m#^YlT1LnUoq<)S4wt$1!4((|p}Yk9{?9Hpgkhv>Yn g0A?JsMVFO#!3s9qI1#IMCY6mY0Bh8vb@Gt@296gOWdHyG literal 0 HcmV?d00001 diff --git a/_qwen_xinference_demo/opro/__pycache__/xinference_client.cpython-313.pyc b/_qwen_xinference_demo/opro/__pycache__/xinference_client.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b8eb3f1f90aa76b3440e8cbf77755596b8d1938 GIT binary patch literal 1741 zcmah}O>7%g5T5)0{DUdr#H5a>QzeFoQUzL2s2$wINs%3+N)EDIXB}E!&&o_@|5wGl<2g7Z64yMJaNb7*rV|_$vzlq0 z^Cio5EvYV?R_Eq4ecq&!r7l_~br98R5WT`N^X_sY$GTu=obS}Y72st~=X#-NTR;vq zdsjacMH?${c|I?U3OZllb1dGTb-25a0qR|Ql(@C^x=;{up8x9eViFE_bI`?tc(7|y zMyw#lp%;I}eYRiDZ%ezpfc+chb?K=!1#F*!ly`J$augg-smbxbH93pnM=_nnXKdN7 z=TR+(@Qj@W=No8H9>BVs`Odiq;XT>plVI{e$bG82bcBUI)>SQX7F-~T)V^d>~MRjYYrei$4qGD=V#Y(sc8-MDo zmzQcrd5(%iHS6SI?wGYn<*sIgN=$`}l{7*{%~BUkmQ>7MbRtO%!%Eg_3ndHl*ONw_ z7^Rv~Dc6#vvmdH@X;IUwDpB$Kr8#xM=olm)11mL6)ve+B67_U?R_a`M635T5XUBBo zmA?ROUt~GFzOV1XJKMZ0dAEF#%jYkh--y2Qt@DRK%a?76ZCTlrqb)hQA@{X?!OKe* zmp=UnVH6iHE-ou8r|E0dVOZo)8F#+uY9l; z{_5lQ>u=q1+#I?a%-kGWi(kuJ%`6vJ`+p1#eK&-`*Og4$7hEyC=l(K-<@I$;C8*zs+OVlTXG%$B)AI(XnKz2W~6R zrxfw_OFWYYdl)_HPPv3TGRLSQA#$fjNyUUaeH_w+oAh7|>adpT_`(wQf|3X$K?Zxe z6O=ibpCS7hBA6DTQWa;#>YU0A*g+V`0j89X;tgoj)Heyf0)R&aW)=x1%W?O>_cKK9 zf%n%)bTcy2ij1s7s2$wD8N@~ptREb01xMGxzb%Hjz!hux_|>ItKyhVeHNM)oai}%$ L+Fxu>1~K5j(93yi literal 0 HcmV?d00001 diff --git a/_qwen_xinference_demo/opro/ollama_client.py b/_qwen_xinference_demo/opro/ollama_client.py new file mode 100644 index 0000000..e5b6ba1 --- /dev/null +++ b/_qwen_xinference_demo/opro/ollama_client.py @@ -0,0 +1,52 @@ +import requests +import re +import config + +OLLAMA_URL = config.OLLAMA_GENERATE_URL +TAGS_URL = config.OLLAMA_TAGS_URL +MODEL_NAME = config.DEFAULT_CHAT_MODEL + +def call_qwen(prompt: str, temperature: float = 0.8, max_tokens: int = 512, model_name: str | None = None) -> str: + def _payload(m: str): + return { + "model": m, + "prompt": prompt, + "stream": False, + "options": { + "temperature": temperature, + "num_predict": max_tokens + } + } + primary = model_name or MODEL_NAME + try: + resp = requests.post(OLLAMA_URL, json=_payload(primary), timeout=60) + resp.raise_for_status() + data = resp.json() + return data.get("response", "") or data.get("text", "") + except requests.HTTPError as e: + # Try fallback to default when user-selected model fails + if model_name and model_name != MODEL_NAME: + try: + resp = requests.post(OLLAMA_URL, json=_payload(MODEL_NAME), timeout=60) + resp.raise_for_status() + data = resp.json() + return data.get("response", "") or data.get("text", "") + except Exception: + pass + raise + +def list_models() -> list[str]: + try: + r = requests.get(TAGS_URL, timeout=10) + r.raise_for_status() + data = r.json() or {} + items = data.get("models") or [] + names = [] + for m in items: + name = m.get("name") or m.get("model") + if name: + names.append(name) + names = [n for n in names if not re.search(r"embedding|rerank|reranker|bge", n, re.I)] + return names + except Exception: + return [MODEL_NAME] diff --git a/_qwen_xinference_demo/opro/prompt_utils.py b/_qwen_xinference_demo/opro/prompt_utils.py new file mode 100644 index 0000000..4d44486 --- /dev/null +++ b/_qwen_xinference_demo/opro/prompt_utils.py @@ -0,0 +1,20 @@ +def refine_instruction(query: str) -> str: + return f""" +你是一个“问题澄清与重写助手”。 +请根据用户的原始问题: +【{query}】 +生成不少于20条多角度、可直接执行的问题改写,每行一条。 +""" + +def refine_instruction_with_history(query: str, rejected_list: list) -> str: + rejected_text = "\n".join(f"- {r}" for r in rejected_list) if rejected_list else "" + return f""" +你是一个“问题澄清与重写助手”。 +原始问题: +{query} + +以下改写已被否定: +{rejected_text} + +请从新的角度重新生成至少20条不同的改写问题,每条单独一行。 +""" diff --git a/_qwen_xinference_demo/opro/session_state.py b/_qwen_xinference_demo/opro/session_state.py new file mode 100644 index 0000000..96c7c77 --- /dev/null +++ b/_qwen_xinference_demo/opro/session_state.py @@ -0,0 +1,56 @@ +import uuid + +SESSIONS = {} +USER_FEEDBACK_LOG = [] + +def create_session(query: str) -> str: + sid = uuid.uuid4().hex + SESSIONS[sid] = { + "original_query": query, + "round": 0, + "history_candidates": [], + "user_feedback": [], + "rejected": [], + "selected_prompt": None, + "chat_history": [], + "model_name": None + } + return sid + +def get_session(sid: str): + return SESSIONS.get(sid) + +def update_session_add_candidates(sid: str, candidates: list): + s = SESSIONS[sid] + s["round"] += 1 + s["history_candidates"].extend(candidates) + +def log_user_choice(sid: str, choice: str): + SESSIONS[sid]["user_feedback"].append( + {"round": SESSIONS[sid]["round"], "choice": choice} + ) + USER_FEEDBACK_LOG.append({ + "session_id": sid, + "round": SESSIONS[sid]["round"], + "choice": choice + }) + +def log_user_reject(sid: str, candidate: str, reason: str | None = None): + SESSIONS[sid]["rejected"].append(candidate) + USER_FEEDBACK_LOG.append({ + "session_id": sid, + "round": SESSIONS[sid]["round"], + "reject": candidate, + "reason": reason or "" + }) + +def set_selected_prompt(sid: str, prompt: str): + SESSIONS[sid]["selected_prompt"] = prompt + +def log_chat_message(sid: str, role: str, content: str): + SESSIONS[sid]["chat_history"].append({"role": role, "content": content}) + +def set_session_model(sid: str, model_name: str | None): + s = SESSIONS.get(sid) + if s is not None: + s["model_name"] = model_name diff --git a/_qwen_xinference_demo/opro/user_prompt_optimizer.py b/_qwen_xinference_demo/opro/user_prompt_optimizer.py new file mode 100644 index 0000000..d742e8f --- /dev/null +++ b/_qwen_xinference_demo/opro/user_prompt_optimizer.py @@ -0,0 +1,55 @@ +import re +import numpy as np +from sklearn.cluster import AgglomerativeClustering +from sklearn.metrics.pairwise import cosine_similarity +import config + +from .ollama_client import call_qwen +from .xinference_client import embed_texts +from .prompt_utils import refine_instruction, refine_instruction_with_history + +def parse_candidates(raw: str) -> list: + lines = [l.strip() for l in re.split(r'\r?\n', raw) if l.strip()] + cleaned = [] + for l in lines: + l = re.sub(r'^[\-\*\d\.\)\s]+', '', l).strip() + if len(l) >= 6: + cleaned.append(l) + return list(dict.fromkeys(cleaned)) + +def cluster_and_select(candidates: list, top_k=config.TOP_K, distance_threshold=config.CLUSTER_DISTANCE_THRESHOLD): + if not candidates: + return [] + if len(candidates) <= top_k: + return candidates + vecs = embed_texts(candidates) + if not vecs or len(vecs) != len(candidates): + return candidates[:top_k] + X = np.array(vecs) + + clustering = AgglomerativeClustering(n_clusters=None, + distance_threshold=distance_threshold, + metric="cosine", + linkage="average") + labels = clustering.fit_predict(X) + + selected_idx = [] + for label in sorted(set(labels)): + idxs = [i for i,l in enumerate(labels) if l == label] + sims = cosine_similarity(X[idxs]).mean(axis=1) + rep = idxs[int(np.argmax(sims))] + selected_idx.append(rep) + + selected = [candidates[i] for i in sorted(selected_idx)] + return selected[:top_k] + +def generate_candidates(query: str, rejected=None, top_k=config.TOP_K, model_name=None): + rejected = rejected or [] + if rejected: + prompt = refine_instruction_with_history(query, rejected) + else: + prompt = refine_instruction(query) + + raw = call_qwen(prompt, temperature=0.9, max_tokens=1024, model_name=model_name) + all_candidates = parse_candidates(raw) + return cluster_and_select(all_candidates, top_k=top_k) diff --git a/_qwen_xinference_demo/opro/xinference_client.py b/_qwen_xinference_demo/opro/xinference_client.py new file mode 100644 index 0000000..c0d3a8b --- /dev/null +++ b/_qwen_xinference_demo/opro/xinference_client.py @@ -0,0 +1,29 @@ +import requests +from typing import List +import config + +XINFERENCE_EMBED_URL = config.XINFERENCE_EMBED_URL +OLLAMA_EMBED_URL = config.OLLAMA_HOST + "/api/embeddings" + +def embed_texts(texts: List[str]) -> List[List[float]]: + payload = {"inputs": texts} + try: + resp = requests.post(XINFERENCE_EMBED_URL, json=payload, timeout=10) + resp.raise_for_status() + data = resp.json() + embs = data.get("embeddings", []) + if embs: + return embs + except Exception: + pass + + try: + payload2 = {"model": config.DEFAULT_EMBED_MODEL, "input": texts} + resp2 = requests.post(OLLAMA_EMBED_URL, json=payload2, timeout=15) + resp2.raise_for_status() + data2 = resp2.json() + if isinstance(data2, dict) and "data" in data2: + return [item.get("embedding", []) for item in data2["data"]] + return data2.get("embeddings", []) + except Exception: + return [] diff --git a/config.py b/config.py new file mode 100644 index 0000000..2662c83 --- /dev/null +++ b/config.py @@ -0,0 +1,19 @@ +APP_TITLE = "OPRO Prompt Optimizer API" +APP_DESCRIPTION = "提供提示优化、候选生成、会话聊天与模型管理的接口" +APP_VERSION = "0.1.0" +APP_CONTACT = {"name": "OPRO Team", "url": "http://127.0.0.1:8010/ui/"} + +# Ollama endpoints +OLLAMA_HOST = "http://127.0.0.1:11434" +OLLAMA_GENERATE_URL = f"{OLLAMA_HOST}/api/generate" +OLLAMA_TAGS_URL = f"{OLLAMA_HOST}/api/tags" +DEFAULT_CHAT_MODEL = "qwen3:8b" +DEFAULT_EMBED_MODEL = "qwen3-embedding:4b" + +# Xinference +XINFERENCE_EMBED_URL = "http://127.0.0.1:9997/models/bge-base-zh/embed" + +# Clustering/selection +TOP_K = 5 +CLUSTER_DISTANCE_THRESHOLD = 0.15 + diff --git a/examples/client_demo.py b/examples/client_demo.py new file mode 100644 index 0000000..5d3303e --- /dev/null +++ b/examples/client_demo.py @@ -0,0 +1,55 @@ +import requests + +BASE = "http://127.0.0.1:8010" + +def _post(path, payload): + r = requests.post(BASE + path, json=payload, timeout=30) + r.raise_for_status() + j = r.json() + if "success" in j: + if not j.get("success"): + raise RuntimeError(f"api error: {j}") + return j.get("data") + return j + +def _get(path): + r = requests.get(BASE + path, timeout=15) + r.raise_for_status() + j = r.json() + if "success" in j: + if not j.get("success"): + raise RuntimeError(f"api error: {j}") + return j.get("data") + return j + +def main(): + print("health:", _get("/health")) + try: + print("version:", _get("/version")) + except Exception: + pass + + data = _post("/query", {"query": "我想买苹果"}) + sid = data["session_id"] + print("created session:", sid) + print("candidates:", data["candidates"]) + + # choose first candidate + if data["candidates"]: + choice = data["candidates"][0] + ans = _post("/select", {"session_id": sid, "choice": choice}) + print("answer:", ans["answer"][:200]) + + # continue optimization + more = _post("/query_from_message", {"session_id": sid}) + print("next candidates:", more["candidates"]) + + # chat + chat = _post("/message", {"session_id": sid, "message": "还有更甜的苹果吗?"}) + print("chat answer:", chat["answer"][:200]) + + # list sessions + print("sessions:", _get("/sessions")) + +if __name__ == "__main__": + main() diff --git a/frontend/index.html b/frontend/index.html new file mode 100644 index 0000000..d6de7ea --- /dev/null +++ b/frontend/index.html @@ -0,0 +1,446 @@ + + + + + + OPRO Prompt Optimizer + + + + +

+
+
+ +
+
+
+
选择模型
+
+ + +
+
+
+ +
+
+

OPRO Prompt Optimizer

+
+
+
+
+ + +
+
+
+ +
+
+

会话信息

+
+
+
+
+
当前轮次
+
Round 0
+
+
+
已选提示词
+
暂未选择
+
+
+
操作提示
+
+
• 输入问题后会生成候选提示词
+
• 点击"选择"使用该提示词
+
• 点击"拒绝"生成新候选
+
• 点击"继续优化"获取更多选项
+
+
+
+
+ +
+
+ + + + + + diff --git a/frontend/react-app.html b/frontend/react-app.html new file mode 100644 index 0000000..0fb1649 --- /dev/null +++ b/frontend/react-app.html @@ -0,0 +1,164 @@ + + + + + + OPRO React 界面 + + + + + + + +
+ + + diff --git a/frontend/react/index.html b/frontend/react/index.html new file mode 100644 index 0000000..06bbb01 --- /dev/null +++ b/frontend/react/index.html @@ -0,0 +1,192 @@ + + + + + + OPRO React 界面 + + + + + + + +
+ + + diff --git a/frontend/ui_offline.html b/frontend/ui_offline.html new file mode 100644 index 0000000..48cb789 --- /dev/null +++ b/frontend/ui_offline.html @@ -0,0 +1,157 @@ + + + + + + OPRO 三栏界面(离线版) + + + +
+ +
+
+
+ + +
+
+
+
+ + +
+
+
+
+ + + diff --git a/ollama_client.py b/ollama_client.py new file mode 100644 index 0000000..a654934 --- /dev/null +++ b/ollama_client.py @@ -0,0 +1,19 @@ +import requests + +OLLAMA_URL = "http://127.0.0.1:11434/api/generate" +MODEL_NAME = "qwen3:8b" + +def call_qwen(prompt: str, temperature: float = 0.8, max_tokens: int = 512) -> str: + payload = { + "model": MODEL_NAME, + "prompt": prompt, + "stream": False, + "options": { + "temperature": temperature, + "num_predict": max_tokens + } + } + resp = requests.post(OLLAMA_URL, json=payload, timeout=60) + resp.raise_for_status() + data = resp.json() + return data.get("response", "") or data.get("text", "") diff --git a/opro_qwen_xinference_demo_package.zip b/opro_qwen_xinference_demo_package.zip new file mode 100644 index 0000000000000000000000000000000000000000..187905ffd3402beba012bb78fcce513d64616f5d GIT binary patch literal 6336 zcmcgx-ESOM72mc{V1&d=Q3V2NG(yTOU5`KFG|j>m+{O`6Z6}ITWNEz}jd$*@C!U>| z+&de4YejNsp)G0L6a^?LrD>HmKv7ZwQk;J9z*8UKpJ=>m=c#WV`teHOKm1^!fM57qa3ar&It?1O^oGM?xYfN@ z6w+8f2DcL@kC*fxoIZVWa-~k9*b4)VFO52`6U1IUS^I>;$aLruUlZ@ch=vv;i~)rW zTw{`~tWyGawtQF8nn&A>rUW$_>`4&KWy43P({k*(?~x!*Na~L7+h>6OG7W>K4-n6pov zczWjOEIdA_X?5RWOrMP%8lNI(&AugyWle|A**2k_vd&_fES(`M16L;I^4!ewS>luW zz6fwTVHwHb`yjg&g)BCWVQ>PCjiTm~1_%S!CM9RfeYvHy%yU7jtdDS^#nmWsjnIk= zRt+$#T3XVpHL-k8pPH~?N#`8JYD(fZ4fIt*S;GKraFs#BcHqL@2wW>y!UIl!UenV0 z24H2WE?(N4Sm-Prkf_X6=?5WKL2WLi@)~*&#yS%&MLSc+VjY@pe)yR{5mK+|1rEd`hLTY$=`m?V7}0G?DkIjxlHPJw z>^NK^0jnGu9TPHhDHCfpkO5fgG;I@*5q4Q76z<^y?1oGC^6Ee#;Nhaj;*fSyd1dU& z=@?MT#QX4no}eEJ%Aos+3C2??NdkT&BBGyT|L>T{AzCi@N;)boQd&4QA0Fy`J}SmC zQ879)Hl9l`ftC;hX&$-AIGHIOfB);>j#ln36!801C>;p$!x{L5N+OE~Fl`eSGfg`& zH8n9kVLyFp3If|$jN_<0JUlvf*cyR9qvg@j@q^>TP;W;yTa>sHc|kKN=u-eynl^WQ=Gf$^ z$(bi7?aApUCXenKl}C;oIXuk0&W0D7WM~2Uo1teHhe>OJxH9_qwT;pF;_kUHo+Ijk{|<(zb8i?p?dnyL5Bw z%5Qrw-rlhc z0-)3!n|5k!FMIK#Eo){)(Bcu-_91@CI{xNuA{StoK)xCCfnl8qy})c3L;5N` zZ|IPkI&{22mRV7Ut|P-p$L+V@>t24h zd-E!qf3+_Rssg~-_1-Vu2ipp>gIn>IduIEEKY?esThK|D(F1~U;!*SfI_=WW zw=Vtx`~coa{E!XRJO6m>nMDXw{PwF?(uMF5MIlMM?YWvZJ2^Xh{KO3Q`Kur)#&NDQ z-Ytej8;x8n{24ck)*@NqCQq1FVNDG=^QRotbB$hE%nl0q`(4jS>e+xEtPF{oTESrmV z0i3~P?K{+qJLy3`(*yP?`Xe+i$NWTKNAXS;r)p_xG*1_jWaUm=Mf}WHmkid50F4~V z(nkO$b(V$ivH#=9%N^^3o*0c)--1f+MP`v>t1qzCVv+B`?0jIJes4ZF0DC*j9Ax{e zMRwJ}1p&b}Ql=Z?Gv~qitAWplC6rjKy#P#~sm>40e|z4Yx8{rUto8`-KLB69n26F3 zJ%Dwe1g2laLGn?>4uM4OorJ5Ov~fJsQBf27Mho4tL^{lr>x(jBWZ2j+=EKPh9zLX! zN+c~biD6_fRSuPO7x+2Yjm3-5N9TV4!$I=pNi=zJBijZ%|W@?cfdiz-bym!eDGk`{uClag7gh8r}M_cdt z8DN{8j}%hx$qQeK;%4FL|Uh8HKJ zg*hjWE<_kWgW|a>usR*m$N^J>!I2?xO$Ob*;98U+fcpR$PbrV5l2_ss1b~a5%OwQZ zt3WzbgjgbvkQzXT%b15CbFPYvYelOC$!0n$9;=KB#gL2=oO*CY6=pq?hfWhtN}#T_ zSh1Mm&x2fOmn01k=X6KCRs2-Fn$RVqJ0^?ZBiYL!PwZDM^Sr4gwRXq3)LJ%uysc@{4t^rFXPgN2KIOPO;;-wHmwWY$O*f;Krmmrtc19*fD^MiV3MhCn& zk{->~5f)YuWpdy%2*ZVR($E92E6XH7F75}c!}s0)&;z@ljzcl#pX(R!A8}Tmr^z3S z2#P}C`i1Yn0sFv1l01w2q^nrG_eIHKN9lxCAkQX0jVU%W+ThtN&mrF=D-QqsvgD9& nne#02&5UAk9njCcP6$ay*R literal 0 HcmV?d00001 diff --git a/optimization/__pycache__/opt_utils.cpython-310.pyc b/optimization/__pycache__/opt_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c7a4a9fad9d6850d422c8a2de03e51c63509add GIT binary patch literal 19855 zcmch93vgW5dEVZ;cNdGrV(}!12OpOpB@!eEk`gIW6d!^}i4slmElOIES}kxcfCU!2 z;JFuktj{L0DA{sc+H}=6v1~S=oy1n+q;8XO+D@mFnR=S(bf#%%I_>nPGxbb5({eg# zT*pmoITHH)|GB$+cL6ez$5Sl!oOADa{*UvY|D69nTWW7lM)3EQ_?KpW`t3;M&*&li z*M^6u@pXTQKtwVTVF-SRBVPF^=PzxD%aiSxT%3Tb<-RLw_=D zP8?-Y!f@K|nTGD6J(ETY(L%?3XQnfF*OloOmWaI_$@GW;kr1tK8<}1)DB45{sWqZq zq(z65p6#2mM5pM&yS4nTTXcvXVdB$6xuRF}PPK_Oq7Nnd#ahvi76*iRAu=*l9=kZ> z*o$7ikoRudQ;Wq(FJCIU_EbsQ3$ir7;Mt`GFF&6z=O|m9@wA|i%y#7~!7F26Rm{pjT@ zl3KgZ!7~2*$cX8i((zlS@zwB!2lYe(2(}uUtIxVauUioS2N_ZS^mP1brDJ!$>IO>!7(ZzaF_8 ztwaewDyo2|S@T}BVkrywFi)3b+klRdk*JKLiflpP$B&=6IDYEXNX)l94AqY>NM|a4 z!;ia*Qxvx>}09vIYn;-uU+aPP2d=^7s>k+ zOcEihf0QPWYD854U@j&Th|M{}O`QZmZy9eHUPJ-B-pDNjKv<`lmXd0gM~Kg+^VgiB zU0j@>aHKuI=z8{qgMaV3gV($5Tv6D2#=u-LSz0U#P&$gTnbu-Bv0uNDGGspz)9-(c{{G3){C~$do|_~}RR4%}gjfdayGSStR50dDH|3c~ zM{gl@GV&Jqw-WO#0#-!djtJ|vsmwNTE^%-rG99VJRpfZ&%Gl*-rKOUnw1UsKtArvH zy{O_#G1Y>n)~V<{!>IKk4|{Qy5b;}PC0UGlRKlXPiisB00uD*2B*NCavHg)s?CM4p zmFE@aii)ZD6wTda`QQns=*XPs*dCCprvs8X&(0So3yZ>Wk-+Okp423EA0BEpl|Qn7H;@1i0a(tuxAG#&1jD*D@- z`fexy`7`f$xvUn%ep?6tKgJ#T@jRMw7W{VIuUE=q$k&?*wBg~j*b7X#2N3ipEW9l9Onc{s zGrv%P;JK)oE{(F)zja8|X5cHBxL6&GYY?#zC?*LM%!VkQ(=-cJN?1tuAuq;xZNY3E z#%yil**eU#HC{>6ESfX#81kFST(a(AHeHLzC((NAoGG6W$x27DU9=%g2Kgz>`f;w^ zUg_{+5MxngiL`1H9k*kZPSv40F%LULywauOXg3>b_YKu0D7UMa3bi)FW9|}Nl%rZ9 zAAVHn1`OR8!CPK@iDrB^R>dBbm_gSRNd={_#eixO)w z-mh^Fx^_R;uC_Ix+fps^M}T<{t$iP398x`yDC_PBQ$SEUdBXa&N za%|4o6x5jyLH%ElJIuM8gWT6c(EqmpdR>1@P~t_j{0M0HSDag;4DdOiM}t%w+I^JU z-5S(-Opgb7{m9$Kc_TsImU^Fe443I&YiT!1Z*OR+z1GrUm#2MFB5@AuLRfiZ`8&O?rBfxjM*p4?{{H7y#V^JB(Yv}tuj1gRl zHg2o5*v0v2$o;g~T}x}sfoB2pH+U>*v8R?B6OVDu2b^Pw$FD_D?t8%#XpXl3vDjN# zQ|zkYmA>Ly&B1v7PpTcy9hzIw{tvmvfu&Z} zzLXSClCE3n69?6rrM7z!VDc%|7o-l|jaAlqDJ(FD#SyHhM{k>z2xeBNqE&Plt7v~^ zpfZR#&8zVpQ+9jpYCt@V^>GMmWm*lWHcG2bF|LNt_A{y<;W1t-JMWlR4PbVy>eP5v zz~0ESx!;GHQ(G`eu2vOX(se_k)t1h6nfIVtjwO?}Juau+gNW^iNNRgqA@6!42ko3q z3A;39Zws|JYO{16wY9Jwq1NnRA*OVwXm+9G+S{-)pt0*S4%Tb0WJB5I6+TpM)Gier z>gTgtCmaCaWY3Ln9n}qQJ(rzurt`(Eqc)OQ3MnxjysL~cELsP4WVTS4&lYp@PPJ3g zB={L^Q~K1Y7wtvY5xgS^4ZwCpfFl6Vgqr3B)ZWI*89<_S(#{D%HNs%K_Vv6sqtT>+ zD3q~A)Tl|<3@gdesR94Yyz6l;7KpirN9QdKfT84rt52N z*1NgTfEs>AW5ho@bsZ)Yx|Bq!ncOvKv{(nLtvc88rNst4%w;`S}Ec-gDa zHAp?Qw@s8v1!8G(y3G7(0#^&q?rZKyvH@Z7S3V6LyWxjVw>g>MSMxsLoSl=O>{kgQhj$oEtf4 zERM?!h1mvki04{}Xw8U1)1ew+o!iDnQj=4+!VFk%X$re-eid*AcfLZSw6_%(#-RQZ z3hK;`Vzi?I3@m8uK>Z{MQnarJy<*3A&5@qXbJ5OCPA+0on+ND%TQucNgT2P=eHf$1z48eRpnUh$(jp`*@gOZar7An< zO|}LEd+nTLUUk9P4z{57p$qit$e8_%gR#V6*-aVEqzx&YPg^+3=dg1hihm0pw^yeDMhrB`GIgxiu#qxJv&?_H< zOdEGBa)lBYh?jK@B2FK$fG?CFaFM}41^@>YfU=xqut_(tbM7SAjJZ$>VpW$0J0WsZs&tr9sstH|@A>`hrL}dw^`a$t*H*O4_AG57ox(m%%5wpeOx~*!8+3r01vf2l7lF2rSBPA9E(B4QVB`(Fni(%;D5`2Ud4I#rsV};OJ z$UA_X^opRJO@mlB0{q(-AT8WCc235xomMyN| zhWgcMj$C>mdTNyQ&$+ME#rrNh0rnZyy*X5R=>2kW-p&kX!x`I)c`;sU44AjPn z6%=DFc5oWjn$+k*(o80EM*wy*S=-g*RuN44JkYh9SO6#KOvXTN*;PS?iK}TuOgEFm zqaB`Gr~|_O4V42iwn9eM)*)<8G~-u$r_N|iVa_?Ut13%r#;uT>q2&xJ&1BvMLpfk? z9C8wG@nm6TqCk5M%twd6CDsCSAK%#KtRFB&Qr)V3DUA4iC$%B|syqxoj&2v}J zy$9+aRRyLz_+SO5yy4;6OSx-W<^&jWat50tY%skU->`inTV8t%0vi~?i9&<4EAJ$x zbBdTUkbwjev;i0|^~sc?;irU2iFXMh^NU`P(g5R}bUkjswHjH3P>RYsT@q228>KjFWeW+`9 z#ioyy1wR$u`uMH2Er;J08c3#L+tPeY0&5%$A%wAkP#qvlADCk=EZDI+hC<^c6lFuZl+3ztCJ!&1tI-uIB6kxN_E2zwg2yRfo~kjH-_gWM8I*Cm zNq|0P#;p%i$)sgk)?fF>EHeoMTmt3;7z$df2A(T;vMQ5=9}H9c)#QfJZ0&8FQzYpH*NdGD2$EiIdZCWsVRMVYCDvH)W z21Kz0dLu(giXdUrNMbn!!-#whxS7Pa_BW4I7rtb5bh|sN1~yp5Ve7!TzKJ-B7`9+x zAtrlc3^8mGB90g~M$v+}1#to~8DCoww<1m=P9knY+=e)Xn9f(*5w{~wBTggkK-_`2 z6ERL(MHk{O#IQ3|V1ZB(SV_*}xo03!G(>NuN37v@8N7oX0`K}Ny-2S`dE|(G#C?bd z5U)i%h`1l|5aI#E>ktnjUXOSP@dm`}5N|}h9JCz_TDnT(plN*#b)DbRP7Y|#1prpM&#v)=(8eL4{Xhk;rlqgd-2`JZ=c|BKZgfAc%R&hRG#FggB(8P zwE~($9!}*|0$vBOtQiqJ9aOD!F7LIePSx5AqdoGhAny=cpAU0hvMKL~YD2xFoY&Tr z_q6Ip-Z)uSYHdBE9--5D_Io&f7q&a@{aNyW5YOF(;{ZROP#ddb19hHfE5)yh7b=JG zJyM*(yOU(ckoUwG-i_hi>v;DWwsd@lY!AT4VSJA;%ttwVTAZqks|2hH&%>&hnr(+2 zuS>iLd&4to9Cn0fZW(JMm18P}&-h#n;jy{M(~*IQI;uKA*LHm9eS7Ww)BOGzTJ2DM z*CNvfT0z+XV6Qh)IUah8lEK?&)nRoU)|tcVh}054|)yrgD3#P{+MaZas$&k*O;34|9^3gH2@3*kl8k1(y$2=}WogfSIE7*}zG z@FYOEP7NY_NgYPGUJW7Kpw=P0Bu*iGS-gVqRgposQ|)-q__9%XUOlf4s1v}z^Wrt| z;0s>oQkQxG7M+tA;bk7-6?L*k3FW{3U z{znUMa0@wgs@j6uyVR>rsFPwM@RE4E>G`C3g2v;udTY>LAI7s5ale?HGw&JV6zrsZ zqVFA(?4Gj&w~Sjx@e6%Krf=a7GzJH2b^zBfSxD(%GV^x+`Z< zV^*E5)!1;yl4-SB%%NUXJ}wGk{_Ti2q@G+_2g`I(lwbo}xNRN>4^_^g?p5&AIi8Jr zrt*{8I%lA^tI`|603S9z&u&uZ)K)Pj6yELFWwUaz@{*PyUyP`ifWL8A<1XRtMUQBF z=|MF1HqUMWjp?&^J6d^}>GX0yr4J2q<%|L0FWG^hr6H5c?je8^~ z-(YDwu5v+%&*R+$c>Y*1#-(!Lu48IKjjLnS`_eYxZBktURFeeJ9qVd77|-Wv6z;bO z|LThP0?jAt|CRyi3W(vkftB+>q;k2q32?p3rT;0F7NR0}r0o?)UGYYga|=3+64W2n zMIZmXT3?A?Zlj7p_kg#?+m2Oahq`hLGqZ99*!-fpQl|j$J*BPy51T8OP!972GVfP_ z-Dx$g4go7M*`lTZ59SdlbZBW*O$U(uIokS?x`ej=mZl~~^5^KURZX!pn?l<&Y6@+i z#43?iCot~Yn9VV{2Q|OUlItrhRqyB(_D<~T5Z1@90(N+zc{|157ViN|CkRWoXe2K- z)%!d3dgAY@6QFoZW}C|Vy?PlnLlT-tkDgHOSJg}%6KVr$|4=aow^+LS;^eg3r*&tuHL&huj_kd~`yHsN`s z^_w_Tm9-A73z5e033QD3)&I&X=BLq$?y)|^Ykp1t3iKjKb?rSG=s|blS1+-o>1Vn; zf!W$3$4Ezlwsx`@#i}$`IjI6Is#^O_%hh4%EF;isTHNn*eK%D94=^X6!klc8??Za6 zrgyf3`#$~bJAzuP>7A=kl=V!e4Az2KVBsKZp|k2Kii;M-fN&Pn{ANgId`r!W-}t(9 zm#L*5gJc4Ks#%RQK>ZBw2Ika0H79-(Yw#Mi7rMio_$^L9&gpMsM(@VE?;zfT_yffI z5&t&g#}NMx;)964i#qjv-tS_c_j}mqr7Hz**HT=~slq+@^sm_CwYWbG@9{AHf1tKk zHAHBmKM;NJQu>GGAv(RS`r)%Ph;yw}fS)~_@BL2p2XJz7a)#YLT|WQvJNP(N`|`l> zX2w z6!a+(J#D3;6U%rxqgB!9`GUJT6i)SwAYi&x{kzoZkA~1 z4{}Dm(@>r~P2dmWM3~3}@@73cMb742ay5oKhu^$HDAQ7<&m0A=zVA(srA z+P-;~`WXoyFox=^Jl*t%YA2$uQy{Afh6Yy**irL2d4?L>N;IoEEQE*P61$LfQ+p*3 z*SM&mG(0u3)lU39)L8W?TRrD==?FR8GH!T~SJXUgnEtw2v%!gKLxb&nGle`_BzV|w zbK#N|pvG?jWhP3l^PDyxkTp@8581+a1dnXa^E(RYD@OyE1W`b>Ib@ZO(%cL&ADIG} z!|L`KGO;kLk%LZCBV*+*T=COV~8UZeDB*Cnxq%TP|wBBjp~(G~J4*UNAXmv#t3e*;ND1U>>1I zkSIYu2q2y;A59)^__G6+Bp3r3(m2e9(uWV1`zLX^3Su4Q$uOh=fSG)_oF29(ZjvPk zm#>g5=|Tn{&S%mAXOlx4xFFN1#ZE}xbm!e{ld#IH4T%u~Sirn9kDC{nR`Q|5B?fu( z{tfiuC*6YMEM##z%DGQAbw3d}iQO!>heQUD+ZuSZB593Hy4P@12R3;O%a6~%b-|J4 z9?Uv|O$#b^uNyAMhRGhrhWu%?&wu)M&_EM930w9Buz+3evzt(-AMcZ;tGvF6#^LJ| zpye-7nax%Ivl_>;`6x7x58D19feOrFIfJyAFJTFrk`AO%(VHaVSO20&(0nDLxO7iy z8uw(b#Bp$9-~iwFs(BSxbLwW3k*IHu?e>j1ZBuD7rz8)N7a}pjk<&1o@x`)u1g4VZ z_;Ps@CPTKq_J^*^RO>Vhj0Q=@Hy#+_2B+P)zSD*mpTJsp4psT$PBK#7_`r_!Jsq|d zmUk4K;|ww&Y74wxRJ&_F#ei=8jU!X$1q}WQ7dy1E~gdbm++)g#-}{_hddYVF+p%Zy}SR3%p+RNp;im;v%J9 zqJR~Fn!YuX`o!(5QoG!l4CxUEZMfzPqmOcxsEx}2$ley3Wpyt+N%zF$(-e$Tz2Uz+l8C?pMXosF)isja6zE%dIH(n19z|z{NUl!1%W%4$GtYe>dFcv(A;q8 z+mGOLB$8U_w(U~nUgQ!iN>@!ma^P!(0nEj`EArEMuZj$7at3~PO|q%=5J7okK_1zL z#u*{F@fArJ$>lPw)2oW6YL)WPC@Df7b_i~um$)?uGEoqdbdXw!t$2Yu3IT-WN0MG? zXRw}s70|TkZ^1tgbQe$XHqO zK3$T${|5-lHilQ11pppudgjz88pt4V2d$EIhJah$HL%HWxqTR*Xg3Nd z;N|Y7^%U!9IrZ$N@l%+D8iq8yph%TtO#Et2N=b2m>CLK z01JlZmC$r%&l$f3)RnnuC$lz`mBp>iDL_ouZs62d%0i=(+jxY8F?o`bF*w0XM}84$ z`78W_J!RQ<2Fe%YEMS`~E!>pvb4}O5RalFz#QJf)Yt@%D%>{vtEBKCjq29O)a#oNp zQ1BuJq}b4sxsd5tpaR%{KoKH8XFrPDI|zz5aW^|x#1-`X+76l^|d=o*YFHC8koXo%Sw+ZX{UfW%~8LQ%`7&R67>Y5Z7{VD>JYI z=V5Y=-x*8|u1CD*x8mf5ltt%8X3HwQhRbu`nazJgDH;dwdah+zG>{mpUHsNNsh(Bv z4yAy$f^Mcgn525c=qGA(vz1bOB9$3zq(@cRWdyzo1j@$kum*T1-hh(eXZ0tfcx)~3E>f$IXrZOWB9(%QwptH<+9$<;?3gMipNyB$D zUGPoQiaexN0S6_Xe1!)4sC4M}HXsC7T$?9hS6yH-gVj{i<$Nqdnf^xHv14?WTtfXe zI(G}a#xp6+T$+J1N#Z4)_W5@~bTK~xKUHRf%#PJA4$`3JXzyD+z0%d>Hwc)%)d!jB z25Oro5N1b-r%7;fT4%e=x_}0}_VG*!sSg-8DHKdnKNVi)d~+(;dGIbFNi=P5P`}PY zgLH2|Q%OV<@RNhj9X{}-3!e15a2Ss!$ph}Ac#Dn`@L^m1Q*TM9__F`94t4vzMlwns zf615`wLXf+tL>2HoQU?pD{ljjBn}@s{9Hg3?tw_fqjB^|{(0nCM?W}_B;UJb6F)Mr zT8cgMIF~kARF1|&skD{k7Evqk(&P5=gr6&*ckspQ##p-JC>^7nHOhY+w{)Dq&l$u` zf}|6^d-ySfWs835AZmVOS>|6`G4p3u-2BjLF@I_$%>Qe(n*Rsg{gu^b{%azp%Qk|Cv}rPyjRdcJh}z8X0TGAo?KOgFY?2fH`fUXLvcn zbC7)?^^YxJ9LwmHU|5dF;D3lr{NT(oJPzSO2}~kC9s>>#CovR)nJ@x8eH33#COV_& zg>e+T{TZmyie6BT@QZvZyJ9r(aD-Q5oV+yAvWXs0<73kTot7i2aD}PaNq$`XUc)&fMcfbFKS%rUZNTp6c`2VF^*2ldNU4s(pVEL=MY-(8xFpSD9N=xjBvwp zw~2i=KGtu`{{W$ky)+#Lv7bE78rIOOLBA!N6{X2+R(_B2evg9hQt-zV&{-CC7&l${ zyOe6du`=%fNINgE*(;IdeL*(sr_Y={ePx# literal 0 HcmV?d00001 diff --git a/optimization/opt_utils.py b/optimization/opt_utils.py new file mode 100644 index 0000000..36ad3bf --- /dev/null +++ b/optimization/opt_utils.py @@ -0,0 +1,1035 @@ +# Copyright 2023 The OPRO Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The utility functions for prompt optimization.""" + +import collections +import json +import os +import pickle +import re +import sys + +OPRO_ROOT_PATH = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +) +sys.path.insert(0, OPRO_ROOT_PATH) + +import numpy as np +from opro.evaluation import eval_utils +import pandas as pd + + +def extract_string_in_square_brackets(input_string): + raw_result = re.findall(r"\[.*?\]", input_string) + if raw_result: + return raw_result[0][1:-1] + else: + return "" #从输入的字符串 input_string中提取​​第一个被方括号 []包裹的内容​​,并返回去掉方括号后的结果。如果输入字符串中没有方括号,则返回空字符串 ""。 + + +def parse_tag_content(text, prefix="", suffix=""): + pattern = f"{prefix}(.*?){suffix}" + results = re.findall(pattern, text, re.DOTALL) + return results + + +def _bucketize_float(num, n_buckets=20): + assert num >= 0 and num <= 1, "The given number must be between 0 and 1." + return round(num * n_buckets) #此函数可能用于将​​指令的评分​​(如 0.85)转换为整数分桶 + + +def gen_ins_and_score_pairs_substr( + old_instructions_and_scores, + old_instruction_score_threshold=0.1, + max_num_instructions=1000, + return_str_only=False, + num_score_buckets=np.inf, +): + """Generate the string that includes instruction-score pairs.""" + assert num_score_buckets == np.inf or isinstance(num_score_buckets, int) + old_instructions_and_scores_str = "" + old_instructions_and_scores = sorted( + old_instructions_and_scores, key=lambda x: x[1] + )[-max_num_instructions:] + old_instructions_and_scores_in_meta_prompt = [] + for instruction, score, i_step in old_instructions_and_scores: + if ( + not old_instruction_score_threshold + or score >= old_instruction_score_threshold + ): + old_instructions_and_scores_in_meta_prompt.append( + (instruction, score, i_step) + ) + if num_score_buckets == np.inf: + score_to_show = round(score, 3) + else: + score_to_show = _bucketize_float(score, num_score_buckets) + old_instructions_and_scores_str += ( + f"\ntext:\n{instruction}\nscore:\n{score_to_show}\n" + ) + if return_str_only: + return old_instructions_and_scores_str + else: + return ( + old_instructions_and_scores_str, + old_instructions_and_scores_in_meta_prompt, + ) + + +def gen_meta_prompt( + old_instructions_and_scores, + instruction_pos, + optimizer_llm_name, + old_instruction_score_threshold=0.1, + max_num_instructions=1000, + meta_prompt_type="both_instructions_and_exemplars", + few_shot_qa_pairs=False, + include_qa=True, + data=None, + few_shot_index_list=None, + instructions_before_exemplars=True, + num_score_buckets=np.inf, + dataset_name="", + task_name="", +): + """Generate meta prompt for instruction rewriting. + + Args: + old_instructions_and_scores (list): a list of (instruction, score, i_step) + pairs. + instruction_pos (str): where to put the instruction, one of {'before_QA', + 'Q_begin', 'Q_end', 'A_begin'}. + optimizer_llm_name (str): the name of the LLM used for instruction editing. + old_instruction_score_threshold (float): only add old instructions with score + no less than this threshold. + max_num_instructions (int): the maximum number of instructions in the meta + prompt. + meta_prompt_type (str): the type of meta-prompt: whether to have both + previous instructions and dataset exemplars (often for fine-tuned + optimizers), or to have only previous instructions (often for pre-trained + optimizers). + few_shot_qa_pairs (bool): whether to have few-shot QA pairs in the meta + prompt. + include_qa (bool): whether to include "Q:" and "A:" formats in the prompt. + data (list or pd.DataFrame): the raw data. + few_shot_index_list (list): the list of indices of few-shot examples. + instructions_before_exemplars (bool): whether the instruction-score pairs are + before the exemplars from the dataset. + num_score_buckets (np.inf or int): the number of score buckets when we + convert float accuracies to integers. Default to np.inf for not + bucketizing. + dataset_name (str): the name of the current dataset. Only used when + generating task description when meta_prompt_type == "instructions_only". + task_name (str): the name of the current task. Only used when generating task + description when meta_prompt_type == "instructions_only". + + Returns: + meta_prompt (str): the generated meta prompt. + """ + assert instruction_pos in { + "before_Q", + "Q_begin", + "Q_end", + "A_begin", + }, ( + "The instruction position should be either before the question, or at the" + " beginning of the question, at the end of the question, or at the" + " beginning of the answer." + ) + assert meta_prompt_type in { + "both_instructions_and_exemplars", + "instructions_only", + } + assert dataset_name in { + "mmlu", + "bbh", + "gsm8k", + }, "The lower-case dataset name must be one of mmlu, bbh, gsm8k." + assert num_score_buckets == np.inf or isinstance(num_score_buckets, int) + + meta_prompt = "" + if meta_prompt_type == "both_instructions_and_exemplars": + if optimizer_llm_name.lower() in {"gpt-3.5-turbo", "gpt-4","local"}: # GPT模型的提示模板 + if instruction_pos == "A_begin":# 针对答案开头部分的描述 + meta_prompt_old_instruction_part = ( + "Your task is to generate the answer starting sentence ." + " Below are some previous starting sentences with their scores." + " The score ranges from 0 to 100.\n" + )#要求模型生成答案的起始句(如“The answer is...”),并参考历史评分数据。 + else:# 针对普通指令的描述 + meta_prompt_old_instruction_part = ( + "Your task is to generate the instruction ." + " Below are some previous instructions with their scores." + " The score ranges from 0 to 100.\n" + ) + else: # text-bison 模型的专用提示模板 + assert optimizer_llm_name.lower() == "text-bison" + meta_prompt_old_instruction_part = ( + "I have some texts along with their corresponding scores." + " The texts are arranged in ascending order based on their scores," + " where higher scores indicate better quality.\n\n" + ) + # add old instructions + old_instructions_and_scores_str = gen_ins_and_score_pairs_substr( + old_instructions_and_scores=old_instructions_and_scores, + old_instruction_score_threshold=old_instruction_score_threshold, + max_num_instructions=max_num_instructions, + return_str_only=True, + num_score_buckets=num_score_buckets, + ) + meta_prompt_old_instruction_part += old_instructions_and_scores_str + # add QA pairs if few_shot_qa_pairs == True + meta_prompt_exemplar_part = "" + if few_shot_qa_pairs: + if optimizer_llm_name.lower() in {"gpt-3.5-turbo", "gpt-4", "local"}: + meta_prompt_exemplar_part += "Below are some problems.\n" + else: + assert optimizer_llm_name.lower() == "text-bison" + meta_prompt_exemplar_part += ( + "The following exemplars show how to apply your text: you replace" + " in each input with your text, then read the input and give" + " an output. We say your output is wrong if your output is" + " different from the given output, and we say your output is" + " correct if they are the same. When replacing with an old" + " piece of text above, we get wrong outputs on the following" + " inputs.\n\n" + ) + for idx in few_shot_index_list: + if dataset_name == "mmlu": + question = eval_utils._format_mmlu_example(data, idx) # pylint: disable=protected-access + true_answer = data.iloc[idx, -1] + elif dataset_name == "bbh": + question = data[idx]["input"] + true_answer = data[idx]["target"] + else: + assert dataset_name == "gsm8k" + question = data.iloc[idx, 0] + true_answer = data.iloc[idx, 1] + + if include_qa: # when "Q:" and "A:" are present in the prompt + if instruction_pos == "before_Q": + meta_prompt_exemplar_part += f"\ninput:\n\nQ: {question}\nA:" + elif instruction_pos == "Q_begin": + meta_prompt_exemplar_part += f"\ninput:\nQ: \n{question}\nA:" + elif instruction_pos == "Q_end": + meta_prompt_exemplar_part += f"\ninput:\nQ: {question}\n\nA:" + else: # instruction_pos == "A_begin" + if optimizer_llm_name.lower() in {"gpt-3.5-turbo", "gpt-4", "local"}: + meta_prompt_exemplar_part += f"\nQ: {question}\nA: " + else: + assert optimizer_llm_name.lower() == "text-bison" + meta_prompt_exemplar_part += f"\ninput:\nQ: {question}\nA: " + else: # when there're no "Q:" and "A:" in the prompt + assert instruction_pos in {"Q_begin", "Q_end"} + if optimizer_llm_name.lower() in {"gpt-3.5-turbo", "gpt-4", "local"}: + if instruction_pos == "Q_begin": + meta_prompt_exemplar_part += f"\nProblem:\n\n{question}\n" + elif instruction_pos == "Q_end": + meta_prompt_exemplar_part += f"\nProblem:\n{question}\n\n" + else: + assert optimizer_llm_name.lower() == "text-bison" + if instruction_pos == "Q_begin": + meta_prompt_exemplar_part += f"\ninput:\n\n{question}\n" + elif instruction_pos == "Q_end": + meta_prompt_exemplar_part += f"\ninput:\n{question}\n\n" + + if optimizer_llm_name.lower() in {"gpt-3.5-turbo", "gpt-4", "local"}: + meta_prompt_exemplar_part += ( + f"\nGround truth answer:\n{true_answer}\n" + ) + else: + assert optimizer_llm_name.lower() == "text-bison" + meta_prompt_exemplar_part += f"\noutput:\n{true_answer}\n" + + if few_shot_qa_pairs: + if instructions_before_exemplars: + meta_prompt += ( + meta_prompt_old_instruction_part + + "\n\n" + + meta_prompt_exemplar_part + ) + else: + meta_prompt += ( + meta_prompt_exemplar_part + + "\n\n" + + meta_prompt_old_instruction_part + ) + else: + meta_prompt += meta_prompt_old_instruction_part + + if optimizer_llm_name.lower() in {"gpt-3.5-turbo", "gpt-4", "local"}: + if instruction_pos == "A_begin": + meta_prompt += ( + "\n\nGenerate a starting sentence that is different from all the" + " sentences above, and has a higher score than all the" + " sentences above. The starting sentence should begin with" + " and end with . The starting sentence should be" + " concise, effective, and generally applicable to all QA pairs" + " above." + ) + else: + meta_prompt += ( + "\n\nGenerate an instruction that" + " is different from all the instructions above," + " and has a higher score than all the instructions above." + " The instruction should begin with and end with ." + " The instruction should be concise, effective," + " and generally applicable to all problems above." + ) + else: + assert optimizer_llm_name.lower() == "text-bison" + meta_prompt += ( + "\n\nWrite your new text that is different from the old ones and" + " has a score as high as possible. Write the text in square brackets." + ) + else: + # when using a pre-trained model as optimizer + assert meta_prompt_type == "instructions_only" + + assert instruction_pos in {"Q_begin", "Q_end", "A_begin"} + if instruction_pos == "Q_begin": + instruction_pos_description = "at the beginning of the question" + elif instruction_pos == "Q_end": + instruction_pos_description = "at the end of the question" + else: + assert instruction_pos == "A_begin" + instruction_pos_description = "at the beginning of the answer" + + if dataset_name == "gsm8k": + instruction_task_description = "grade school math" + elif dataset_name == "mmlu": + instruction_task_description = task_name + else: + assert dataset_name == "bbh" + instruction_task_description = " ".join(task_name.split("_")) + + meta_instruction = ( + f"Create a piece of text {instruction_pos_description.strip()} to" + " enhance the precision in solving diverse" + f" {instruction_task_description.strip()} problems." + ) + old_instructions_and_scores = sorted( + old_instructions_and_scores, key=lambda x: x[1] + ) + old_instructions_and_scores_str = "" + for instruction, score, _ in old_instructions_and_scores: + if num_score_buckets == np.inf: + score_to_show = round(score, 2) + else: + score_to_show = _bucketize_float(score, num_score_buckets) + old_instructions_and_scores_str += ( + f"\n\nPrecision: {score_to_show} {instruction}" + ) + meta_prompt += meta_instruction + old_instructions_and_scores_str + return meta_prompt + + +def run_evolution(**kwargs): + """The function for evolution.""" + # ================= experiment configurations ============================= + num_search_steps = kwargs["num_search_steps"] + old_instruction_score_threshold = kwargs["old_instruction_score_threshold"] + scorer_llm_dict = kwargs["scorer_llm_dict"] + optimizer_llm_dict = kwargs["optimizer_llm_dict"] + extract_final_answer_by_prompting_again = kwargs[ + "extract_final_answer_by_prompting_again" + ] + include_qa = kwargs["include_qa"] + evaluate_in_parallel = kwargs["evaluate_in_parallel"] + tasks_all = kwargs["tasks_all"] + train_ratio = kwargs["train_ratio"] + eval_ratio = kwargs["eval_ratio"] + test_ratio = kwargs["test_ratio"] + train_index = kwargs["train_index"] + eval_index = kwargs["eval_index"] + dataset_name = kwargs["dataset_name"] + task_name = kwargs["task_name"] + num_examples = kwargs["num_examples"] + root_data_folder_path = kwargs["root_data_folder_path"] + optimizer_llm_temperature = kwargs["optimizer_llm_temperature"] + optimizer_llm_temperature_schedule = ( + kwargs["optimizer_llm_temperature_schedule"] + if "optimizer_llm_temperature_schedule" in kwargs + else "constant" + ) + optimizer_llm_temperature_end = ( + kwargs["optimizer_llm_temperature_end"] + if "optimizer_llm_temperature_end" in kwargs + else None + ) + initial_instructions = kwargs["initial_instructions"] + multiple_choice_tasks = kwargs["multiple_choice_tasks"] + raw_data = kwargs["raw_data"] + call_scorer_server_func = kwargs["call_scorer_server_func"] + call_optimizer_server_func = kwargs["call_optimizer_server_func"] + instruction_pos = kwargs["instruction_pos"] + prediction_treat_as_number = kwargs["prediction_treat_as_number"] + prediction_treat_as_bool = kwargs["prediction_treat_as_bool"] + result_by_instruction_folder = kwargs["result_by_instruction_folder"] + few_shot_qa_pairs = kwargs["few_shot_qa_pairs"] + num_score_buckets = kwargs["num_score_buckets"] + max_num_instructions = kwargs["max_num_instructions"] + meta_prompt_type = kwargs["meta_prompt_type"] + meta_prompt_instructions_before_exemplars = kwargs[ + "meta_prompt_instructions_before_exemplars" + ] + few_shot_selection_criteria = kwargs["few_shot_selection_criteria"] + optimizer_llm_name = kwargs["optimizer_llm_name"] + num_generated_instructions_in_each_step = kwargs[ + "num_generated_instructions_in_each_step" + ] + evaluate_generated_ins_on_few_shot = kwargs[ + "evaluate_generated_ins_on_few_shot" + ] + num_few_shot_questions_for_instruction_refinement = kwargs[ + "num_few_shot_questions_for_instruction_refinement" + ] + evaluate_old_ins_on_few_shot = kwargs["evaluate_old_ins_on_few_shot"] + eval_interval = kwargs["eval_interval"] + save_folder = kwargs["save_folder"] + verbose = kwargs["verbose"] if "verbose" in kwargs else False + + # =================== assertions ===================== + assert dataset_name in { + "mmlu", + "bbh", + "gsm8k", + }, "The lower-case dataset name must be one of mmlu, bbh, gsm8k." + assert optimizer_llm_temperature_schedule in { + "constant", + "linear_increase", + }, "The temperature schedule should be constant or linear_increase." + + # =================== save configurations to json file ==================== + configs_dict = dict() + configs_dict["scorer_llm_dict"] = scorer_llm_dict + configs_dict["optimizer_llm_dict"] = optimizer_llm_dict + configs_dict["instruction_pos"] = instruction_pos + configs_dict["optimizer_llm_temperature"] = optimizer_llm_temperature + configs_dict["optimizer_llm_temperature_schedule"] = ( + optimizer_llm_temperature_schedule + ) + configs_dict["optimizer_llm_temperature_end"] = optimizer_llm_temperature_end + with open(os.path.join(save_folder, "configs_dict.json"), "w") as f: + json.dump(configs_dict, f, indent=4) + + num_servers = scorer_llm_dict["num_servers"] + batch_size = scorer_llm_dict["batch_size"] + generated_ins_on_few_shot_results_dict = dict() + old_ins_on_few_shot_results_dict = dict() + # evaluation results every a few steps + # format: [(i_step, instruction, detailed_results_df)] + eval_results = [] + # all generated instructions, format: [(instruction, score, step_index)] + # the instructions that were skipped have score NaN + old_instructions_and_scores_raw = [] + # the new instructions, format: [(instruction, score, step_index)] + old_instructions_and_scores = [] + meta_prompts = [] # format: [(meta_prompt, step_index)] + instruction_score_dict = dict() # the dictionary of {instruction: score} + # the dictionary of the few-shot QA indices in meta-prompt + # key: step index; value: the list of few-shot indices in that step + few_shot_index_list_by_step_dict = dict() + detailed_results_df_by_instruction_dict = dict() + wrong_questions_from_start_counter = collections.Counter() + # EVAL results + eval_detailed_results_df_dict = dict() # {instruction: detailed_results_df} + instruction_eval_score_dict = dict() # {instruction: eval_score} + old_instruction_md5_hashstrings_set = set() + + print(f"tasks_all: {tasks_all}") + print( + f"train_ratio: {train_ratio}, number of training points:" + f" {int(num_examples * train_ratio)}" + ) + print( + f"eval_ratio: {eval_ratio}, number of eval points: " + f"{int(num_examples * eval_ratio)}" + ) + print( + f"test_ratio: {test_ratio}, number of test points: " + f"{int(num_examples * test_ratio)}" + ) + print( + f"optimizer llm temperature: {optimizer_llm_temperature}, schedule:" + f" {optimizer_llm_temperature_schedule}" + ) + print( + f"generating {num_generated_instructions_in_each_step} instructions in" + f" each step, run for {num_search_steps} steps" + ) + print( + "discarding generated instructions with score less than:" + f" {old_instruction_score_threshold} (old_instruction_score_threshold)" + ) + print(f"num_score_buckets: {num_score_buckets}") + + if dataset_name == "mmlu": + is_multiple_choice = True + is_multiple_choice_eval = True + elif dataset_name in {"gsm8k"}: + is_multiple_choice = False + is_multiple_choice_eval = False + else: + assert dataset_name == "bbh" + is_multiple_choice = [] + is_multiple_choice_eval = [] + train_index_by_task_dict = dict() + eval_index_by_task_dict = dict() + start_index = 0 + for task_name in tasks_all: + single_task_list = eval_utils.load_bbh_task_data( + task_name, base_dir=root_data_folder_path + ) + end_index = start_index + len(single_task_list) + train_index_by_task_dict[task_name] = ( + train_index[(train_index >= start_index) & (train_index < end_index)] + # if " - start_index" is added here, then the dict would contain + # indices in the original task + ) + eval_index_by_task_dict[task_name] = ( + eval_index[(eval_index >= start_index) & (eval_index < end_index)] + # if " - start_index" is added here, then the dict would contain + # indices in the original task + ) + start_index = end_index + is_multiple_choice_single_task_train = [ + task_name in multiple_choice_tasks + ] * len(train_index_by_task_dict[task_name]) + is_multiple_choice_single_task_eval = [ + task_name in multiple_choice_tasks + ] * len(eval_index_by_task_dict[task_name]) + is_multiple_choice += is_multiple_choice_single_task_train + is_multiple_choice_eval += is_multiple_choice_single_task_eval + + prev_saved_instructions = set() + + # evaluate initial instructions + print("\n============== evaluating initial instructions ===============") + for instruction in initial_instructions: + print(f"""computing the score of "{instruction}" by prompting""") + + detailed_results_df = eval_utils.evaluate_single_instruction( + data=raw_data, + instruction=instruction, + eval_index_all=train_index, + batch_size=batch_size, + call_server_func=call_scorer_server_func, + dataset_name=dataset_name, + num_servers=num_servers, + extract_final_answer_by_prompting_again=extract_final_answer_by_prompting_again, + include_qa=include_qa, + evaluate_in_parallel=evaluate_in_parallel, + instruction_pos=instruction_pos, + is_multiple_choice=is_multiple_choice, + prediction_treat_as_number=prediction_treat_as_number, + prediction_treat_as_bool=prediction_treat_as_bool, + prediction_num_decimals=0, + max_retry=120, + sleep_time=60, + verbose=verbose, + ) + + detailed_results_df_by_instruction_dict[instruction] = detailed_results_df + scores = detailed_results_df["accuracy"] + average_score = np.average(scores) + print(f"instruction: {instruction}, score: {average_score}") + filename = eval_utils.instruction_to_filename(instruction) + file_path = os.path.join(result_by_instruction_folder, f"{filename}.csv") + detailed_results_df.to_csv(file_path, index=True, header=True) + print(f"""saving results of "{instruction}" to {file_path}""") + old_instructions_and_scores.append((instruction, average_score, -1)) + old_instructions_and_scores_raw.append((instruction, average_score, -1)) + instruction_score_dict[instruction] = average_score + + # increment the counter on wrong questions + wrong_question_indices_set = set( + list( + detailed_results_df.iloc[ + np.where(detailed_results_df.accuracy == 0.0)[0], : + ].index + ) + ) + for idx in wrong_question_indices_set: + wrong_questions_from_start_counter[idx] += 1 + + # evolution + for i_step in range(num_search_steps): + print(f"\n================== Step {i_step} =====================") + if not i_step % 10: + print(f"old_instructions_and_scores: {old_instructions_and_scores}") + + if optimizer_llm_temperature_schedule == "linear_increase": + optimizer_llm_temperature_curr = ( + optimizer_llm_temperature + + i_step + / num_search_steps + * (optimizer_llm_temperature_end - optimizer_llm_temperature) + ) + else: + optimizer_llm_temperature_curr = optimizer_llm_temperature + print( + f"current optimizer_llm_temperature: {optimizer_llm_temperature_curr}" + ) + + # generate new instructions + if few_shot_qa_pairs: + if few_shot_selection_criteria == "accumulative_most_frequent": + # select QA pairs that were done wrong the most number of times + most_frequent_wrong_question_indices = [ + k + for k, _ in sorted( + wrong_questions_from_start_counter.items(), key=lambda x: -x[1] + ) + ] + print( + "len(most_frequent_wrong_question_indices):" + f" {len(most_frequent_wrong_question_indices)}" + ) + if ( + len(most_frequent_wrong_question_indices) + <= num_few_shot_questions_for_instruction_refinement + ): + few_shot_index_list = most_frequent_wrong_question_indices.copy() + else: + np.random.seed(i_step) + few_shot_index_list = np.sort( + np.random.choice( + most_frequent_wrong_question_indices, + num_few_shot_questions_for_instruction_refinement, + replace=False, + ) + ) + + elif few_shot_selection_criteria == "current_most_frequent": + # show exemplars done wrong most often by currently shown instructions + old_instruction_score_threshold_single_step = ( + old_instruction_score_threshold if i_step > 0 else 0 + ) + _, old_instructions_and_scores_in_meta_prompt = ( + gen_ins_and_score_pairs_substr( + old_instructions_and_scores=old_instructions_and_scores, + old_instruction_score_threshold=old_instruction_score_threshold_single_step, + max_num_instructions=max_num_instructions, + return_str_only=False, + num_score_buckets=num_score_buckets, + ) + ) + wrong_questions_counter_single_step = collections.Counter() + for ins, _, _ in old_instructions_and_scores_in_meta_prompt: + filename = eval_utils.instruction_to_filename(ins) + file_path = os.path.join( + result_by_instruction_folder, f"{filename}.csv" + ) + single_ins_df = pd.read_csv(file_path, index_col=0, header=0) + wrong_question_indices_set_single_old_ins = set( + list( + single_ins_df.iloc[ + np.where(single_ins_df.accuracy == 0.0)[0], : + ].index + ) + ) + for idx in wrong_question_indices_set_single_old_ins: + wrong_questions_counter_single_step[idx] += 1 + most_occurred_wrong_questions = [ + k + for k, v in wrong_questions_counter_single_step.items() + if v == max(wrong_questions_counter_single_step.values()) + ] + if ( + len(most_occurred_wrong_questions) + < num_few_shot_questions_for_instruction_refinement + ): + # pylint: disable=cell-var-from-loop + idx_most_to_least = sorted( + wrong_questions_counter_single_step, + key=lambda x: -wrong_questions_counter_single_step[x], + ) + few_shot_index_list = idx_most_to_least[ + :num_few_shot_questions_for_instruction_refinement + ] + else: + few_shot_index_list = np.sort( + np.random.choice( + most_occurred_wrong_questions, + num_few_shot_questions_for_instruction_refinement, + replace=False, + ) + ) + elif few_shot_selection_criteria == "constant": + np.random.seed(0) + few_shot_index_list = np.sort( + np.random.choice( + train_index, + num_few_shot_questions_for_instruction_refinement, + replace=False, + ) + ) + else: + assert few_shot_selection_criteria == "random" + np.random.seed(i_step) + few_shot_index_list = np.sort( + np.random.choice( + train_index, + num_few_shot_questions_for_instruction_refinement, + replace=False, + ) + ).tolist() + + few_shot_index_list_by_step_dict[i_step] = few_shot_index_list + + meta_prompt = gen_meta_prompt( + old_instructions_and_scores=old_instructions_and_scores, + instruction_pos=instruction_pos, + optimizer_llm_name=optimizer_llm_name, + old_instruction_score_threshold=old_instruction_score_threshold, + max_num_instructions=max_num_instructions, + meta_prompt_type=meta_prompt_type, + few_shot_qa_pairs=few_shot_qa_pairs, + include_qa=include_qa, + data=raw_data, + few_shot_index_list=few_shot_index_list, + instructions_before_exemplars=meta_prompt_instructions_before_exemplars, + num_score_buckets=num_score_buckets, + dataset_name=dataset_name, + task_name=task_name, + ) + + else: # no few-shot exemplars in meta-prompt + few_shot_index_list = [] + meta_prompt = gen_meta_prompt( + old_instructions_and_scores=old_instructions_and_scores, + instruction_pos=instruction_pos, + optimizer_llm_name=optimizer_llm_name, + old_instruction_score_threshold=old_instruction_score_threshold, + max_num_instructions=max_num_instructions, + meta_prompt_type=meta_prompt_type, + few_shot_qa_pairs=False, + include_qa=include_qa, + instructions_before_exemplars=meta_prompt_instructions_before_exemplars, + num_score_buckets=num_score_buckets, + dataset_name=dataset_name, + task_name=task_name, + ) + print(f"\nmeta_prompt: \n\n{meta_prompt}\n") + meta_prompts.append((meta_prompt, i_step)) + remaining_num_instructions_to_generate = ( + num_generated_instructions_in_each_step + ) + generated_instructions_raw = [] + while remaining_num_instructions_to_generate > 0: + optimizer_llm_input_text = meta_prompt + # generate instructions + print(f"current temperature: {optimizer_llm_temperature_curr}") + raw_outputs = call_optimizer_server_func( + optimizer_llm_input_text, + temperature=optimizer_llm_temperature_curr, + ) + + # Extract the generated instructions from the optimizer LLM output. Only + # keep some samples if the desired number of remaining instructions + # is smaller than the total number of decodes in this step. + if meta_prompt_type == "both_instructions_and_exemplars": + raw_outputs = raw_outputs[:remaining_num_instructions_to_generate] + if optimizer_llm_name.lower() in {"gpt-3.5-turbo", "gpt-4", "local"}: + if instruction_pos == "A_begin": + start_string = "" + end_string = "" + else: + start_string = "+" + end_string = "" + for raw_output in raw_outputs: + if start_string not in raw_output: + start_index = 0 + else: + start_index = raw_output.index(start_string) + len(start_string) + if end_string not in raw_output: + end_index = len(raw_output) + else: + end_index = raw_output.index(end_string) + new_inst = raw_output[start_index:end_index].strip() + generated_instructions_raw.append(new_inst) + else: + assert optimizer_llm_name.lower() == "text-bison" + generated_instructions_raw += [ + extract_string_in_square_brackets(string) + for string in raw_outputs + ] + + remaining_num_instructions_to_generate -= optimizer_llm_dict[ + "batch_size" + ] + else: + assert meta_prompt_type == "instructions_only" + max_num_instructions_to_keep_in_each_output = 1 + for string in raw_outputs: + generated_instructions_raw += parse_tag_content(string)[ + :max_num_instructions_to_keep_in_each_output + ] + remaining_num_instructions_to_generate -= ( + optimizer_llm_dict["batch_size"] + * max_num_instructions_to_keep_in_each_output + ) + + generated_instructions_raw = list( + map(eval_utils.polish_sentence, generated_instructions_raw) + ) + print(f"\ninitially generated instructions: {generated_instructions_raw}\n") + + # do not evaluate old instructions again + generated_instructions = [] # the new instructions generated in this step + for ins in generated_instructions_raw: + ins_md5_hashstring = eval_utils.instruction_to_filename( + ins, md5_hashing=True + ) + if ins_md5_hashstring not in old_instruction_md5_hashstrings_set: + generated_instructions.append(ins) + old_instruction_md5_hashstrings_set.add(ins_md5_hashstring) + else: + print(f"already evaluated '{ins}' previously") + generated_instructions = list(set(generated_instructions)) + + to_evaluate_instructions = [] + for instruction in generated_instructions: + if len(instruction) > 500: + print(f"Step {i_step}, instruction: {instruction}, too long, skipped") + continue + if dataset_name == "gsm8k" and any( + char.isdigit() for char in instruction + ): + print( + f"Step {i_step}, instruction: {instruction}, contains numbers," + " skipped" + ) + continue + if "INS" in instruction: + print( + f"Step {i_step}, instruction: {instruction}, contains 'INS'," + " skipped" + ) + continue + to_evaluate_instructions.append(instruction) + print(f"\nto-evaluate generated instructions: {to_evaluate_instructions}\n") + + # evaluate new instructions on the few-shot exemplars in meta-prompt + if few_shot_qa_pairs and evaluate_generated_ins_on_few_shot: + print("evaluating GENERATED instructions on few-shot exemplars") + single_step_eval_on_few_shot = dict() + for instruction in to_evaluate_instructions: + if instruction not in prev_saved_instructions: + print( + f"evaluating Step {i_step}, instruction: {instruction} on" + " few-shot exemplars" + ) + detailed_results_df = eval_utils.evaluate_single_instruction( + data=raw_data, + instruction=instruction, + eval_index_all=few_shot_index_list, + batch_size=batch_size, + call_server_func=call_scorer_server_func, + dataset_name=dataset_name, + num_servers=num_servers, + extract_final_answer_by_prompting_again=extract_final_answer_by_prompting_again, + include_qa=include_qa, + evaluate_in_parallel=evaluate_in_parallel, + instruction_pos=instruction_pos, + is_multiple_choice=is_multiple_choice, + prediction_treat_as_number=prediction_treat_as_number, + prediction_treat_as_bool=prediction_treat_as_bool, + prediction_num_decimals=0, + max_retry=5, + sleep_time=180, + verbose=verbose, + ) + single_step_eval_on_few_shot[instruction] = detailed_results_df + + print( + f"Step {i_step}, single_step_eval_on_few_shot:" + f" {single_step_eval_on_few_shot}\n" + ) + generated_ins_on_few_shot_results_dict[i_step] = ( + single_step_eval_on_few_shot + ) + + # evaluate OLD instructions on the few-shot exemplars in meta-prompt + if few_shot_qa_pairs and evaluate_old_ins_on_few_shot: + print("evaluating OLD instructions on few-shot exemplars") + single_step_eval_on_few_shot = dict() + for instruction, _, _ in old_instructions_and_scores: + print( + f"evaluating Step {i_step}, instruction: {instruction} on few-shot" + " exemplars" + ) + detailed_results_df = eval_utils.evaluate_single_instruction( + data=raw_data, + instruction=instruction, + eval_index_all=few_shot_index_list, + batch_size=scorer_llm_dict["batch_size"], + call_server_func=call_scorer_server_func, + dataset_name=dataset_name, + num_servers=scorer_llm_dict["num_servers"], + extract_final_answer_by_prompting_again=extract_final_answer_by_prompting_again, + include_qa=include_qa, + evaluate_in_parallel=evaluate_in_parallel, + instruction_pos=instruction_pos, + is_multiple_choice=is_multiple_choice, + prediction_treat_as_number=prediction_treat_as_number, + prediction_treat_as_bool=prediction_treat_as_bool, + prediction_num_decimals=0, + max_retry=5, + sleep_time=180, + verbose=verbose, + ) + single_step_eval_on_few_shot[instruction] = detailed_results_df + + print( + f"Step {i_step}, single_step_eval_on_few_shot:" + f" {single_step_eval_on_few_shot}\n" + ) + old_ins_on_few_shot_results_dict[i_step] = single_step_eval_on_few_shot + + # evaluate newly generated instructions on the training set + for instruction in to_evaluate_instructions: + if instruction not in prev_saved_instructions: + print(f"""computing the score of "{instruction}" by prompting""") + detailed_results_df = eval_utils.evaluate_single_instruction( + data=raw_data, + instruction=instruction, + eval_index_all=train_index, + batch_size=batch_size, + call_server_func=call_scorer_server_func, + dataset_name=dataset_name, + num_servers=num_servers, + extract_final_answer_by_prompting_again=extract_final_answer_by_prompting_again, + include_qa=include_qa, + evaluate_in_parallel=evaluate_in_parallel, + instruction_pos=instruction_pos, + is_multiple_choice=is_multiple_choice, + prediction_treat_as_number=prediction_treat_as_number, + prediction_treat_as_bool=prediction_treat_as_bool, + prediction_num_decimals=0, + max_retry=5, + sleep_time=180, + verbose=verbose, + ) + prev_saved_instructions.add(instruction) + else: + # do not re-evaluate instructions that had been evaluated previously + detailed_results_df = pd.read_csv( + os.path.join(result_by_instruction_folder, f"{instruction}.csv"), + index_col=0, + header=0, + ) + print(f"""reading previously saved "{instruction}" information""") + + scores = detailed_results_df["accuracy"] + average_score = np.average(scores) + print( + f"Step {i_step}, instruction: {instruction}, score: {average_score}" + ) + + # increment the counter on wrong questions + wrong_question_indices_set = set( + list( + detailed_results_df[detailed_results_df["accuracy"] == 0.0].index + ) + ) + for idx in wrong_question_indices_set: + wrong_questions_from_start_counter[idx] += 1 + + filename = eval_utils.instruction_to_filename(instruction) + file_path = os.path.join( + result_by_instruction_folder, f"""{filename}.csv""" + ) + detailed_results_df.to_csv(file_path, index=True, header=True) + print(f"saving results to {file_path}") + + detailed_results_df_by_instruction_dict[instruction] = detailed_results_df + old_instructions_and_scores.append((instruction, average_score, i_step)) + instruction_score_dict[instruction] = average_score + + # record all generated instructions + for instruction in generated_instructions_raw: + if instruction in instruction_score_dict: + average_score = instruction_score_dict[instruction] + else: + average_score = np.nan + old_instructions_and_scores_raw.append( + (instruction, average_score, i_step) + ) + + # =============================== eval ==================================== + # every eval_interval steps, evaluate the instructions that were generated + # in the current step and were not skipped + if not i_step % eval_interval: + for instruction in generated_instructions_raw: + # if the instruction wasn't skipped in any step + if instruction in instruction_score_dict: + if instruction not in instruction_eval_score_dict: + detailed_results_df = eval_utils.evaluate_single_instruction( + data=raw_data, + instruction=instruction, + eval_index_all=eval_index, + batch_size=batch_size, + call_server_func=call_scorer_server_func, + dataset_name=dataset_name, + num_servers=num_servers, + extract_final_answer_by_prompting_again=extract_final_answer_by_prompting_again, + include_qa=include_qa, + evaluate_in_parallel=evaluate_in_parallel, + instruction_pos=instruction_pos, + is_multiple_choice=is_multiple_choice_eval, + prediction_treat_as_number=prediction_treat_as_number, + prediction_treat_as_bool=prediction_treat_as_bool, + prediction_num_decimals=0, + max_retry=5, + sleep_time=180, + verbose=verbose, + ) + eval_score = np.average(detailed_results_df["accuracy"]) + eval_detailed_results_df_dict[instruction] = detailed_results_df + instruction_eval_score_dict[instruction] = eval_score + else: + eval_score = instruction_eval_score_dict[instruction] + print( + f"EVAL: \nStep {i_step}, instruction: {instruction}, eval score:" + f" {eval_score:.2f}" + ) + eval_results.append((i_step, instruction, eval_score)) + + # ===================== save up-to-date results =========================== + results_dict = dict() + results_dict["meta_prompts"] = meta_prompts + results_dict["old_instructions_and_scores"] = list( + old_instructions_and_scores + ) + results_dict["old_instructions_and_scores_raw"] = list( + old_instructions_and_scores_raw + ) + results_dict["generated_ins_on_few_shot_results_dict"] = ( + generated_ins_on_few_shot_results_dict + ) + results_dict["old_ins_on_few_shot_results_dict"] = ( + old_ins_on_few_shot_results_dict + ) + results_dict["few_shot_index_list_by_step_dict"] = ( + few_shot_index_list_by_step_dict + ) + results_dict["eval_results"] = eval_results + results_dict["eval_detailed_results_df_dict"] = ( + eval_detailed_results_df_dict + ) + with open(os.path.join(save_folder, "results_dict.pkl"), "wb") as fp: + pickle.dump(results_dict, fp) + print(f"\nsaved all results to\n{save_folder}") diff --git a/optimization/optimize_instructions.py b/optimization/optimize_instructions.py new file mode 100644 index 0000000..e619ffb --- /dev/null +++ b/optimization/optimize_instructions.py @@ -0,0 +1,736 @@ + +import datetime +import functools +import os +import sys + +OPRO_ROOT_PATH = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +) +sys.path.insert(0, OPRO_ROOT_PATH) + +from absl import app +from absl import flags +import google.generativeai as palm +import numpy as np +import openai +from opro import prompt_utils +from opro.optimization import opt_utils +import pandas as pd +FLAGS = flags.FLAGS +ROOT_DATA_FOLDER_PATH = os.path.join(OPRO_ROOT_PATH, "data") + +flags.DEFINE_string("local_model_path", "", "Path to local vLLM model.") + +_OPENAI_API_KEY = flags.DEFINE_string( + "openai_api_key", "", "The OpenAI API key." +) + +_PALM_API_KEY = flags.DEFINE_string("palm_api_key", "", "The PaLM API key.") + +_SCORER = flags.DEFINE_string( + "scorer", "text-bison", "The name of the scorer LLM." +) + +_OPTIMIZER = flags.DEFINE_string( + "optimizer", "gpt-3.5-turbo", "The name of the optimizer LLM." +) + +_DATASET = flags.DEFINE_string( + "dataset", "gsm8k", "The name of dataset to search for instructions on." +) + +_TASK = flags.DEFINE_string( + "task", + "train", + "The name of task within the above dataset to search for instructions on.", +) + +_INSTRUCTION_POS = flags.DEFINE_string( + "instruction_pos", + "A_begin", + "The position of the instruction to search for.", +) + +_META_PROMPT_TYPE = flags.DEFINE_string( + "meta_prompt_type", + "both_instructions_and_exemplars", + "The type of meta-prompt: whether to have both previous instructions and" + " dataset exemplars (often for fine-tuned optimizers), or to have only" + " previous instructions (often for pre-trained optimizers).", +) + + +def main(_): + local_model_path = FLAGS.local_model_path + openai_api_key = _OPENAI_API_KEY.value + palm_api_key = _PALM_API_KEY.value + scorer_llm_name = _SCORER.value + optimizer_llm_name = _OPTIMIZER.value + dataset_name = _DATASET.value.lower() + task_name = _TASK.value + meta_prompt_type = _META_PROMPT_TYPE.value + + assert dataset_name in { + "mmlu", + "bbh", + "gsm8k", + }, "The lower-case dataset name must be one of mmlu, bbh, or gsm8k." + if dataset_name == "mmlu": + assert task_name in { + "STEM", + "humanities", + "social sciences", + "other (business, health, misc.)", + } # for now only support searching on one MMLU category + elif dataset_name == "bbh": + assert task_name in { + "boolean_expressions", + "causal_judgement", + "date_understanding", + "disambiguation_qa", + "dyck_languages", + "formal_fallacies", + "geometric_shapes", + "hyperbaton", + "logical_deduction_five_objects", + "logical_deduction_seven_objects", + "logical_deduction_three_objects", + "movie_recommendation", + "multistep_arithmetic_two", + "navigate", + "object_counting", + "penguins_in_a_table", + "reasoning_about_colored_objects", + "ruin_names", + "salient_translation_error_detection", + "snarks", + "sports_understanding", + "temporal_sequences", + "tracking_shuffled_objects_five_objects", + "tracking_shuffled_objects_seven_objects", + "tracking_shuffled_objects_three_objects", + "web_of_lies", + "word_sorting", + } + else: + assert dataset_name == "gsm8k" + assert task_name in {"train", "test"} + + assert scorer_llm_name in { + "text-bison", + "gpt-3.5-turbo", + "gpt-4", + "local", + } + assert optimizer_llm_name in { + "text-bison", + "gpt-3.5-turbo", + "gpt-4", + "local", + } + assert meta_prompt_type in { + "both_instructions_and_exemplars", + "instructions_only", + } + + instruction_pos = _INSTRUCTION_POS.value + assert instruction_pos in { + "before_Q", + "Q_begin", + "Q_end", + "A_begin", + }, ( + "The instruction position should be either before the question, or at the" + " beginning of the question, at the end of the question, or at the" + " beginning of the answer." + ) + print( + f"scorer: {scorer_llm_name}, optimizer: {optimizer_llm_name}, dataset:" + f" {dataset_name}, task: {task_name}, instruction_pos: {instruction_pos}" + ) + + + if scorer_llm_name in {"gpt-3.5-turbo", "gpt-4"}: + assert openai_api_key, "The OpenAI API key must be provided." + openai.api_key = openai_api_key + elif scorer_llm_name == "text-bison": + assert palm_api_key, "A PaLM API key is needed when prompting the text-bison model." + palm.configure(api_key=palm_api_key) + elif scorer_llm_name == "local": + # 本地模型,无需 API key + pass + else: + raise ValueError(f"Unknown scorer model: {scorer_llm_name}") + + if optimizer_llm_name in {"gpt-3.5-turbo", "gpt-4"}: + assert openai_api_key, "The OpenAI API key must be provided." + openai.api_key = openai_api_key + elif optimizer_llm_name == "text-bison": + assert palm_api_key, "A PaLM API key is needed when prompting the text-bison model." + palm.configure(api_key=palm_api_key) + elif optimizer_llm_name == "local": + # 本地模型,无需 API key + pass + else: + raise ValueError(f"Unknown scorer model: {optimizer_llm_name}") + + + if dataset_name == "mmlu": + root_data_folder_path = os.path.join(ROOT_DATA_FOLDER_PATH, "MMLU-data") + elif dataset_name == "bbh": + root_data_folder_path = os.path.join( + ROOT_DATA_FOLDER_PATH, "BIG-Bench-Hard-data/" + ) + else: + assert dataset_name == "gsm8k" + root_data_folder_path = os.path.join(ROOT_DATA_FOLDER_PATH, "gsm_data") + + # =================== create the result directory ========================== + datetime_str = ( + str(datetime.datetime.now().replace(microsecond=0)) + .replace(" ", "-") + .replace(":", "-") + ) + + + save_folder = os.path.join( + OPRO_ROOT_PATH, + "outputs", + "optimization-results", + f"{dataset_name.upper()}-{task_name}-s-{scorer_llm_name}-o-{optimizer_llm_name}-{datetime_str}/", +) + + + result_by_instruction_folder = os.path.join( + save_folder, "result_by_instruction" + ) + print(f"Results will be saved to: {os.path.abspath(result_by_instruction_folder)}") + os.makedirs(result_by_instruction_folder,exist_ok=True) + print(f"result directory:\n{save_folder}") + + # ====================== scorer model configs ============================== + + + if scorer_llm_name == "text-bison": + # when prompting text-bison with Cloud API + scorer_finetuned_palm_temperature = 0.0 + scorer_finetuned_palm_max_decode_steps = 1024 + scorer_finetuned_palm_batch_size = 1 + scorer_finetuned_palm_num_servers = 1 + scorer_finetuned_palm_dict = dict() + scorer_finetuned_palm_dict["temperature"] = ( + scorer_finetuned_palm_temperature + ) + scorer_finetuned_palm_dict["num_servers"] = ( + scorer_finetuned_palm_num_servers + ) + scorer_finetuned_palm_dict["batch_size"] = scorer_finetuned_palm_batch_size + scorer_finetuned_palm_dict["max_decode_steps"] = ( + scorer_finetuned_palm_max_decode_steps + ) + + call_scorer_finetuned_palm_server_func = functools.partial( + prompt_utils.call_palm_server_from_cloud, + model="text-bison-001", + temperature=scorer_finetuned_palm_dict["temperature"], + max_decode_steps=scorer_finetuned_palm_dict["max_decode_steps"], + ) + + scorer_llm_dict = { + "model_type": scorer_llm_name.lower(), + } + scorer_llm_dict.update(scorer_finetuned_palm_dict) + call_scorer_server_func = call_scorer_finetuned_palm_server_func + + + elif scorer_llm_name.lower() in {"gpt-3.5-turbo", "gpt-4", "local"}: + # 改成调用本地vLLM版本的函数 + scorer_gpt_max_decode_steps = 1024 + # scorer_gpt_max_decode_steps = 512 + scorer_gpt_temperature = 0.0 + + scorer_llm_dict = { + "model_type": scorer_llm_name.lower(), + "max_decode_steps": scorer_gpt_max_decode_steps, + "temperature": scorer_gpt_temperature, + "num_decodes": 1, + "batch_size": 1, + "num_servers": 1, + } + + call_scorer_server_func = functools.partial( + prompt_utils.call_openai_server_func, # 你本地实现的vLLM调用函数 + max_decode_steps=scorer_gpt_max_decode_steps, + temperature=scorer_gpt_temperature, + local_model_path=FLAGS.local_model_path, # 传入你本地模型路径 + ) + + else: + raise ValueError(f"Unsupported scorer_llm_name: {scorer_llm_name}") + + + # ====================== optimizer model configs ============================ + if optimizer_llm_name.lower() == "text-bison": + # when prompting text-bison with Cloud API + optimizer_finetuned_palm_temperature = 1.0 + optimizer_finetuned_palm_num_decodes = 8 + optimizer_finetuned_palm_max_decode_steps = 1024 + optimizer_finetuned_palm_batch_size = 1 + optimizer_finetuned_palm_num_servers = 1 + optimizer_finetuned_palm_dict = dict() + optimizer_finetuned_palm_dict["temperature"] = ( + optimizer_finetuned_palm_temperature + ) + optimizer_finetuned_palm_dict["num_decodes"] = ( + optimizer_finetuned_palm_num_decodes + ) + optimizer_finetuned_palm_dict["batch_size"] = ( + optimizer_finetuned_palm_batch_size + ) + optimizer_finetuned_palm_dict["num_servers"] = ( + optimizer_finetuned_palm_num_servers + ) + optimizer_finetuned_palm_dict["max_decode_steps"] = ( + optimizer_finetuned_palm_max_decode_steps + ) + + call_optimizer_finetuned_palm_server_func = functools.partial( + prompt_utils.call_palm_server_from_cloud, + model="text-bison-001", + temperature=optimizer_finetuned_palm_dict["temperature"], + max_decode_steps=optimizer_finetuned_palm_dict["max_decode_steps"], + ) + + optimizer_llm_dict = { + "model_type": optimizer_llm_name.lower(), + } + optimizer_llm_dict.update(optimizer_finetuned_palm_dict) + call_optimizer_server_func = call_optimizer_finetuned_palm_server_func + + elif optimizer_llm_name.lower() in {"gpt-3.5-turbo", "gpt-4", "local"}: + # 用本地 vLLM 版本替代调用 + optimizer_gpt_max_decode_steps = 512 + + optimizer_gpt_temperature = 1.0 + + optimizer_llm_dict = { + "max_decode_steps": optimizer_gpt_max_decode_steps, + "temperature": optimizer_gpt_temperature, + "batch_size": 1, + "num_decodes": 1, + } + + call_optimizer_server_func = functools.partial( + prompt_utils.call_openai_server_func, # 你写的本地vLLM调用接口 + max_decode_steps=optimizer_gpt_max_decode_steps, + temperature=optimizer_gpt_temperature, + local_model_path=FLAGS.local_model_path, + ) + + else: + raise ValueError(f"Unsupported optimizer_llm_name: {optimizer_llm_name}") + + + + # ====================== try calling the servers ============================ + print("\n======== testing the scorer and optimizer servers ===========") + scorer_test_output = call_scorer_server_func( + "Does the sun rise from the north? Just answer yes or no." + ) + print(f"number of scorer output decodes: {len(scorer_test_output)}") + print(f"scorer test output: {scorer_test_output}") + optimizer_test_output = call_optimizer_server_func( + "Does the sun rise from the north? Just answer yes or no.", + temperature=1.0, + ) + print(f"number of optimizer output decodes: {len(optimizer_test_output)}") + print(f"optimizer test output: {optimizer_test_output}") + print("Finished testing the servers.") + + # ====================== read data ============================ + print("\n================ prompt optimization settings ==============") + # from https://github.com/hendrycks/test/blob/master/categories.py + subcategories = { + "abstract_algebra": ["math"], + "anatomy": ["health"], + "astronomy": ["physics"], + "business_ethics": ["business"], + "clinical_knowledge": ["health"], + "college_biology": ["biology"], + "college_chemistry": ["chemistry"], + "college_computer_science": ["computer science"], + "college_mathematics": ["math"], + "college_medicine": ["health"], + "college_physics": ["physics"], + "computer_security": ["computer science"], + "conceptual_physics": ["physics"], + "econometrics": ["economics"], + "electrical_engineering": ["engineering"], + "elementary_mathematics": ["math"], + "formal_logic": ["philosophy"], + "global_facts": ["other"], + "high_school_biology": ["biology"], + "high_school_chemistry": ["chemistry"], + "high_school_computer_science": ["computer science"], + "high_school_european_history": ["history"], + "high_school_geography": ["geography"], + "high_school_government_and_politics": ["politics"], + "high_school_macroeconomics": ["economics"], + "high_school_mathematics": ["math"], + "high_school_microeconomics": ["economics"], + "high_school_physics": ["physics"], + "high_school_psychology": ["psychology"], + "high_school_statistics": ["math"], + "high_school_us_history": ["history"], + "high_school_world_history": ["history"], + "human_aging": ["health"], + "human_sexuality": ["culture"], + "international_law": ["law"], + "jurisprudence": ["law"], + "logical_fallacies": ["philosophy"], + "machine_learning": ["computer science"], + "management": ["business"], + "marketing": ["business"], + "medical_genetics": ["health"], + "miscellaneous": ["other"], + "moral_disputes": ["philosophy"], + "moral_scenarios": ["philosophy"], + "nutrition": ["health"], + "philosophy": ["philosophy"], + "prehistory": ["history"], + "professional_accounting": ["other"], + "professional_law": ["law"], + "professional_medicine": ["health"], + "professional_psychology": ["psychology"], + "public_relations": ["politics"], + "security_studies": ["politics"], + "sociology": ["culture"], + "us_foreign_policy": ["politics"], + "virology": ["health"], + "world_religions": ["philosophy"], + } + + categories = { + "STEM": [ + "physics", + "chemistry", + "biology", + "computer science", + "math", + "engineering", + ], + "humanities": ["history", "philosophy", "law"], + "social sciences": [ + "politics", + "culture", + "economics", + "geography", + "psychology", + ], + "other (business, health, misc.)": ["other", "business", "health"], + } + + if dataset_name == "mmlu": + + category_names = [task_name] + folder_name = "test" # one of {'auxiliary_train', 'dev', 'val', 'test'} + task_names = [] + for task_csv_name in os.listdir( + os.path.join(root_data_folder_path, folder_name) + ): + task_names.append(task_csv_name.split(".")[0]) + + tasks_in_category = [] + for category_name in category_names: + for task_name in task_names: + for subname in subcategories: + if subname in task_name: + if subcategories[subname][0] in categories[category_name]: + tasks_in_category.append(task_name) + break + + tasks_all = [(folder_name, task_name) for task_name in tasks_in_category] + multiple_choice_tasks = set([item[1] for item in tasks_all]) + boolean_tasks = set() + numerical_output_tasks = set() + + + elif dataset_name == "bbh": + tasks_all = [task_name] + assert ( + len(tasks_all) == 1 + ), "for now only support prompt optimization on one BBH task" + + + numerical_output_tasks = { + "object_counting", + "multistep_arithmetic_two", + } + + multiple_choice_tasks = { + "date_understanding", + "disambiguation_qa", + "geometric_shapes", + "hyperbaton", + "logical_deduction_five_objects", + "logical_deduction_seven_objects", + "logical_deduction_three_objects", + "movie_recommendation", + "penguins_in_a_table", + "reasoning_about_colored_objects", + "ruin_names", + "salient_translation_error_detection", + "snarks", + "temporal_sequences", + "tracking_shuffled_objects_five_objects", + "tracking_shuffled_objects_seven_objects", + "tracking_shuffled_objects_three_objects", + } + + boolean_tasks = { + "boolean_expressions", # True or False + "causal_judgement", # yes or no + "formal_fallacies", # valid or invalid + "navigate", # yes or no + "sports_understanding", # yes or no + "web_of_lies", # yes or no + } + + else: + assert dataset_name in {"gsm8k"} + tasks_all = [task_name] + multiple_choice_tasks = set() + boolean_tasks = set() + numerical_output_tasks = set(tasks_all) + + if dataset_name == "mmlu": + raw_data = pd.DataFrame() + prediction_treat_as_number = False + prediction_treat_as_bool = False + elif dataset_name == "bbh": + raw_data = [] + prediction_treat_as_number = bool( + tasks_all[0] in numerical_output_tasks + ) # for now only check the first task + prediction_treat_as_bool = bool( + tasks_all[0] in boolean_tasks + ) # for now only check the first task + print( + f"prediction_treat_as_number: {prediction_treat_as_number}," + f" prediction_treat_as_bool: {prediction_treat_as_bool}" + ) + else: + assert dataset_name == "gsm8k" + raw_data = pd.DataFrame() + prediction_treat_as_number = True + prediction_treat_as_bool = False + + for t in tasks_all: + if dataset_name == "mmlu": + folder_name = t[0] + task_name = t[1] + single_task_df = pd.read_csv( + os.path.join(root_data_folder_path, f"{folder_name}/{task_name}.csv"), + index_col=None, + header=None, + ) + raw_data = pd.concat([raw_data, single_task_df]) + elif dataset_name == "bbh": + task_name = t + single_task_list = opt_utils.load_bbh_task_data( + task_name, base_dir=root_data_folder_path + ) + raw_data += single_task_list + else: + assert dataset_name == "gsm8k" + task_name = t + f_gsm = os.path.join(root_data_folder_path, f"gsm_{task_name}.tsv") + single_task_df = pd.read_csv(f_gsm, sep="\t", header=None) + raw_data = pd.concat([raw_data, single_task_df]) + + if dataset_name == "mmlu": + num_examples = raw_data.shape[0] + elif dataset_name == "bbh": + num_examples = len(raw_data) + else: + assert dataset_name in {"gsm8k"} + num_examples = raw_data.shape[0] + print(f"number of examples in the current task: {num_examples}") + + # ================ split data into train/val/test ========================== + if dataset_name == "mmlu": + train_ratio = 0.8 + + eval_ratio = 0.2 + elif dataset_name == "gsm8k": +# train_ratio = 0.035 + train_ratio = 0.01 # 原来是 0.035,改成 0.01,约 74 条 + eval_ratio = 0 + else: + assert dataset_name == "bbh" + train_ratio = 0.2 + eval_ratio = 0 + + + assert train_ratio + eval_ratio <= 1 + test_ratio = 1 - train_ratio - eval_ratio + print( + f"train_ratio: {train_ratio}, eval_ratio: {eval_ratio}, " + f"test_ratio: {test_ratio}" + ) + np.random.seed(0) + train_index = np.sort( + np.array( + np.random.choice( + num_examples, size=int(train_ratio * num_examples), replace=False + ) + ) + ) + eval_and_test_index = np.sort( + np.array(list(set(np.arange(num_examples)) - set(train_index))) + ) + eval_index = np.sort( + np.array( + np.random.choice( + eval_and_test_index, + size=int(eval_ratio * num_examples), + replace=False, + ) + ) + ) + + # ========== set other optimization experiment hyperparameters ============== + if scorer_llm_name == "text-bison": + old_instruction_score_threshold = 0.0 + # old_instruction_score_threshold = 0.15 # for GSM8K + else: + assert scorer_llm_name in {"gpt-3.5-turbo", "gpt-4", "local"} + old_instruction_score_threshold = 0.3 + + if scorer_llm_name == "text-bison": + extract_final_answer_by_prompting_again = False + include_qa = False + evaluate_in_parallel = False + else: + assert scorer_llm_name in {"gpt-3.5-turbo", "gpt-4", "local"} + extract_final_answer_by_prompting_again = False + include_qa = False + evaluate_in_parallel = False + + optimizer_llm_temperature = optimizer_llm_dict["temperature"] + + +# num_few_shot_questions_for_instruction_refinement = 3 + num_few_shot_questions_for_instruction_refinement = 1 # 减少 few-shot 例子数 + +# num_generated_instructions_in_each_step = 8 + num_generated_instructions_in_each_step = 2 # 每步只生成 2 条指令 + +# num_search_steps = 200 + num_search_steps = 3 # 原来是 200,改成 3 步即可 + + + initial_instructions = [ + "Let's solve the problem.", + # "", + # "The answer is", + ] + few_shot_qa_pairs = True + # one of {'accumulative_most_frequent', 'current_most_frequent', 'random', + # 'constant'} + few_shot_selection_criteria = "random" + # whether to evaluate generated instructions on the exemplars in meta-prompt + evaluate_generated_ins_on_few_shot = False + # whether to evaluate old instructions on the exemplars in the meta-prompt + evaluate_old_ins_on_few_shot = False + # every this number of steps, compute the accuracies of current-step + # instructions on the validation set +# eval_interval = 3 + eval_interval = 1 # 每步就 eval 一次,及时看到结果 + # eval_interval = 10 + max_num_instructions = ( + 20 # the maximum number of instructions and scores in the meta-prompt + ) + # The number of buckets when converting scores to integers in the meta-prompt. + num_score_buckets = 100 + # whether to put old instructions and scores to before exemplars in + # the meta-prompt + meta_prompt_instructions_before_exemplars = True + + # ===================== run prompt optimization ====================== + + assert few_shot_selection_criteria in { + "accumulative_most_frequent", + "current_most_frequent", + "random", + "constant", + } + evolution_kwargs = { + "num_search_steps": num_search_steps, + "old_instruction_score_threshold": old_instruction_score_threshold, + "scorer_llm_dict": scorer_llm_dict, + "optimizer_llm_dict": optimizer_llm_dict, + "extract_final_answer_by_prompting_again": ( + extract_final_answer_by_prompting_again + ), + "include_qa": include_qa, + "evaluate_in_parallel": evaluate_in_parallel, + "tasks_all": tasks_all, + "train_ratio": train_ratio, + "eval_ratio": eval_ratio, + "test_ratio": test_ratio, + "train_index": train_index, + "eval_index": eval_index, + "dataset_name": dataset_name, + "task_name": task_name, + "num_examples": num_examples, + "root_data_folder_path": root_data_folder_path, + "optimizer_llm_temperature": optimizer_llm_temperature, + # "optimizer_llm_temperature_schedule": ( + # optimizer_llm_temperature_schedule + # ), + # "optimizer_llm_temperature_end": optimizer_llm_temperature_end, + "initial_instructions": initial_instructions, + "multiple_choice_tasks": multiple_choice_tasks, + "raw_data": raw_data, + "call_scorer_server_func": call_scorer_server_func, + "call_optimizer_server_func": call_optimizer_server_func, + "instruction_pos": instruction_pos, + "prediction_treat_as_number": prediction_treat_as_number, + "prediction_treat_as_bool": prediction_treat_as_bool, + "result_by_instruction_folder": result_by_instruction_folder, + "few_shot_qa_pairs": few_shot_qa_pairs, + "num_score_buckets": num_score_buckets, + "max_num_instructions": max_num_instructions, + "meta_prompt_type": meta_prompt_type, + "meta_prompt_instructions_before_exemplars": ( + meta_prompt_instructions_before_exemplars + ), + "few_shot_selection_criteria": few_shot_selection_criteria, + "optimizer_llm_name": optimizer_llm_name, + "num_generated_instructions_in_each_step": ( + num_generated_instructions_in_each_step + ), + "evaluate_generated_ins_on_few_shot": evaluate_generated_ins_on_few_shot, + "num_few_shot_questions_for_instruction_refinement": ( + num_few_shot_questions_for_instruction_refinement + ), + "evaluate_old_ins_on_few_shot": evaluate_old_ins_on_few_shot, + "eval_interval": eval_interval, + "save_folder": save_folder, + } + print("=== 开始优化过程 ===") + try: + opt_utils.run_evolution(**evolution_kwargs) + print("=== 优化完成 ===") + except Exception as e: + import traceback + print(f"!!! 优化失败: {e} !!!", file=sys.stderr) + traceback.print_exc() + sys.exit(1) + + +if __name__ == "__main__": + app.run(main) + diff --git a/optimization/optimize_linear_regression.py b/optimization/optimize_linear_regression.py new file mode 100644 index 0000000..f856434 --- /dev/null +++ b/optimization/optimize_linear_regression.py @@ -0,0 +1,424 @@ +# Copyright 2023 The OPRO Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r"""Optimize over the objective function of a linear regression problem. + +Usage: + +``` +python optimize_linear_regression.py --optimizer="text-bison" +``` + +Note: +- When using a Google-Cloud-served model (like text-bison at +https://developers.generativeai.google/tutorials/text_quickstart), add +`--palm_api_key=""` +- When using an OpenAI model, add `--openai_api_key=""` +""" + +import datetime +import functools +import json +import os +import re +import sys + +OPRO_ROOT_PATH = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +) +sys.path.insert(0, OPRO_ROOT_PATH) + +from absl import app +from absl import flags +import google.generativeai as palm +import numpy as np +import openai + +from opro import prompt_utils + +_OPENAI_API_KEY = flags.DEFINE_string( + "openai_api_key", "", "The OpenAI API key." +) + +_PALM_API_KEY = flags.DEFINE_string("palm_api_key", "", "The PaLM API key.") + +_OPTIMIZER = flags.DEFINE_string( + "optimizer", "gpt-3.5-turbo", "The name of the optimizer LLM." +) + + +def main(_): + # ============== set optimization experiment configurations ================ + num_points = 50 # number of points in linear regression + w_true = 15 # the true w + b_true = 14 # the true b + max_num_steps = 500 # the number of optimization steps + num_reps = 5 # the number of repeated runs + max_num_pairs = 20 # the maximum number of input-output pairs in meta-prompt + num_input_decimals = 0 # num of decimals for input values in meta-prompt + num_output_decimals = 0 # num of decimals for output values in meta-prompt + num_generated_points_in_each_step = 8 + + # ================ load LLM settings =================== + optimizer_llm_name = _OPTIMIZER.value + assert optimizer_llm_name in { + "text-bison", + "gpt-3.5-turbo", + "gpt-4", + } + openai_api_key = _OPENAI_API_KEY.value + palm_api_key = _PALM_API_KEY.value + + if optimizer_llm_name in {"gpt-3.5-turbo", "gpt-4"}: + assert openai_api_key, "The OpenAI API key must be provided." + openai.api_key = openai_api_key + else: + assert optimizer_llm_name == "text-bison" + assert ( + palm_api_key + ), "A PaLM API key is needed when prompting the text-bison model." + palm.configure(api_key=palm_api_key) + + # =================== create the result directory ========================== + datetime_str = ( + str(datetime.datetime.now().replace(microsecond=0)) + .replace(" ", "-") + .replace(":", "-") + ) + + save_folder = os.path.join( + OPRO_ROOT_PATH, + "outputs", + "optimization-results", + f"linear_regression-o-{optimizer_llm_name}-{datetime_str}/", + ) + os.makedirs(save_folder) + print(f"result directory:\n{save_folder}") + + # ====================== optimizer model configs ============================ + if optimizer_llm_name.lower() == "text-bison": + # when prompting text-bison with Cloud API + optimizer_finetuned_palm_temperature = 1.0 + optimizer_finetuned_palm_max_decode_steps = 1024 + optimizer_finetuned_palm_batch_size = 1 + optimizer_finetuned_palm_num_servers = 1 + optimizer_finetuned_palm_dict = dict() + optimizer_finetuned_palm_dict["temperature"] = ( + optimizer_finetuned_palm_temperature + ) + optimizer_finetuned_palm_dict["batch_size"] = ( + optimizer_finetuned_palm_batch_size + ) + optimizer_finetuned_palm_dict["num_servers"] = ( + optimizer_finetuned_palm_num_servers + ) + optimizer_finetuned_palm_dict["max_decode_steps"] = ( + optimizer_finetuned_palm_max_decode_steps + ) + + call_optimizer_finetuned_palm_server_func = functools.partial( + prompt_utils.call_palm_server_from_cloud, + # prompt_utils.call_vllm, + model="text-bison-001", + temperature=optimizer_finetuned_palm_dict["temperature"], + max_decode_steps=optimizer_finetuned_palm_dict["max_decode_steps"], + ) + + optimizer_llm_dict = { + "model_type": optimizer_llm_name.lower(), + } + optimizer_llm_dict.update(optimizer_finetuned_palm_dict) + call_optimizer_server_func = call_optimizer_finetuned_palm_server_func + + else: + assert optimizer_llm_name in {"gpt-3.5-turbo", "gpt-4"} + optimizer_gpt_max_decode_steps = 1024 + optimizer_gpt_temperature = 1.0 + + optimizer_llm_dict = dict() + optimizer_llm_dict["max_decode_steps"] = optimizer_gpt_max_decode_steps + optimizer_llm_dict["temperature"] = optimizer_gpt_temperature + optimizer_llm_dict["batch_size"] = 1 + call_optimizer_server_func = functools.partial( + prompt_utils.call_openai_server_func, + model=optimizer_llm_name, + max_decode_steps=optimizer_gpt_max_decode_steps, + temperature=optimizer_gpt_temperature, + ) + + # ====================== try calling the servers ============================ + print("\n======== testing the optimizer server ===========") + optimizer_test_output = call_optimizer_server_func( + "Does the sun rise from the north? Just answer yes or no.", + temperature=1.0, + ) + print(f"optimizer test output: {optimizer_test_output}") + print("Finished testing the optimizer server.") + print("\n=================================================") + + # ====================== utility functions ============================ + def evaluate_loss(X, y, w, b): # pylint: disable=invalid-name + residual = y - (X * w + b) + return np.linalg.norm(residual) ** 2 + + def gen_meta_prompt( + old_value_pairs_set, + X, # pylint: disable=invalid-name, unused-argument + y, # pylint: disable=unused-argument + num_input_decimals=5, + num_output_decimals=5, + max_num_pairs=100, + ): + """Generate the meta-prompt for optimization. + + Args: + old_value_pairs_set (set): the set of old (w, b, z) pairs. + X (np.array): the 1D array of x values. + y (np.array): the 1D array of y values. + num_input_decimals (int): the number of decimals for (w, b) in the + meta-prompt. + num_output_decimals (int): the number of decimals for z in the meta-prompt. + max_num_pairs (int): the maximum number of exemplars in the meta-prompt. + + Returns: + meta_prompt (str): the generated meta-prompt. + """ + old_value_pairs_set = set( + [ # pylint: disable=g-complex-comprehension + ( + np.round(w, num_input_decimals) + if num_input_decimals > 0 + else int(w), + np.round(b, num_input_decimals) + if num_input_decimals > 0 + else int(b), + np.round(z, num_output_decimals) + if num_output_decimals > 0 + else int(z), + ) + for w, b, z in old_value_pairs_set + ] + ) + old_value_pairs = list(old_value_pairs_set) + old_value_pairs = sorted(old_value_pairs, key=lambda x: -x[2])[ + -max_num_pairs: + ] + old_value_pairs_substr = "" + for w, b, z in old_value_pairs: + old_value_pairs_substr += f"\ninput:\nw={w}, b={b}\nvalue:\n{z}\n" + meta_prompt = """ + Now you will help me minimize a function with two input variables w, b. I have some (w, b) pairs and the function values at those points. The pairs are arranged in descending order based on their function values, where lower values are better. + """.strip() + meta_prompt += "\n\n" + meta_prompt += old_value_pairs_substr.strip() + meta_prompt += "\n\n" + # function_analytic_form = "" + # for xi, yi in zip(X, y): + # function_analytic_form += f"({yi:.4f} - ({xi:.4f} * w + b)) ** 2 + " + # function_analytic_form = function_analytic_form[:-3] + # meta_prompt += ( + # "The function has the analytic form f(w, b) =" + # f" {function_analytic_form}. When evaluating the value of a (w, b)" + # " pair, you should replace the w and b in the analytic form with your" + # " values and do the computation." + # ) + # meta_prompt += "\n\n" + meta_prompt += """Give me a new (w, b) pair that is different from all pairs above, and has a function value lower than any of the above. Do not write code. The output must end with a pair [w, b], where w and b are numerical values. + """.strip() + return meta_prompt + + def extract_string_in_square_brackets(input_string): + raw_result = re.findall(r"\[.*?\]", input_string) + if raw_result: + for pair in raw_result[::-1]: + if "=" not in pair and ("w" in pair or "b" in pair): + continue + return pair[1:-1] + return "" + else: + return "" + + def parse_output(extracted_output): + """Parse the extracted output 'w, b' string to np.array([w, b]). + + Args: + extracted_output (str): the extracted output string, like '1.5, 2.5'. + + Returns: + parsed_output (np.array): the parsed output in a numpy array, like [1.5, + 2.5]. + """ + if not extracted_output: + return + extracted_values = [] + for item in extracted_output.split(","): + if "=" in item: + item = item[item.index("=") + 1 :] + extracted_values.append(item.strip()) + parsed_output = np.array(extracted_values).astype(float) + return parsed_output + + configs_dict = dict() + results_dict = dict() + num_convergence_steps = [] + for i_rep in range(num_reps): + found_optimal = False + print(f"\nRep {i_rep}:") + + # ================= generate the ground truth X, y ===================== + X = np.arange(num_points).astype(float) + 1 # pylint: disable=invalid-name + np.random.seed(i_rep + 1) + y = X * w_true + b_true + np.random.randn(num_points) + loss_at_true_values = evaluate_loss(X, y, w_true, b_true) + print(f"value at (w_true, b_true): {loss_at_true_values}") + + # ================= generate the starting points ===================== + num_starting_points = 5 # the number of initial points for optimization + np.random.seed((i_rep + 1) * 10) + init_w = np.random.uniform(low=10, high=20, size=num_starting_points) + np.random.seed((i_rep + 1) * 100) + init_b = np.random.uniform(low=10, high=20, size=num_starting_points) + + # ====================== run optimization ============================ + configs_dict_single_rep = { + "optimizer_llm_configs": optimizer_llm_dict, + "data": { + "num_points": num_points, + "w_true": w_true, + "b_true": b_true, + "loss_at_true_values": loss_at_true_values, + "X": list(X), + "y": list(y), + }, + "init_w": list(init_w), + "init_b": list(init_b), + "max_num_steps": max_num_steps, + "max_num_pairs": max_num_pairs, + "num_input_decimals": num_input_decimals, + "num_output_decimals": num_output_decimals, + "num_generated_points_in_each_step": num_generated_points_in_each_step, + } + configs_dict[i_rep] = configs_dict_single_rep + configs_json_path = os.path.join(save_folder, "configs.json") + print(f"saving configs to\n{configs_json_path}") + with open(configs_json_path, "w") as f: + json.dump(configs_dict, f, indent=4) + + old_value_pairs_set = set() + old_value_pairs_with_i_step = [] # format: [(w, b, z = f(w, b), i_step)] + meta_prompts_dict = dict() # format: {i_step: meta_prompt} + raw_outputs_dict = dict() # format: {i_step: raw_outputs} + + rounded_inits = [ + (np.round(w, num_input_decimals), np.round(b, num_input_decimals)) + for w, b in zip(init_w, init_b) + ] + rounded_inits = [ + tuple(item) for item in list(np.unique(rounded_inits, axis=0)) + ] + for w, b in rounded_inits: + z = evaluate_loss(X, y, w, b) + old_value_pairs_set.add((w, b, z)) + old_value_pairs_with_i_step.append((w, b, z, -1)) + + print("\n================ run optimization ==============") + print( + f"initial points: {[tuple(item[:2]) for item in old_value_pairs_set]}" + ) + print(f"initial values: {[item[-1] for item in old_value_pairs_set]}") + results_json_path = os.path.join(save_folder, "results.json") + print(f"saving results to\n{results_json_path}") + + for i_step in range(max_num_steps): + print(f"\nStep {i_step}:") + meta_prompt = gen_meta_prompt( + old_value_pairs_set, + X, + y, + num_input_decimals=num_input_decimals, + num_output_decimals=num_output_decimals, + max_num_pairs=max_num_pairs, + ) + if not i_step % 5: + print("\n=================================================") + print(f"meta_prompt:\n{meta_prompt}") + meta_prompts_dict[i_step] = meta_prompt + + # generate a maximum of the given number of points in each step + remaining_num_points_to_generate = num_generated_points_in_each_step + raw_outputs = [] + while remaining_num_points_to_generate > 0: + raw_outputs += call_optimizer_server_func(meta_prompt) + remaining_num_points_to_generate -= optimizer_llm_dict["batch_size"] + raw_outputs = raw_outputs[:num_generated_points_in_each_step] + + raw_outputs_dict[i_step] = raw_outputs + parsed_outputs = [] + for string in raw_outputs: + if not i_step % 5: + print("\n=================================================") + print("raw output:\n", string) + print("\n=================================================") + try: + parsed_output = parse_output( + extract_string_in_square_brackets(string) + ) + if parsed_output is not None and len(parsed_output) == 2: + parsed_outputs.append(parsed_output) + except ValueError: + pass + parsed_outputs = [tuple(item) for item in parsed_outputs] + print(f"proposed points before rounding: {parsed_outputs}") + + # round the proposed points to the number of decimals in meta-prompt + rounded_outputs = [ + (np.round(w, num_input_decimals), np.round(b, num_input_decimals)) + for w, b in parsed_outputs + ] + rounded_outputs = [ + tuple(item) for item in list(np.unique(rounded_outputs, axis=0)) + ] + print(f"proposed points after rounding: {rounded_outputs}") + + # evaluate the values of proposed and rounded outputs + single_step_values = [] + for w, b in rounded_outputs: + if w == w_true and b == b_true: + found_optimal = True + z = evaluate_loss(X, y, w, b) + single_step_values.append(z) + old_value_pairs_set.add((w, b, z)) + old_value_pairs_with_i_step.append((w, b, z, i_step)) + print(f"single_step_values: {single_step_values}") + + # ====================== save results ============================ + results_dict_single_rep = { + "meta_prompts": meta_prompts_dict, + "raw_outputs": raw_outputs_dict, + "old_value_pairs_with_i_step": old_value_pairs_with_i_step, + } + results_dict[i_rep] = results_dict_single_rep + with open(results_json_path, "w") as f: + json.dump(results_dict, f, indent=4) + if found_optimal: + print( + f"Repetition {i_rep+1}, optimal found at Step {i_step+1}, saving" + f" final results to\n{save_folder}" + ) + num_convergence_steps.append(i_step + 1) + break + print(f"num_convergence_steps: {num_convergence_steps}") + + +if __name__ == "__main__": + app.run(main) diff --git a/optimization/optimize_tsp.py b/optimization/optimize_tsp.py new file mode 100644 index 0000000..a8f99d9 --- /dev/null +++ b/optimization/optimize_tsp.py @@ -0,0 +1,430 @@ +# Copyright 2024 The OPRO Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r"""Optimize over the objective function of a traveling salesman problem. + +Usage: + +``` +python optimize_tsp.py --optimizer="text-bison" +``` + +Note: +- When using a Google-Cloud-served model (like text-bison at +https://developers.generativeai.google/tutorials/text_quickstart), add +`--palm_api_key=""` +- When using an OpenAI model, add `--openai_api_key=""` +""" + +import datetime +import functools +import getpass +import json +import os +import re +import sys +import itertools + +OPRO_ROOT_PATH = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +) +sys.path.insert(0, OPRO_ROOT_PATH) + +from absl import app +from absl import flags +import google.generativeai as palm +import numpy as np +import openai + +from opro import prompt_utils + + +_OPENAI_API_KEY = flags.DEFINE_string( + "openai_api_key", "", "The OpenAI API key." +) + +_PALM_API_KEY = flags.DEFINE_string("palm_api_key", "", "The PaLM API key.") + +_OPTIMIZER = flags.DEFINE_string( + "optimizer", "gpt-3.5-turbo", "The name of the optimizer LLM." +) + +_START_ALGORITHM = flags.DEFINE_string( + "starting_algorithm", "farthest_insertion", "The name of the starting algorithm. Select from [dp, nearest_neighbor, farthest_insertion]" +) + +def main(_): + # ============== set optimization experiment configurations ================ + num_points = 100 # number of points in TSP + num_steps = 500 # the number of optimization steps + max_num_pairs = 10 # the maximum number of input-output pairs in meta-prompt + num_decimals = 0 # num of decimals for distances in meta-prompt + num_starting_points = 5 # the number of initial points for optimization + num_decode_per_step = 8 # the number of decoded solutions per step + + # ================ load LLM settings =================== + optimizer_llm_name = _OPTIMIZER.value + assert optimizer_llm_name in { + "text-bison", + "gpt-3.5-turbo", + "gpt-4", + } + openai_api_key = _OPENAI_API_KEY.value + palm_api_key = _PALM_API_KEY.value + + if optimizer_llm_name in {"gpt-3.5-turbo", "gpt-4"}: + assert openai_api_key, "The OpenAI API key must be provided." + openai.api_key = openai_api_key + else: + assert optimizer_llm_name == "text-bison" + assert ( + palm_api_key + ), "A PaLM API key is needed when prompting the text-bison model." + palm.configure(api_key=palm_api_key) + + # =================== create the result directory ========================== + datetime_str = ( + str(datetime.datetime.now().replace(microsecond=0)) + .replace(" ", "-") + .replace(":", "-") + ) + + save_folder = os.path.join( + OPRO_ROOT_PATH, + "outputs", + "optimization-results", + f"tsp-o-{optimizer_llm_name}-{datetime_str}/", + ) + os.makedirs(save_folder) + print(f"result directory:\n{save_folder}") + + # ====================== optimizer model configs ============================ + if optimizer_llm_name.lower() == "text-bison": + # when prompting text-bison with Cloud API + optimizer_finetuned_palm_temperature = 1.0 + optimizer_finetuned_palm_max_decode_steps = 1024 + optimizer_finetuned_palm_batch_size = 1 + optimizer_finetuned_palm_num_servers = 1 + optimizer_finetuned_palm_dict = dict() + optimizer_finetuned_palm_dict["temperature"] = ( + optimizer_finetuned_palm_temperature + ) + optimizer_finetuned_palm_dict["batch_size"] = ( + optimizer_finetuned_palm_batch_size + ) + optimizer_finetuned_palm_dict["num_servers"] = ( + optimizer_finetuned_palm_num_servers + ) + optimizer_finetuned_palm_dict["max_decode_steps"] = ( + optimizer_finetuned_palm_max_decode_steps + ) + + call_optimizer_finetuned_palm_server_func = functools.partial( + prompt_utils.call_palm_server_from_cloud, + # prompt_utils.call_vllm, + model="text-bison-001", + temperature=optimizer_finetuned_palm_dict["temperature"], + max_decode_steps=optimizer_finetuned_palm_dict["max_decode_steps"], + ) + + optimizer_llm_dict = { + "model_type": optimizer_llm_name.lower(), + } + optimizer_llm_dict.update(optimizer_finetuned_palm_dict) + call_optimizer_server_func = call_optimizer_finetuned_palm_server_func + + else: + assert optimizer_llm_name in {"gpt-3.5-turbo", "gpt-4"} + optimizer_gpt_max_decode_steps = 1024 + optimizer_gpt_temperature = 1.0 + + optimizer_llm_dict = dict() + optimizer_llm_dict["max_decode_steps"] = optimizer_gpt_max_decode_steps + optimizer_llm_dict["temperature"] = optimizer_gpt_temperature + optimizer_llm_dict["batch_size"] = 1 + call_optimizer_server_func = functools.partial( + prompt_utils.call_openai_server_func, + model=optimizer_llm_name, + max_decode_steps=optimizer_gpt_max_decode_steps, + temperature=optimizer_gpt_temperature, + ) + + # ====================== try calling the servers ============================ + print("\n======== testing the optimizer server ===========") + optimizer_test_output = call_optimizer_server_func( + "Does the sun rise from the north? Just answer yes or no.", + temperature=1.0, + ) + print(f"optimizer test output: {optimizer_test_output}") + print("Finished testing the optimizer server.") + print("\n=================================================") + + # ====================== utility functions ============================ + def evaluate_distance(x, y, trace, num_decimals): # pylint: disable=invalid-name + dis = 0 + try: + for i in range(len(trace) - 1): + id0 = trace[i] + id1 = trace[i + 1] + dis += np.sqrt((x[id0] - x[id1]) ** 2 + (y[id0] - y[id1]) ** 2) + except: + return -1 + id0 = trace[-1] + id1 = trace[0] + dis += np.sqrt((x[id0] - x[id1]) ** 2 + (y[id0] - y[id1]) ** 2) + dis = np.round(dis, num_decimals) if num_decimals > 0 else int(dis) + return dis + + def solve_tsp(x, y, num_points, num_decimals, starting_algorithm): + if starting_algorithm == "nearest_neighbor": + min_dis = 0 + gt_sol = [0] + remaining_points = list(range(1, num_points)) + while len(remaining_points) > 0: + min_p = -1 + min_cur_dis = -1 + for p in remaining_points: + cur_dis = np.sqrt((x[p] - x[gt_sol[-1]]) ** 2 + (y[p] - y[gt_sol[-1]]) ** 2) + if min_p == -1 or cur_dis < min_cur_dis: + min_p = p + min_cur_dis = cur_dis + gt_sol.append(min_p) + min_dis += min_cur_dis + remaining_points.remove(min_p) + min_dis += np.sqrt((x[0] - x[gt_sol[-1]]) ** 2 + (y[0] - y[gt_sol[-1]]) ** 2) + min_dis = np.round(min_dis, num_decimals) if num_decimals > 0 else int(min_dis) + return gt_sol, min_dis + elif starting_algorithm == "farthest_insertion": + gt_sol = [0] + remaining_points = list(range(1, num_points)) + while len(remaining_points) > 0: + max_p = -1 + max_cur_dis = -1 + max_cur_index = -1 + for p in remaining_points: + min_cur_dis = -1 + min_cur_index = -1 + for index in range(1, len(gt_sol) + 1): + new_sol = gt_sol[:index] + [p] + gt_sol[index:] + cur_dis = evaluate_distance(x, y, new_sol, num_decimals) + if min_cur_dis == -1 or cur_dis < min_cur_dis: + min_cur_dis = cur_dis + min_cur_index = index + if max_cur_dis == -1 or min_cur_dis > max_cur_dis: + max_p = p + max_cur_dis = min_cur_dis + max_cur_index = min_cur_index + gt_sol = gt_sol[:max_cur_index] + [max_p] + gt_sol[max_cur_index:] + remaining_points.remove(max_p) + min_dis = evaluate_distance(x, y, gt_sol, num_decimals) + return gt_sol, min_dis + + f = {(0, 1): (0, [0])} + q = [(0, 1)] + min_dis = -1 + gt_sol = list(range(num_points)) + while len(q) > 0: + p, status = q[0] + q = q[1:] + for i in range(num_points): + if 2 << i >> 1 & status == 0: + new_status = status + (2 << i >> 1) + new_dis = f[(p, status)][0] + np.sqrt((x[i] - x[p]) ** 2 + (y[i] - y[p]) ** 2) + if (i, new_status) not in f or new_dis < f[(i, new_status)][0]: + f[(i, new_status)] = (new_dis, f[(p, status)][1] + [i]) + if new_status == (2 << num_points >> 1) - 1: + new_dis += np.sqrt((x[i] - x[0]) ** 2 + (y[i] - y[0]) ** 2) + if min_dis == -1 or new_dis < min_dis: + min_dis = new_dis + gt_sol = f[(i, new_status)][1][:] + elif (i, new_status) not in q: + q.append((i, new_status)) + min_dis = np.round(min_dis, num_decimals) if num_decimals > 0 else int(min_dis) + return gt_sol, min_dis + + def gen_meta_prompt( + old_value_pairs_set, + x, # pylint: disable=invalid-name + y, + max_num_pairs=100, + ): + """Generate the meta-prompt for optimization. + + Args: + old_value_pairs_set (set): the set of old traces. + X (np.array): the 1D array of x values. + y (np.array): the 1D array of y values. + num_decimals (int): the number of decimals in the + meta-prompt. + max_num_pairs (int): the maximum number of exemplars in the meta-prompt. + + Returns: + meta_prompt (str): the generated meta-prompt. + """ + old_value_pairs = list(old_value_pairs_set) + old_value_pairs = sorted(old_value_pairs, key=lambda x: -x[1])[ + -max_num_pairs: + ] + old_value_pairs_substr = "" + for trace, dis in old_value_pairs: + old_value_pairs_substr += f"\n {trace} \nlength:\n{dis}\n" + meta_prompt = "You are given a list of points with coordinates below:\n" + for i, (xi, yi) in enumerate(zip(x, y)): + if i: + meta_prompt += ", " + meta_prompt += f"({i}): ({xi}, {yi})" + meta_prompt += ".\n\nBelow are some previous traces and their lengths. The traces are arranged in descending order based on their lengths, where lower values are better.".strip() + meta_prompt += "\n\n" + meta_prompt += old_value_pairs_substr.strip() + meta_prompt += "\n\n" + meta_prompt += """Give me a new trace that is different from all traces above, and has a length lower than any of the above. The trace should traverse all points exactly once. The trace should start with '' and end with . + """.strip() + return meta_prompt + + def extract_string(input_string): + start_string = "" + end_string = "" + if start_string not in input_string: + return "" + input_string = input_string[input_string.index(start_string) + len(start_string):] + if end_string not in input_string: + return "" + input_string = input_string[:input_string.index(end_string)] + parsed_list = [] + for p in input_string.split(","): + p = p.strip() + try: + p = int(p) + except: + continue + parsed_list.append(p) + return parsed_list + + # ================= generate the ground truth trace ===================== + + x = np.random.uniform(low=-100, high=100, size=num_points) + y = np.random.uniform(low=-100, high=100, size=num_points) + x = [np.round(xi, num_decimals) if num_decimals > 0 else int(xi) for xi in x] + y = [np.round(yi, num_decimals) if num_decimals > 0 else int(yi) for yi in y] + + starting_algorithm = _START_ALGORITHM.value + + gt_sol, min_dis = solve_tsp(x, y, num_points, num_decimals, starting_algorithm) + print("ground truth solution" + str(gt_sol)) + print("min distance: ", min_dis) + gt_sol_str = ",".join([str(i) for i in gt_sol]) + point_list = range(num_points) + init_sols = [] + while len(init_sols) < num_starting_points: + sol = np.random.permutation(point_list) + if sol[0] != 0: + continue + sol_str = ",".join([str(i) for i in sol]) + if sol_str == gt_sol_str: + continue + init_sols.append(list(sol)) + + # ====================== run optimization ============================ + configs_dict = { + "num_starting_points": num_starting_points, + "num_decode_per_step": num_decode_per_step, + "optimizer_llm_configs": optimizer_llm_dict, + "data": { + "ground truth solution": [",".join([str(i) for i in gt_sol])], + "loss_at_true_values": min_dis, + "x": list(x), + "y": list(y), + }, + "init_sols": [",".join([str(i) for i in sol]) for sol in init_sols], + "num_steps": num_steps, + "max_num_pairs": max_num_pairs, + "num_decimals": num_decimals, + } + configs_json_path = os.path.join(save_folder, "configs.json") + print(f"saving configs to\n{configs_json_path}") + with open(configs_json_path, "w") as f: + json.dump(configs_dict, f, indent=4) + + old_value_pairs_set = set() + old_value_pairs_with_i_step = [] # format: [(trace, dis = f(trace), i_step)] + meta_prompts_dict = dict() # format: {i_step: meta_prompt} + raw_outputs_dict = dict() # format: {i_step: raw_outputs} + + for sol in init_sols: + dis = evaluate_distance(x, y, sol, num_decimals) + sol_str = ",".join([str(i) for i in sol]) + old_value_pairs_set.add((sol_str, dis)) + old_value_pairs_with_i_step.append((sol_str, dis, -1)) + + print("\n================ run optimization ==============") + print(f"initial points: {[tuple(item[:-1]) for item in old_value_pairs_set]}") + print(f"initial values: {[item[-1] for item in old_value_pairs_set]}") + results_json_path = os.path.join(save_folder, "results.json") + print(f"saving results to\n{results_json_path}") + + for i_step in range(num_steps): + print(f"\nStep {i_step}:") + meta_prompt = gen_meta_prompt( + old_value_pairs_set, + x, + y, + max_num_pairs=max_num_pairs, + ) + print("\n=================================================") + print(f"meta_prompt:\n{meta_prompt}") + meta_prompts_dict[i_step] = meta_prompt + raw_outputs = [] + parsed_outputs = [] + while len(parsed_outputs) < num_decode_per_step: + raw_output = call_optimizer_server_func(meta_prompt) + for string in raw_output: + print("\n=================================================") + print("raw output:\n", string) + try: + parsed_output = extract_string(string) + if parsed_output is not None and len(set(parsed_output)) == num_points and len(parsed_output) == num_points and parsed_output[0] == 0: + dis = evaluate_distance(x, y, parsed_output, num_decimals) + if dis == -1: + continue + parsed_outputs.append(parsed_output) + raw_outputs.append(string) + except: + pass + print("\n=================================================") + print(f"proposed points: {parsed_outputs}") + raw_outputs_dict[i_step] = raw_outputs + + # evaluate the values of proposed and rounded outputs + single_step_values = [] + for trace in parsed_outputs: + dis = evaluate_distance(x, y, trace, num_decimals) + single_step_values.append(dis) + trace_str = ",".join([str(i) for i in trace]) + old_value_pairs_set.add((trace_str, dis)) + old_value_pairs_with_i_step.append((trace_str, dis, i_step)) + print(f"single_step_values: {single_step_values}") + print("ground truth solution" + str(gt_sol)) + print("min distance: ", min_dis) + + # ====================== save results ============================ + results_dict = { + "meta_prompts": meta_prompts_dict, + "raw_outputs": raw_outputs_dict, + "old_value_pairs_with_i_step": old_value_pairs_with_i_step, + } + with open(results_json_path, "w") as f: + json.dump(results_dict, f, indent=4) + + +if __name__ == "__main__": + app.run(main) diff --git a/optimization/test.py b/optimization/test.py new file mode 100644 index 0000000..a38abcc --- /dev/null +++ b/optimization/test.py @@ -0,0 +1,967 @@ +# Copyright 2023 The OPRO Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r"""The .py file for prompt optimization. + +Usage: + +Step 1: edit the starting instructions by modifying `initial_instructions` + +Step 2: edit the training ratio by modifying `train_ratio` + +Step 3: check if the model configs (like batch size) are the same as the actual serving configs + +Step 4: run + +``` +python optimize_instructions.py \ + --optimizer="gpt-3.5-turbo" --scorer="text-bison" \ + --instruction_pos="A_begin" --dataset="gsm8k" --task="train" +``` + +The outputs will then be written to `outputs/optimization-results/` in the opro folder. + +Notes: + +1. One or more API keys may need to be provided: +- When using a Google-Cloud-served model (like text-bison at https://developers.generativeai.google/tutorials/text_quickstart), add `--palm_api_key=` +- When using an OpenAI model, add `--openai_api_key=””` + +2. The initial instructions should be provided in the "initial_instructions" +variable. +""" + +import datetime +import functools +import os +import sys + +OPRO_ROOT_PATH = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +) +sys.path.insert(0, OPRO_ROOT_PATH) + +from absl import app +from absl import flags +import google.generativeai as palm +import numpy as np +import openai +from opro import prompt_utils +from opro.optimization import opt_utils +import pandas as pd + +ROOT_DATA_FOLDER_PATH = os.path.join(OPRO_ROOT_PATH, "data") + +_LOCAL_MODEL_PATH = flags.DEFINE_string("local_model_path", None, "Path to local vLLM model.") + + +_OPENAI_API_KEY = flags.DEFINE_string( + "openai_api_key", "", "The OpenAI API key." +) + +_PALM_API_KEY = flags.DEFINE_string("palm_api_key", "", "The PaLM API key.") + +_SCORER = flags.DEFINE_string( + "scorer", "text-bison", "The name of the scorer LLM." +) + +_OPTIMIZER = flags.DEFINE_string( + "optimizer", "gpt-3.5-turbo", "The name of the optimizer LLM." +) + +_DATASET = flags.DEFINE_string( + "dataset", "gsm8k", "The name of dataset to search for instructions on." +) + +_TASK = flags.DEFINE_string( + "task", + "train", + "The name of task within the above dataset to search for instructions on.", +) + +_INSTRUCTION_POS = flags.DEFINE_string( + "instruction_pos", + "A_begin", + "The position of the instruction to search for.", +) + +_META_PROMPT_TYPE = flags.DEFINE_string( + "meta_prompt_type", + "both_instructions_and_exemplars", + "The type of meta-prompt: whether to have both previous instructions and" + " dataset exemplars (often for fine-tuned optimizers), or to have only" + " previous instructions (often for pre-trained optimizers).", +) + + +def main(_): + openai_api_key = _OPENAI_API_KEY.value + palm_api_key = _PALM_API_KEY.value + scorer_llm_name = _SCORER.value + optimizer_llm_name = _OPTIMIZER.value + dataset_name = _DATASET.value.lower() + task_name = _TASK.value + meta_prompt_type = _META_PROMPT_TYPE.value + local_model_path = _LOCAL_MODEL_PATH.value + + assert dataset_name in { + "mmlu", + "bbh", + "gsm8k", + }, "The lower-case dataset name must be one of mmlu, bbh, or gsm8k." + if dataset_name == "mmlu": + assert task_name in { + "STEM", + "humanities", + "social sciences", + "otheran (business, health, misc.)", + } # for now only support searching on one MMLU category + elif dataset_name == "bbh": + assert task_name in { + "boolean_expressions", + "causal_judgement", + "date_understanding", + "disambiguation_qa", + "dyck_languages", + "formal_fallacies", + "geometric_shapes", + "hyperbaton", + "logical_deduction_five_objects", + "logical_deduction_seven_objects", + "logical_deduction_three_objects", + "movie_recommendation", + "multistep_arithmetic_two", + "navigate", + "object_counting", + "penguins_in_a_table", + "reasoning_about_colored_objects", + "ruin_names", + "salient_translation_error_detection", + "snarks", + "sports_understanding", + "temporal_sequences", + "tracking_shuffled_objects_five_objects", + "tracking_shuffled_objects_seven_objects", + "tracking_shuffled_objects_three_objects", + "web_of_lies", + "word_sorting", + } + else: + assert dataset_name == "gsm8k" + assert task_name in {"train", "test"} + + assert scorer_llm_name in { + "text-bison", + "gpt-3.5-turbo", + "gpt-4", + "local", + } + assert optimizer_llm_name in { + "text-bison", + "gpt-3.5-turbo", + "gpt-4", + "local", + } + assert meta_prompt_type in { + "both_instructions_and_exemplars", + "instructions_only", + } + + instruction_pos = _INSTRUCTION_POS.value + assert instruction_pos in { + "before_Q", + "Q_begin", + "Q_end", + "A_begin", + }, ( + "The instruction position should be either before the question, or at the" + " beginning of the question, at the end of the question, or at the" + " beginning of the answer." + ) + print( + f"scorer: {scorer_llm_name}, optimizer: {optimizer_llm_name}, dataset:" + f" {dataset_name}, task: {task_name}, instruction_pos: {instruction_pos}" + ) + + # make sure the scorer and optimizer models are callable + + if scorer_llm_name in {"gpt-3.5-turbo", "gpt-4"}: + assert openai_api_key, "The OpenAI API key must be provided." + openai.api_key = openai_api_key + elif scorer_llm_name == "text-bison": + assert scorer_llm_name == "text-bison" + assert ( + palm_api_key + ), "A PaLM API key is needed when prompting the text-bison model." + palm.configure(api_key=palm_api_key) + + elif scorer_llm_name == "local": + assert local_model_path, "The local model path must be provided." + assert os.path.exists(local_model_path), ( + f"The local model path {local_model_path} does not exist." + ) + # set the local model path for vLLM + # prompt_utils.call_local_server_func(local_model_path) + else: + raise ValueError( + f"Unknown scorer_llm_name: {scorer_llm_name}. " + "It should be one of text-bison, gpt-3.5-turbo, gpt-4, or local." + ) + + if optimizer_llm_name in {"gpt-3.5-turbo", "gpt-4"}: + assert openai_api_key, "The OpenAI API key must be provided." + openai.api_key = openai_api_key + elif optimizer_llm_name == "text-bison": + assert optimizer_llm_name == "text-bison" + assert ( + palm_api_key + ), "A PaLM API key is needed when prompting the text-bison model." + palm.configure(api_key=palm_api_key) + + elif optimizer_llm_name == "local": + assert local_model_path, "The local model path must be provided." + assert os.path.exists(local_model_path), ( + f"The local model path {local_model_path} does not exist." + ) + # set the local model path for vLLM + # prompt_utils.call_local_server_func(local_model_path) + else: + raise ValueError( + f"Unknown scorer_llm_name: {optimizer_llm_name}. " + "It should be one of text-bison, gpt-3.5-turbo, gpt-4, or local." + ) + + + if dataset_name == "mmlu": + root_data_folder_path = os.path.join(ROOT_DATA_FOLDER_PATH, "MMLU-data") + elif dataset_name == "bbh": + root_data_folder_path = os.path.join( + ROOT_DATA_FOLDER_PATH, "BIG-Bench-Hard-data/" + ) + else: + assert dataset_name == "gsm8k" + root_data_folder_path = os.path.join(ROOT_DATA_FOLDER_PATH, "gsm_data") + + # =================== create the result directory ========================== + datetime_str = ( + str(datetime.datetime.now().replace(microsecond=0)) + .replace(" ", "-") + .replace(":", "-") + ) + + save_folder = os.path.join( + OPRO_ROOT_PATH, + "outputs", + "optimization-results", + f"{dataset_name.upper()}-{task_name}-s-{scorer_llm_name}-o-{optimizer_llm_name}-{datetime_str}/", + ) + result_by_instruction_folder = os.path.join( + save_folder, "result_by_instruction" + ) + os.makedirs(result_by_instruction_folder) + print(f"result directory:\n{save_folder}") + + # ====================== scorer model configs ============================== + # difference between num_decodes and batch_size: + # - num_decodes: how many outputs we actually want for each input + # - batch_size: the batch size in model serving, should equal to that in + # model serving config + # 常量定义 + DEFAULT_MAX_TOKENS = 1024 + DEFAULT_TEMPERATURE = 0.0 + PALM_MODEL_NAME = "text-bison-001" + + if scorer_llm_name == "text-bison": + config = { + "temperature": DEFAULT_TEMPERATURE, + "max_decode_steps": DEFAULT_MAX_TOKENS, + "batch_size": 1, + "num_servers": 1, + } + call_scorer_server_func = functools.partial( + prompt_utils.call_palm_server_from_cloud, + model=PALM_MODEL_NAME, + **config + ) + scorer_llm_dict = {"model_type": "text-bison", **config} + + elif scorer_llm_name in {"gpt-3.5-turbo", "gpt-4"}: + config = { + "temperature": DEFAULT_TEMPERATURE, + "max_decode_steps": DEFAULT_MAX_TOKENS, + "batch_size": 1, + "num_servers": 1, + } + call_scorer_server_func = functools.partial( + prompt_utils.call_openai_server_func, + model=scorer_llm_name.lower(), + **config + ) + scorer_llm_dict = {"model_type": scorer_llm_name.lower(), **config} + + elif scorer_llm_name == "local": + print(f"[DEBUG] local_model_path: {local_model_path}") + assert local_model_path, "Local model path must be provided." + config = { + "temperature": DEFAULT_TEMPERATURE, + "max_decode_steps": DEFAULT_MAX_TOKENS, + "batch_size": 8, + "num_servers": 8,# number of servers to use for local model + } + call_scorer_server_func = functools.partial( + prompt_utils.call_local_server_func, + local_model_path=local_model_path, + **config + ) + scorer_llm_dict = {"model_type": "local", **config} + + else: + raise ValueError(f"Unsupported model: {scorer_llm_name}") + +# if scorer_llm_name == "text-bison": +# # when prompting text-bison with Cloud API +# scorer_finetuned_palm_temperature = 0.0 +# scorer_finetuned_palm_max_decode_steps = 1024 +# scorer_finetuned_palm_batch_size = 1 +# scorer_finetuned_palm_num_servers = 1 +# scorer_finetuned_palm_dict = dict() +# scorer_finetuned_palm_dict["temperature"] = ( +# scorer_finetuned_palm_temperature +# ) +# scorer_finetuned_palm_dict["num_servers"] = ( +# scorer_finetuned_palm_num_servers +# ) +# scorer_finetuned_palm_dict["batch_size"] = scorer_finetuned_palm_batch_size +# scorer_finetuned_palm_dict["max_decode_steps"] = ( +# scorer_finetuned_palm_max_decode_steps +# ) + +# call_scorer_finetuned_palm_server_func = functools.partial( +# prompt_utils.call_palm_server_from_cloud, +# model="text-bison-001", +# temperature=scorer_finetuned_palm_dict["temperature"], +# max_decode_steps=scorer_finetuned_palm_dict["max_decode_steps"], +# ) + +# scorer_llm_dict = { +# "model_type": scorer_llm_name.lower(), +# } +# scorer_llm_dict.update(scorer_finetuned_palm_dict) +# call_scorer_server_func = call_scorer_finetuned_palm_server_func + +# elif scorer_llm_name in {"gpt-3.5-turbo", "gpt-4"}: +# # assert scorer_llm_name.lower() in {"gpt-3.5-turbo", "gpt-4"} +# scorer_gpt_max_decode_steps = 1024 +# scorer_gpt_temperature = 0.0 + +# scorer_gpt_dict = dict() +# scorer_gpt_dict["max_decode_steps"] = scorer_gpt_max_decode_steps +# scorer_gpt_dict["temperature"] = scorer_gpt_temperature +# scorer_gpt_dict["num_decodes"] = 1 +# scorer_gpt_dict["batch_size"] = 1 +# scorer_gpt_dict["num_servers"] = 1 + +# scorer_llm_dict = { +# "model_type": scorer_llm_name.lower(), +# } +# scorer_llm_dict.update(scorer_gpt_dict) +# call_scorer_server_func = functools.partial( +# prompt_utils.call_openai_server_func, +# model=scorer_llm_name.lower(), +# max_decode_steps=scorer_gpt_max_decode_steps, +# temperature=scorer_gpt_temperature, +# ) +# elif scorer_llm_name == "local": +# # local vLLM model +# scorer_local_max_decode_steps = 1024 +# scorer_local_temperature = 0.0 +# call_scorer_server_func = functools.partial( +# prompt_utils.call_local_model_server_func, +# model_path=local_model_path, +# max_decode_steps=scorer_local_max_decode_steps, +# temperature=scorer_local_temperature, +# ) +# else: +# raise ValueError( +# f"Unknown scorer_llm_name: {scorer_llm_name}. " +# "It should be one of text-bison, gpt-3.5-turbo, gpt-4, or local." +# ) + + # ====================== optimizer model configs ============================ + if optimizer_llm_name.lower() == "text-bison": + # PaLM text-bison 模型配置 + optimizer_llm_dict = { + "model_type": "text-bison", + "temperature": 1.0, # 更高的随机性以生成多样化解 + "max_decode_steps": 1024, # 最大生成长度 + "batch_size": 1, # 单样本处理 + "num_decodes": 8, # 生成8个候选结果 + "num_servers": 1 # 单服务器 + } + + call_optimizer_server_func = functools.partial( + prompt_utils.call_palm_server_from_cloud, + model="text-bison-001", + temperature=optimizer_llm_dict["temperature"], + max_decode_steps=optimizer_llm_dict["max_decode_steps"], + ) + + elif optimizer_llm_name.lower() in {"gpt-3.5-turbo", "gpt-4"}: + # GPT 模型配置 + optimizer_llm_dict = { + "model_type": optimizer_llm_name.lower(), + "temperature": 1.0, # 更高的随机性 + "max_decode_steps": 512, # 较短的最大长度 + "batch_size": 1, + "num_decodes": 1 , # 单次生成 + "num_servers": 1 # 单服务器 + } + + call_optimizer_server_func = functools.partial( + prompt_utils.call_openai_server_func, + model=optimizer_llm_name, + max_decode_steps=optimizer_llm_dict["max_decode_steps"], + temperature=optimizer_llm_dict["temperature"], + ) + elif optimizer_llm_name.lower() == "local": + assert local_model_path, "Local model path must be provided." + optimizer_llm_dict = { + "model_type": optimizer_llm_name.lower(), + "temperature": 1.0, # 更高的随机性 + "max_decode_steps": 512, # 较短的最大长度 + "batch_size": 8, + "num_decodes": 1 , # 单次生成 + "num_servers": 8 # 单服务器 + } + call_optimizer_server_func = functools.partial( + prompt_utils.call_local_server_func, + local_model_path=local_model_path, + max_decode_steps=optimizer_llm_dict["max_decode_steps"], + temperature=optimizer_llm_dict["temperature"], + ) + + else: + raise ValueError( + f"Unsupported optimizer model: {optimizer_llm_name}. " + "Must be one of: text-bison, gpt-3.5-turbo, gpt-4" + ) + +# if optimizer_llm_name.lower() == "text-bison": +# # when prompting text-bison with Cloud API +# optimizer_finetuned_palm_temperature = 1.0 +# optimizer_finetuned_palm_num_decodes = 8 +# optimizer_finetuned_palm_max_decode_steps = 1024 +# optimizer_finetuned_palm_batch_size = 1 +# optimizer_finetuned_palm_num_servers = 1 +# optimizer_finetuned_palm_dict = dict() +# optimizer_finetuned_palm_dict["temperature"] = ( +# optimizer_finetuned_palm_temperature +# ) +# optimizer_finetuned_palm_dict["num_decodes"] = ( +# optimizer_finetuned_palm_num_decodes +# ) +# optimizer_finetuned_palm_dict["batch_size"] = ( +# optimizer_finetuned_palm_batch_size +# ) +# optimizer_finetuned_palm_dict["num_servers"] = ( +# optimizer_finetuned_palm_num_servers +# ) +# optimizer_finetuned_palm_dict["max_decode_steps"] = ( +# optimizer_finetuned_palm_max_decode_steps +# ) + +# call_optimizer_finetuned_palm_server_func = functools.partial( +# prompt_utils.call_palm_server_from_cloud, +# model="text-bison-001", +# temperature=optimizer_finetuned_palm_dict["temperature"], +# max_decode_steps=optimizer_finetuned_palm_dict["max_decode_steps"], +# ) + +# optimizer_llm_dict = { +# "model_type": optimizer_llm_name.lower(), +# } +# optimizer_llm_dict.update(optimizer_finetuned_palm_dict) +# call_optimizer_server_func = call_optimizer_finetuned_palm_server_func + +# else: +# assert optimizer_llm_name in {"gpt-3.5-turbo", "gpt-4"} +# optimizer_gpt_max_decode_steps = 512 +# optimizer_gpt_temperature = 1.0 + +# optimizer_llm_dict = dict() +# optimizer_llm_dict["max_decode_steps"] = optimizer_gpt_max_decode_steps +# optimizer_llm_dict["temperature"] = optimizer_gpt_temperature +# optimizer_llm_dict["batch_size"] = 1 +# optimizer_llm_dict["num_decodes"] = 1 +# call_optimizer_server_func = functools.partial( +# prompt_utils.call_openai_server_func, +# model=optimizer_llm_name, +# max_decode_steps=optimizer_gpt_max_decode_steps, +# temperature=optimizer_gpt_temperature, +# ) + + # ====================== try calling the servers ============================ + print("\n======== testing the scorer and optimizer servers ===========") + scorer_test_output = call_scorer_server_func( + "Does the sun rise from the north? Just answer yes or no." + ) + print(f"number of scorer output decodes: {len(scorer_test_output)}") + print(f"scorer test output: {scorer_test_output}") + optimizer_test_output = call_optimizer_server_func( + "Does the sun rise from the north? Just answer yes or no.", + temperature=1.0, + ) + print(f"number of optimizer output decodes: {len(optimizer_test_output)}") + print(f"optimizer test output: {optimizer_test_output}") + print("Finished testing the servers.") + + # ====================== read data ============================ + print("\n================ prompt optimization settings ==============") + # from https://github.com/hendrycks/test/blob/master/categories.py + subcategories = { + "abstract_algebra": ["math"], + "anatomy": ["health"], + "astronomy": ["physics"], + "business_ethics": ["business"], + "clinical_knowledge": ["health"], + "college_biology": ["biology"], + "college_chemistry": ["chemistry"], + "college_computer_science": ["computer science"], + "college_mathematics": ["math"], + "college_medicine": ["health"], + "college_physics": ["physics"], + "computer_security": ["computer science"], + "conceptual_physics": ["physics"], + "econometrics": ["economics"], + "electrical_engineering": ["engineering"], + "elementary_mathematics": ["math"], + "formal_logic": ["philosophy"], + "global_facts": ["other"], + "high_school_biology": ["biology"], + "high_school_chemistry": ["chemistry"], + "high_school_computer_science": ["computer science"], + "high_school_european_history": ["history"], + "high_school_geography": ["geography"], + "high_school_government_and_politics": ["politics"], + "high_school_macroeconomics": ["economics"], + "high_school_mathematics": ["math"], + "high_school_microeconomics": ["economics"], + "high_school_physics": ["physics"], + "high_school_psychology": ["psychology"], + "high_school_statistics": ["math"], + "high_school_us_history": ["history"], + "high_school_world_history": ["history"], + "human_aging": ["health"], + "human_sexuality": ["culture"], + "international_law": ["law"], + "jurisprudence": ["law"], + "logical_fallacies": ["philosophy"], + "machine_learning": ["computer science"], + "management": ["business"], + "marketing": ["business"], + "medical_genetics": ["health"], + "miscellaneous": ["other"], + "moral_disputes": ["philosophy"], + "moral_scenarios": ["philosophy"], + "nutrition": ["health"], + "philosophy": ["philosophy"], + "prehistory": ["history"], + "professional_accounting": ["other"], + "professional_law": ["law"], + "professional_medicine": ["health"], + "professional_psychology": ["psychology"], + "public_relations": ["politics"], + "security_studies": ["politics"], + "sociology": ["culture"], + "us_foreign_policy": ["politics"], + "virology": ["health"], + "world_religions": ["philosophy"], + } + + categories = { + "STEM": [ + "physics", + "chemistry", + "biology", + "computer science", + "math", + "engineering", + ], + "humanities": ["history", "philosophy", "law"], + "social sciences": [ + "politics", + "culture", + "economics", + "geography", + "psychology", + ], + "other (business, health, misc.)": ["other", "business", "health"], + } + + if dataset_name == "mmlu": + # EITHER: filter by category + # category_names = [ + # "STEM", + # "humanities", + # "social sciences", + # "other (business, health, misc.)", + # ] + category_names = [task_name] + folder_name = "test" # one of {'auxiliary_train', 'dev', 'val', 'test'} + task_names = [] + for task_csv_name in os.listdir( + os.path.join(root_data_folder_path, folder_name) + ): + task_names.append(task_csv_name.split(".")[0]) + + tasks_in_category = [] + for category_name in category_names: + for task_name in task_names: + for subname in subcategories: + if subname in task_name: + if subcategories[subname][0] in categories[category_name]: + tasks_in_category.append(task_name) + break + + tasks_all = [(folder_name, task_name) for task_name in tasks_in_category] + multiple_choice_tasks = set([item[1] for item in tasks_all]) + boolean_tasks = set() + numerical_output_tasks = set() + + # OR: filter by task + # tasks_all = [ + # # ('test', 'abstract_algebra_test'), + # # ('test', 'college_computer_science_test'), + # # ('test', 'college_mathematics_test'), + # # ('test', 'college_physics_test'), + # # ('test', 'elementary_mathematics_test'), + # # ('test', 'global_facts_test'), + # # ('test', 'high_school_physics_test'), + # # ('test', 'machine_learning_test'), + # # ('test', 'management_test'), + # # ('test', 'medical_genetics_test'), + # # ('test', 'moral_scenarios_test'), + # # ('test', 'professional_psychology_test'), + # # ('test', 'public_relations_test'), + # # ('test', 'professional_law_test'), + # # ('test', 'high_school_psychology_test'), + # # ('test', 'high_school_world_history_test'), + # # ('test', 'human_aging_test'), + # # ('test', 'miscellaneous_test'), + # # ('test', 'moral_scenarios_test'), + # ('test', 'professional_psychology_test'), + # # ('test', 'security_studies_test'), + # ] + + elif dataset_name == "bbh": + tasks_all = [task_name] + assert ( + len(tasks_all) == 1 + ), "for now only support prompt optimization on one BBH task" + + # all BBH tasks are as below + # tasks_all = [ + # 'boolean_expressions', + # 'causal_judgement', + # 'date_understanding', + # 'disambiguation_qa', + # 'dyck_languages', + # 'formal_fallacies', + # 'geometric_shapes', + # 'hyperbaton', + # 'logical_deduction_five_objects', + # 'logical_deduction_seven_objects', + # 'logical_deduction_three_objects', + # 'movie_recommendation', + # 'multistep_arithmetic_two', + # 'navigate', + # 'object_counting', + # 'penguins_in_a_table', + # 'reasoning_about_colored_objects', + # 'ruin_names', + # 'salient_translation_error_detection', + # 'snarks', + # 'sports_understanding', + # 'temporal_sequences', + # 'tracking_shuffled_objects_five_objects', + # 'tracking_shuffled_objects_seven_objects', + # 'tracking_shuffled_objects_three_objects', + # 'web_of_lies', + # 'word_sorting' + # ] + numerical_output_tasks = { + "object_counting", + "multistep_arithmetic_two", + } + + multiple_choice_tasks = { + "date_understanding", + "disambiguation_qa", + "geometric_shapes", + "hyperbaton", + "logical_deduction_five_objects", + "logical_deduction_seven_objects", + "logical_deduction_three_objects", + "movie_recommendation", + "penguins_in_a_table", + "reasoning_about_colored_objects", + "ruin_names", + "salient_translation_error_detection", + "snarks", + "temporal_sequences", + "tracking_shuffled_objects_five_objects", + "tracking_shuffled_objects_seven_objects", + "tracking_shuffled_objects_three_objects", + } + + boolean_tasks = { + "boolean_expressions", # True or False + "causal_judgement", # yes or no + "formal_fallacies", # valid or invalid + "navigate", # yes or no + "sports_understanding", # yes or no + "web_of_lies", # yes or no + } + + else: + assert dataset_name in {"gsm8k"} + tasks_all = [task_name] + multiple_choice_tasks = set() + boolean_tasks = set() + numerical_output_tasks = set(tasks_all) + + if dataset_name == "mmlu": + raw_data = pd.DataFrame() + prediction_treat_as_number = False + prediction_treat_as_bool = False + elif dataset_name == "bbh": + raw_data = [] + prediction_treat_as_number = bool( + tasks_all[0] in numerical_output_tasks + ) # for now only check the first task + prediction_treat_as_bool = bool( + tasks_all[0] in boolean_tasks + ) # for now only check the first task + print( + f"prediction_treat_as_number: {prediction_treat_as_number}," + f" prediction_treat_as_bool: {prediction_treat_as_bool}" + ) + else: + assert dataset_name == "gsm8k" + raw_data = pd.DataFrame() + prediction_treat_as_number = True + prediction_treat_as_bool = False + + for t in tasks_all: + if dataset_name == "mmlu": + folder_name = t[0] + task_name = t[1] + single_task_df = pd.read_csv( + os.path.join(root_data_folder_path, f"{folder_name}/{task_name}.csv"), + index_col=None, + header=None, + ) + raw_data = pd.concat([raw_data, single_task_df]) + elif dataset_name == "bbh": + task_name = t + single_task_list = opt_utils.load_bbh_task_data( + task_name, base_dir=root_data_folder_path + ) + raw_data += single_task_list + else: + assert dataset_name == "gsm8k" + task_name = t + f_gsm = os.path.join(root_data_folder_path, f"gsm_{task_name}.tsv") + single_task_df = pd.read_csv(f_gsm, sep="\t", header=None) + raw_data = pd.concat([raw_data, single_task_df]) + + if dataset_name == "mmlu": + num_examples = raw_data.shape[0] + elif dataset_name == "bbh": + num_examples = len(raw_data) + else: + assert dataset_name in {"gsm8k"} + num_examples = raw_data.shape[0] + print(f"number of examples in the current task: {num_examples}") + + # ================ split data into train/val/test ========================== + if dataset_name == "mmlu": + train_ratio = 0.8 + eval_ratio = 0.2 + elif dataset_name == "gsm8k": + train_ratio = 0.035 + eval_ratio = 0 + else: + assert dataset_name == "bbh" + train_ratio = 0.2 + eval_ratio = 0 + + # train-validation-test split + # It is important to sort the indices, as this ensures the is_multiple_choice + # Boolean variables match the data points. + assert train_ratio + eval_ratio <= 1 + test_ratio = 1 - train_ratio - eval_ratio + print( + f"train_ratio: {train_ratio}, eval_ratio: {eval_ratio}, " + f"test_ratio: {test_ratio}" + ) + np.random.seed(0) + train_index = np.sort( + np.array( + np.random.choice( + num_examples, size=int(train_ratio * num_examples), replace=False + ) + ) + ) + eval_and_test_index = np.sort( + np.array(list(set(np.arange(num_examples)) - set(train_index))) + ) + eval_index = np.sort( + np.array( + np.random.choice( + eval_and_test_index, + size=int(eval_ratio * num_examples), + replace=False, + ) + ) + ) + + # ========== set other optimization experiment hyperparameters ============== + if scorer_llm_name == "text-bison": + old_instruction_score_threshold = 0.0 # 完全保留旧指令 表示不过滤任何历史指令(即使质量很低的旧指令也会保留)。 + # old_instruction_score_threshold = 0.15 # for GSM8K + elif scorer_llm_name == "local": + old_instruction_score_threshold = 0.3 + else: + assert scorer_llm_name in {"gpt-3.5-turbo", "gpt-4"} # 模型校验 + old_instruction_score_threshold = 0.3 # 过滤低质量旧指令 + + if scorer_llm_name == "text-bison": + extract_final_answer_by_prompting_again = False # 是否通过二次提示提取最终答案(例如从冗长响应中提取关键内容) + include_qa = False # 是否在元提示中包含问答对 + evaluate_in_parallel = False # 是否并行评估 + elif scorer_llm_name == "local": + extract_final_answer_by_prompting_again = True + include_qa = True + evaluate_in_parallel = True + else: + assert scorer_llm_name in {"gpt-3.5-turbo", "gpt-4"} + extract_final_answer_by_prompting_again = False + include_qa = False + evaluate_in_parallel = False + + optimizer_llm_temperature = optimizer_llm_dict["temperature"] + + num_few_shot_questions_for_instruction_refinement = 3 # number of few-shot questions 每次优化指令时参考的少样本示例数量(Few-shot QA对)。 + + # To change the number of generated instructions in each step, one should + # edit the value of the variable below, instead of editing the number of + # decodes in model parameters, because those values are limited by model + # serving configs. + num_generated_instructions_in_each_step = 3 # number of generated instructions in each step 每轮搜索生成的候选指令数量。 + num_search_steps = 50 # number of search steps 总优化迭代次数。 + + initial_instructions = [ + "Let's solve the problem.", + # "", + # "The answer is", + ] + few_shot_qa_pairs = True #是否使用少样本示例指导指令生成。 + # one of {'accumulative_most_frequent', 'current_most_frequent', 'random', + # 'constant'} + few_shot_selection_criteria = "random" #对多样性要求高时用 random,稳定性要求高时用 most_frequent。 + # whether to evaluate generated instructions on the exemplars in meta-prompt + evaluate_generated_ins_on_few_shot = False # 是否评估新指令 开发阶段设为 True调试指令质量。 + # whether to evaluate old instructions on the exemplars in the meta-prompt + evaluate_old_ins_on_few_shot = False # 是否评估旧指令 生产阶段设为 False加速运行。 + # every this number of steps, compute the accuracies of current-step + # instructions on the validation set + eval_interval = 3 # 每N步在验证集上测试当前指令的准确率。 + + max_num_instructions = ( + 20 # 元提示中保留的历史指令数量上限。 + ) + # 将连续分数离散化为N档(如0-100整数),简化模型理解。 + num_score_buckets = 100 + # whether to put old instructions and scores to before exemplars in + # 控制元提示中历史指令和少样本示例的顺序。 + meta_prompt_instructions_before_exemplars = True + + # ===================== run prompt optimization ====================== + + assert few_shot_selection_criteria in { + "accumulative_most_frequent", + "current_most_frequent", + "random", + "constant", + } + evolution_kwargs = { + "num_search_steps": num_search_steps, + "old_instruction_score_threshold": old_instruction_score_threshold, + "scorer_llm_dict": scorer_llm_dict, + "optimizer_llm_dict": optimizer_llm_dict, + "extract_final_answer_by_prompting_again": ( + extract_final_answer_by_prompting_again + ), + "include_qa": include_qa, + "evaluate_in_parallel": evaluate_in_parallel, + "tasks_all": tasks_all, + "train_ratio": train_ratio, + "eval_ratio": eval_ratio, + "test_ratio": test_ratio, + "train_index": train_index, + "eval_index": eval_index, + "dataset_name": dataset_name, + "task_name": task_name, + "num_examples": num_examples, + "root_data_folder_path": root_data_folder_path, + "optimizer_llm_temperature": optimizer_llm_temperature, + # "optimizer_llm_temperature_schedule": ( + # optimizer_llm_temperature_schedule + # ), + # "optimizer_llm_temperature_end": optimizer_llm_temperature_end, + "initial_instructions": initial_instructions, + "multiple_choice_tasks": multiple_choice_tasks, + "raw_data": raw_data, + "call_scorer_server_func": call_scorer_server_func, + "call_optimizer_server_func": call_optimizer_server_func, + "instruction_pos": instruction_pos, + "prediction_treat_as_number": prediction_treat_as_number, + "prediction_treat_as_bool": prediction_treat_as_bool, + "result_by_instruction_folder": result_by_instruction_folder, + "few_shot_qa_pairs": few_shot_qa_pairs, + "num_score_buckets": num_score_buckets, + "max_num_instructions": max_num_instructions, + "meta_prompt_type": meta_prompt_type, + "meta_prompt_instructions_before_exemplars": ( + meta_prompt_instructions_before_exemplars + ), + "few_shot_selection_criteria": few_shot_selection_criteria, + "optimizer_llm_name": optimizer_llm_name, + "num_generated_instructions_in_each_step": ( + num_generated_instructions_in_each_step + ), + "evaluate_generated_ins_on_few_shot": evaluate_generated_ins_on_few_shot, + "num_few_shot_questions_for_instruction_refinement": ( + num_few_shot_questions_for_instruction_refinement + ), + "evaluate_old_ins_on_few_shot": evaluate_old_ins_on_few_shot, + "eval_interval": eval_interval, + "save_folder": save_folder, + } + + opt_utils.run_evolution(**evolution_kwargs) + + +if __name__ == "__main__": + app.run(main) diff --git a/outputs/user_feedback.jsonl b/outputs/user_feedback.jsonl new file mode 100644 index 0000000..e504227 --- /dev/null +++ b/outputs/user_feedback.jsonl @@ -0,0 +1,2 @@ +{"session_id": "06e84e4b03c54df188961f4e033f9e77", "round": 1, "choice": "哪些奶茶含有较少的糖分?", "answer": "选择低糖奶茶时,可以通过品牌选择、糖分选项调整以及配料搭配来控制摄入的糖分。以下是一些具体建议和推荐:\n\n---\n\n### **一、低糖奶茶品牌及选项**\n1. **喜茶(HEYTEA)** \n - **推荐系列**:轻乳茶(如「轻乳茉莉奶绿」)、部分限定款(如「芝士奶盖茶」可选低糖)。 \n - **糖分控制**:默认糖度为50%,可选「三分糖」(约10%糖度)或「无糖」(使用代糖)。 \n - **注意**:奶盖本身含糖,建议选择「无奶盖」选项。\n\n2. **奈雪的茶** \n - **推荐系列**:鲜果茶(如「霸气橙子」)、「鲜果轻乳茶」系列。 \n - **糖分控制**:默认糖度为50%,可选「少糖」(约30%)或「无糖」(使用代糖)。 \n - **注意**:鲜果茶的天然果糖含量较高,需结合糖分选项综合判断。\n\n3. **CoCo都可** \n - **推荐系列**:「鲜果茶」(如「百香果鲜果茶」)、「轻乳茶」。 \n - **糖分控制**:默认全糖(约100%),可选「三分糖」(约30%)或「无糖」(代糖)。 \n - **注意**:部分经典奶茶(如「珍珠奶茶」)糖分较高,建议选择鲜果类饮品。\n\n4. **一点点** \n - **推荐系列**:「鲜果茶」(如「百香果鲜果茶」)、「轻乳茶」。 \n - **糖分控制**:默认全糖(约100%),可选「半糖」(50%)或「无糖」(代糖)。 \n - **注意**:部分奶盖茶含糖量较高,建议选择鲜果或轻乳系列。\n\n---\n\n### **二、如何选择低糖奶茶?**\n1. **糖分选项优先级** \n - **无糖/代糖**:使用赤藓糖醇、甜菊糖等代糖,几乎不含热量(但需注意部分人对代糖敏感)。 \n - **三分糖/半糖**:糖分约为全糖的30%-50%,适合控制摄入。 \n - **避免全糖**:全糖奶茶的糖分通常在"} +{"session_id": "af370656137a4b8f97d8ffac2c60b191", "round": 1, "choice": "瑞幸咖啡有哪些值得推荐的饮品?", "answer": "瑞幸咖啡近年来凭借高性价比和创新口味,推出了许多受欢迎的饮品。以下是一些值得推荐的饮品,涵盖经典款、特色新品和季节限定,适合不同口味和需求:\n\n---\n\n### **一、爆款必试系列**\n1. **生椰拿铁** \n - **推荐理由**:瑞幸的“现象级”产品,椰香浓郁与咖啡的苦香完美融合,口感顺滑,常年占据销量榜首。 \n - **小贴士**:可尝试“生椰拿铁+巧克力糖浆”组合,变成“生椰摩卡”,风味更丰富。\n\n2. **丝绒拿铁** \n - **推荐理由**:以红茶为基底,搭配丝滑牛奶,茶香与奶香平衡,适合喜欢茶咖融合的人群。 \n - **特色**:选用锡兰红茶,口感更醇厚,冷热皆宜。\n\n3. **厚乳拿铁** \n - **推荐理由**:使用厚乳(高乳脂含量的牛奶),奶香更浓郁,适合追求绵密口感的爱好者。\n\n---\n\n### **二、果味与创意系列**\n1. **冰椰拿铁**(夏季限定) \n - **推荐理由**:生椰拿铁的冰饮版本,加入冰块和椰香糖浆,清爽解暑,适合夏天。\n\n2. **蓝莓生椰拿铁** \n - **推荐理由**:在生椰拿铁基础上加入蓝莓糖浆,果香与椰香交织,甜而不腻。\n\n3. **蜜桃生椰拿铁** \n - **推荐理由**:蜜桃风味糖浆与生椰拿1:1搭配,清新果香与咖啡的苦香碰撞,适合喜欢果味的人。\n\n---\n\n### **三、季节限定款**\n1. **桂花拿铁**(秋季限定) \n - **推荐理由**:桂花糖浆与拿铁结合,香气扑鼻,甜度适中,是"} diff --git a/prompt_utils.py b/prompt_utils.py new file mode 100644 index 0000000..d60aaf2 --- /dev/null +++ b/prompt_utils.py @@ -0,0 +1,220 @@ +# Copyright 2023 The OPRO Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The utility functions for prompting GPT and Google Cloud models.""" +import openai +import time +try: + import google.generativeai as palm +except Exception: + palm = None +try: + from vllm import LLM, SamplingParams +except Exception: + LLM = None + SamplingParams = None +from pathlib import Path + + +# 缓存 vLLM 实例,避免重复加载 +_llm_instance = None + +def get_llm(local_model_path): + if LLM is None: + raise RuntimeError("vLLM not available") + global _llm_instance + if _llm_instance is None: + assert local_model_path is not None, "model_path cannot be None" + local_model_path = str(Path(local_model_path).resolve()) + _llm_instance = LLM( + model=local_model_path, + dtype="bfloat16", + tensor_parallel_size=8, + max_num_batched_tokens=8192, + max_num_seqs=64, + gpu_memory_utilization=0.7, + enforce_eager=True, + block_size=16, + enable_chunked_prefill=True, + trust_remote_code=True, + ) + return _llm_instance + +def call_local_server_single_prompt(prompt, local_model_path=None, temperature=0.8, max_decode_steps=512, **kwargs): + """ + 使用本地 vLLM 模型生成单个 prompt 的响应,替代原本 OpenAI API。 + """ + llm = get_llm(local_model_path) + sampling_params = SamplingParams( + temperature=temperature, + top_p=0.9, + max_tokens=max_decode_steps, + skip_special_tokens=True # 避免特殊字符触发协议错误 + ) + outputs = llm.generate([prompt], sampling_params) + return outputs[0].outputs[0].text + +def call_local_server_func(inputs, local_model_path=None, temperature=0.8, max_decode_steps=512, **kwargs): + """ + 批量处理多个输入 prompt。 + """ + + assert local_model_path is not None, "local_model_path must be provided" + # 强制类型检查 + if isinstance(inputs, bytes): + inputs = inputs.decode('utf-8') + outputs = [] + for input_str in inputs: + output = call_local_server_single_prompt( + input_str, + local_model_path=local_model_path, + temperature=temperature, + max_decode_steps=max_decode_steps + ) + outputs.append(output) + return outputs + + +def call_openai_server_single_prompt( + prompt, model="gpt-3.5-turbo", max_decode_steps=20, temperature=0.8 +): + """The function to call OpenAI server with an input string.""" + try: + completion = openai.ChatCompletion.create( + model=model, + temperature=temperature, + max_tokens=max_decode_steps, + messages=[ + {"role": "user", "content": prompt}, + ], + ) + return completion.choices[0].message.content +# 函数捕获了 ​​6 类常见异常​​,并在遇到错误时自动重试: + except openai.error.Timeout as e: # API 请求超时 + retry_time = e.retry_after if hasattr(e, "retry_after") else 30 + print(f"Timeout error occurred. Retrying in {retry_time} seconds...") + time.sleep(retry_time) + return call_openai_server_single_prompt( + prompt, max_decode_steps=max_decode_steps, temperature=temperature + ) + + except openai.error.RateLimitError as e: #请求频率超限(Rate Limit) + retry_time = e.retry_after if hasattr(e, "retry_after") else 30 + print(f"Rate limit exceeded. Retrying in {retry_time} seconds...") + time.sleep(retry_time) + return call_openai_server_single_prompt( + prompt, max_decode_steps=max_decode_steps, temperature=temperature + ) + + except openai.error.APIError as e: # API 错误(如服务器错误) + retry_time = e.retry_after if hasattr(e, "retry_after") else 30 + print(f"API error occurred. Retrying in {retry_time} seconds...") + time.sleep(retry_time) + return call_openai_server_single_prompt( + prompt, max_decode_steps=max_decode_steps, temperature=temperature + ) + + except openai.error.APIConnectionError as e: # API 连接错误(如网络问题) + retry_time = e.retry_after if hasattr(e, "retry_after") else 30 + print(f"API connection error occurred. Retrying in {retry_time} seconds...") + time.sleep(retry_time) + return call_openai_server_single_prompt( + prompt, max_decode_steps=max_decode_steps, temperature=temperature + ) + + except openai.error.ServiceUnavailableError as e: # 服务不可用(如服务器维护) + retry_time = e.retry_after if hasattr(e, "retry_after") else 30 + print(f"Service unavailable. Retrying in {retry_time} seconds...") + time.sleep(retry_time) + return call_openai_server_single_prompt( + prompt, max_decode_steps=max_decode_steps, temperature=temperature + ) + + except OSError as e: # 操作系统级连接错误(如网络中断) + retry_time = 5 # Adjust the retry time as needed + print( + f"Connection error occurred: {e}. Retrying in {retry_time} seconds..." + ) + time.sleep(retry_time) + return call_openai_server_single_prompt( + prompt, max_decode_steps=max_decode_steps, temperature=temperature + ) + + +def call_openai_server_func( #​​批量处理多个输入提示(prompts),通过 OpenAI API 并行或顺序获取多个生成结果​​。 + inputs, model="gpt-3.5-turbo", max_decode_steps=20, temperature=0.8 +): + """The function to call OpenAI server with a list of input strings.""" + if isinstance(inputs, str): # 将单个字符串转为列表,统一处理 + inputs = [inputs] + outputs = [] + for input_str in inputs: + output = call_openai_server_single_prompt( + input_str, + model=model, + max_decode_steps=max_decode_steps, + temperature=temperature, + ) + outputs.append(output) + return outputs + +#通过 Google PaLM API(Cloud 版)调用 text-bison模型生成文本​​,并包含基本错误处理和自动重试机制。 +def call_palm_server_from_cloud( + input_text, model="text-bison-001", max_decode_steps=20, temperature=0.8 +): + if palm is None: + raise RuntimeError("google.generativeai not available") + assert isinstance(input_text, str) + assert model == "text-bison-001" + all_model_names = [ + m + for m in palm.list_models() + if "generateText" in m.supported_generation_methods + ] + model_name = all_model_names[0].name + try: + completion = palm.generate_text( + model=model_name, + prompt=input_text, + temperature=temperature, + max_output_tokens=max_decode_steps, + ) + output_text = completion.result + return [output_text] + except Exception: + retry_time = 10 + time.sleep(retry_time) + return call_palm_server_from_cloud( + input_text, max_decode_steps=max_decode_steps, temperature=temperature + ) + +def refine_instruction(query: str) -> str: + return f""" +你是一个“问题澄清与重写助手”。 +请根据用户的原始问题: +【{query}】 +生成不少于20条多角度、可直接执行的问题改写,每行一条。 +""" + +def refine_instruction_with_history(query: str, rejected_list: list) -> str: + rejected_text = "\n".join(f"- {r}" for r in rejected_list) if rejected_list else "" + return f""" +你是一个“问题澄清与重写助手”。 +原始问题: +{query} + +以下改写已被否定: +{rejected_text} + +请从新的角度重新生成至少20条不同的改写问题,每条单独一行。 +""" diff --git a/session_state.py b/session_state.py new file mode 100644 index 0000000..08b78ad --- /dev/null +++ b/session_state.py @@ -0,0 +1,26 @@ +import uuid + +SESSIONS = {} + +def create_session(query: str) -> str: + sid = uuid.uuid4().hex + SESSIONS[sid] = { + "original_query": query, + "round": 0, + "history_candidates": [], + "user_feedback": [] + } + return sid + +def get_session(sid: str): + return SESSIONS.get(sid) + +def update_session_add_candidates(sid: str, candidates: list): + s = SESSIONS[sid] + s["round"] += 1 + s["history_candidates"].extend(candidates) + +def log_user_choice(sid: str, choice: str): + SESSIONS[sid]["user_feedback"].append( + {"round": SESSIONS[sid]["round"], "choice": choice} + ) diff --git a/user_prompt_optimizer.py b/user_prompt_optimizer.py new file mode 100644 index 0000000..933b929 --- /dev/null +++ b/user_prompt_optimizer.py @@ -0,0 +1,52 @@ +import re +import numpy as np +from sklearn.cluster import AgglomerativeClustering +from sklearn.metrics.pairwise import cosine_similarity + +from opro.ollama_client import call_qwen +from opro.xinference_client import embed_texts +from opro.prompt_utils import refine_instruction, refine_instruction_with_history + +def parse_candidates(raw: str) -> list: + lines = [l.strip() for l in re.split(r'\r?\n', raw) if l.strip()] + cleaned = [] + for l in lines: + l = re.sub(r'^[\-\*\d\.\)\s]+', '', l).strip() + if len(l) >= 6: + cleaned.append(l) + return list(dict.fromkeys(cleaned)) + +def cluster_and_select(candidates: list, top_k=5, distance_threshold=0.15): + if not candidates: + return [] + vecs = embed_texts(candidates) + X = np.array(vecs) + if len(candidates) <= top_k: + return candidates + + clustering = AgglomerativeClustering(n_clusters=None, + distance_threshold=distance_threshold, + metric="cosine", + linkage="average") + labels = clustering.fit_predict(X) + + selected_idx = [] + for label in sorted(set(labels)): + idxs = [i for i,l in enumerate(labels) if l == label] + sims = cosine_similarity(X[idxs]).mean(axis=1) + rep = idxs[int(np.argmax(sims))] + selected_idx.append(rep) + + selected = [candidates[i] for i in sorted(selected_idx)] + return selected[:top_k] + +def generate_candidates(query: str, rejected=None, top_k=5): + rejected = rejected or [] + if rejected: + prompt = refine_instruction_with_history(query, rejected) + else: + prompt = refine_instruction(query) + + raw = call_qwen(prompt, temperature=0.9, max_tokens=512) + all_candidates = parse_candidates(raw) + return cluster_and_select(all_candidates, top_k=top_k) diff --git a/xinference_client.py b/xinference_client.py new file mode 100644 index 0000000..dc9029f --- /dev/null +++ b/xinference_client.py @@ -0,0 +1,11 @@ +import requests +from typing import List + +XINFERENCE_EMBED_URL = "http://127.0.0.1:9997/models/bge-base-zh/embed" + +def embed_texts(texts: List[str]) -> List[List[float]]: + payload = {"inputs": texts} + resp = requests.post(XINFERENCE_EMBED_URL, json=payload, timeout=30) + resp.raise_for_status() + data = resp.json() + return data.get("embeddings", [])