原始代码
This commit is contained in:
BIN
_qwen_xinference_demo/__pycache__/api.cpython-310.pyc
Normal file
BIN
_qwen_xinference_demo/__pycache__/api.cpython-310.pyc
Normal file
Binary file not shown.
BIN
_qwen_xinference_demo/__pycache__/api.cpython-313.pyc
Normal file
BIN
_qwen_xinference_demo/__pycache__/api.cpython-313.pyc
Normal file
Binary file not shown.
284
_qwen_xinference_demo/api.py
Normal file
284
_qwen_xinference_demo/api.py
Normal file
@@ -0,0 +1,284 @@
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from pydantic import BaseModel
|
||||
import config
|
||||
|
||||
from .opro.session_state import create_session, get_session, update_session_add_candidates, log_user_choice
|
||||
from .opro.session_state import log_user_reject
|
||||
from .opro.session_state import set_selected_prompt, log_chat_message
|
||||
from .opro.session_state import set_session_model
|
||||
from .opro.session_state import USER_FEEDBACK_LOG
|
||||
from .opro.user_prompt_optimizer import generate_candidates
|
||||
from .opro.ollama_client import call_qwen
|
||||
from .opro.ollama_client import list_models
|
||||
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
app = FastAPI(
|
||||
title=config.APP_TITLE,
|
||||
description=config.APP_DESCRIPTION,
|
||||
version=config.APP_VERSION,
|
||||
contact=config.APP_CONTACT,
|
||||
openapi_tags=[
|
||||
{"name": "health", "description": "健康检查"},
|
||||
{"name": "models", "description": "模型列表与设置"},
|
||||
{"name": "sessions", "description": "会话管理"},
|
||||
{"name": "opro", "description": "提示优化候选生成与选择/拒绝"},
|
||||
{"name": "chat", "description": "会话聊天"},
|
||||
{"name": "ui", "description": "静态页面"}
|
||||
]
|
||||
)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
MAX_ROUNDS = 3
|
||||
|
||||
def ok(data=None):
|
||||
return JSONResponse({"success": True, "data": data})
|
||||
|
||||
class AppException(HTTPException):
|
||||
def __init__(self, status_code: int, detail: str, error_code: str):
|
||||
super().__init__(status_code=status_code, detail=detail)
|
||||
self.error_code = error_code
|
||||
|
||||
@app.exception_handler(AppException)
|
||||
def _app_exc_handler(request: Request, exc: AppException):
|
||||
return JSONResponse(status_code=exc.status_code, content={
|
||||
"success": False,
|
||||
"error": str(exc.detail),
|
||||
"error_code": exc.error_code
|
||||
})
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
def _http_exc_handler(request: Request, exc: HTTPException):
|
||||
return JSONResponse(status_code=exc.status_code, content={
|
||||
"success": False,
|
||||
"error": str(exc.detail),
|
||||
"error_code": "HTTP_ERROR"
|
||||
})
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
def _generic_exc_handler(request: Request, exc: Exception):
|
||||
return JSONResponse(status_code=500, content={
|
||||
"success": False,
|
||||
"error": "internal error",
|
||||
"error_code": "INTERNAL_ERROR"
|
||||
})
|
||||
|
||||
class StartReq(BaseModel):
|
||||
query: str
|
||||
|
||||
class NextReq(BaseModel):
|
||||
session_id: str
|
||||
|
||||
class SelectReq(BaseModel):
|
||||
session_id: str
|
||||
choice: str
|
||||
|
||||
class RejectReq(BaseModel):
|
||||
session_id: str
|
||||
candidate: str
|
||||
reason: str | None = None
|
||||
|
||||
class SetModelReq(BaseModel):
|
||||
session_id: str
|
||||
model_name: str
|
||||
|
||||
@app.post("/start", tags=["opro"])
|
||||
def start(req: StartReq):
|
||||
sid = create_session(req.query)
|
||||
cands = generate_candidates(req.query, [], model_name=get_session(sid).get("model_name"))
|
||||
update_session_add_candidates(sid, cands)
|
||||
return ok({"session_id": sid, "round": 0, "candidates": cands})
|
||||
|
||||
@app.post("/next", tags=["opro"])
|
||||
def next_round(req: NextReq):
|
||||
s = get_session(req.session_id)
|
||||
if not s:
|
||||
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
|
||||
|
||||
if s["round"] >= MAX_ROUNDS:
|
||||
ans = call_qwen(s["original_query"], temperature=0.3, max_tokens=512)
|
||||
return ok({"final": True, "answer": ans})
|
||||
|
||||
cands = generate_candidates(s["original_query"], s["history_candidates"], model_name=s.get("model_name"))
|
||||
update_session_add_candidates(req.session_id, cands)
|
||||
return ok({"session_id": req.session_id, "round": s["round"], "candidates": cands})
|
||||
|
||||
@app.post("/select", tags=["opro"])
|
||||
def select(req: SelectReq):
|
||||
s = get_session(req.session_id)
|
||||
if not s:
|
||||
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
|
||||
|
||||
log_user_choice(req.session_id, req.choice)
|
||||
set_selected_prompt(req.session_id, req.choice)
|
||||
log_chat_message(req.session_id, "system", req.choice)
|
||||
try:
|
||||
ans = call_qwen(req.choice, temperature=0.2, max_tokens=1024, model_name=s.get("model_name"))
|
||||
except Exception as e:
|
||||
raise AppException(400, f"ollama error: {e}", "OLLAMA_ERROR")
|
||||
log_chat_message(req.session_id, "assistant", ans)
|
||||
try:
|
||||
import os, json
|
||||
os.makedirs("outputs", exist_ok=True)
|
||||
with open("outputs/user_feedback.jsonl", "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps({
|
||||
"session_id": req.session_id,
|
||||
"round": s["round"],
|
||||
"choice": req.choice,
|
||||
"answer": ans
|
||||
}, ensure_ascii=False) + "\n")
|
||||
except Exception:
|
||||
pass
|
||||
return ok({"prompt": req.choice, "answer": ans})
|
||||
|
||||
@app.post("/reject", tags=["opro"])
|
||||
def reject(req: RejectReq):
|
||||
s = get_session(req.session_id)
|
||||
if not s:
|
||||
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
|
||||
log_user_reject(req.session_id, req.candidate, req.reason)
|
||||
cands = generate_candidates(s["original_query"], s["history_candidates"] + [req.candidate], model_name=s.get("model_name"))
|
||||
update_session_add_candidates(req.session_id, cands)
|
||||
return ok({"session_id": req.session_id, "round": s["round"], "candidates": cands})
|
||||
class QueryReq(BaseModel):
|
||||
query: str
|
||||
session_id: str | None = None
|
||||
|
||||
@app.post("/query", tags=["opro"])
|
||||
def query(req: QueryReq):
|
||||
if req.session_id:
|
||||
s = get_session(req.session_id)
|
||||
if not s:
|
||||
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
|
||||
cands = generate_candidates(s["original_query"], s["history_candidates"], model_name=s.get("model_name"))
|
||||
update_session_add_candidates(req.session_id, cands)
|
||||
return ok({"session_id": req.session_id, "round": s["round"], "candidates": cands})
|
||||
else:
|
||||
sid = create_session(req.query)
|
||||
log_chat_message(sid, "user", req.query)
|
||||
cands = generate_candidates(req.query, [], model_name=get_session(sid).get("model_name"))
|
||||
update_session_add_candidates(sid, cands)
|
||||
return ok({"session_id": sid, "round": 0, "candidates": cands})
|
||||
app.mount("/ui", StaticFiles(directory="frontend", html=True), name="static")
|
||||
|
||||
@app.get("/", tags=["ui"])
|
||||
def root():
|
||||
return RedirectResponse(url="/ui/")
|
||||
|
||||
@app.get("/health", tags=["health"])
|
||||
def health():
|
||||
return ok({"status": "ok", "version": config.APP_VERSION})
|
||||
|
||||
@app.get("/version", tags=["health"])
|
||||
def version():
|
||||
return ok({"version": config.APP_VERSION})
|
||||
|
||||
# @app.get("/ui/react", tags=["ui"])
|
||||
# def ui_react():
|
||||
# return FileResponse("frontend/react/index.html")
|
||||
|
||||
# @app.get("/ui/offline", tags=["ui"])
|
||||
# def ui_offline():
|
||||
# return FileResponse("frontend/ui_offline.html")
|
||||
|
||||
@app.get("/react", tags=["ui"])
|
||||
def react_root():
|
||||
return FileResponse("frontend/react/index.html")
|
||||
|
||||
@app.get("/sessions", tags=["sessions"])
|
||||
def sessions():
|
||||
from .opro.session_state import SESSIONS
|
||||
return ok({"sessions": [{
|
||||
"session_id": sid,
|
||||
"round": s.get("round", 0),
|
||||
"selected_prompt": s.get("selected_prompt"),
|
||||
"original_query": s.get("original_query")
|
||||
} for sid, s in SESSIONS.items()]})
|
||||
|
||||
@app.get("/session/{sid}", tags=["sessions"])
|
||||
def session_detail(sid: str):
|
||||
s = get_session(sid)
|
||||
if not s:
|
||||
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
|
||||
return ok({
|
||||
"session_id": sid,
|
||||
"round": s["round"],
|
||||
"original_query": s["original_query"],
|
||||
"selected_prompt": s["selected_prompt"],
|
||||
"candidates": s["history_candidates"],
|
||||
"user_feedback": s["user_feedback"],
|
||||
"rejected": s["rejected"],
|
||||
"history": s["chat_history"],
|
||||
})
|
||||
|
||||
class MessageReq(BaseModel):
|
||||
session_id: str
|
||||
message: str
|
||||
|
||||
@app.post("/message", tags=["chat"])
|
||||
def message(req: MessageReq):
|
||||
s = get_session(req.session_id)
|
||||
if not s:
|
||||
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
|
||||
log_chat_message(req.session_id, "user", req.message)
|
||||
base_prompt = s.get("selected_prompt") or s["original_query"]
|
||||
full_prompt = base_prompt + "\n\n" + req.message
|
||||
try:
|
||||
ans = call_qwen(full_prompt, temperature=0.3, max_tokens=1024, model_name=s.get("model_name"))
|
||||
except Exception as e:
|
||||
raise AppException(400, f"ollama error: {e}", "OLLAMA_ERROR")
|
||||
log_chat_message(req.session_id, "assistant", ans)
|
||||
return ok({"session_id": req.session_id, "answer": ans, "history": s["chat_history"]})
|
||||
|
||||
class QueryFromMsgReq(BaseModel):
|
||||
session_id: str
|
||||
|
||||
@app.post("/query_from_message", tags=["opro"])
|
||||
def query_from_message(req: QueryFromMsgReq):
|
||||
s = get_session(req.session_id)
|
||||
if not s:
|
||||
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
|
||||
last_user = None
|
||||
for m in reversed(s.get("chat_history", [])):
|
||||
if m.get("role") == "user" and m.get("content"):
|
||||
last_user = m["content"]
|
||||
break
|
||||
base = last_user or s["original_query"]
|
||||
cands = generate_candidates(base, s["history_candidates"], model_name=s.get("model_name"))
|
||||
update_session_add_candidates(req.session_id, cands)
|
||||
return ok({"session_id": req.session_id, "round": s["round"], "candidates": cands})
|
||||
|
||||
class AnswerReq(BaseModel):
|
||||
query: str
|
||||
|
||||
@app.post("/answer", tags=["opro"])
|
||||
def answer(req: AnswerReq):
|
||||
sid = create_session(req.query)
|
||||
log_chat_message(sid, "user", req.query)
|
||||
ans = call_qwen(req.query, temperature=0.2, max_tokens=1024)
|
||||
log_chat_message(sid, "assistant", ans)
|
||||
cands = generate_candidates(req.query, [])
|
||||
update_session_add_candidates(sid, cands)
|
||||
return ok({"session_id": sid, "answer": ans, "candidates": cands})
|
||||
|
||||
@app.get("/models", tags=["models"])
|
||||
def models():
|
||||
return ok({"models": list_models()})
|
||||
|
||||
@app.post("/set_model", tags=["models"])
|
||||
def set_model(req: SetModelReq):
|
||||
s = get_session(req.session_id)
|
||||
if not s:
|
||||
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
|
||||
avail = set(list_models() or [])
|
||||
if req.model_name not in avail:
|
||||
raise AppException(400, f"model not available: {req.model_name}", "MODEL_NOT_AVAILABLE")
|
||||
set_session_model(req.session_id, req.model_name)
|
||||
return ok({"session_id": req.session_id, "model_name": req.model_name})
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
52
_qwen_xinference_demo/opro/ollama_client.py
Normal file
52
_qwen_xinference_demo/opro/ollama_client.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import requests
|
||||
import re
|
||||
import config
|
||||
|
||||
OLLAMA_URL = config.OLLAMA_GENERATE_URL
|
||||
TAGS_URL = config.OLLAMA_TAGS_URL
|
||||
MODEL_NAME = config.DEFAULT_CHAT_MODEL
|
||||
|
||||
def call_qwen(prompt: str, temperature: float = 0.8, max_tokens: int = 512, model_name: str | None = None) -> str:
|
||||
def _payload(m: str):
|
||||
return {
|
||||
"model": m,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"temperature": temperature,
|
||||
"num_predict": max_tokens
|
||||
}
|
||||
}
|
||||
primary = model_name or MODEL_NAME
|
||||
try:
|
||||
resp = requests.post(OLLAMA_URL, json=_payload(primary), timeout=60)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return data.get("response", "") or data.get("text", "")
|
||||
except requests.HTTPError as e:
|
||||
# Try fallback to default when user-selected model fails
|
||||
if model_name and model_name != MODEL_NAME:
|
||||
try:
|
||||
resp = requests.post(OLLAMA_URL, json=_payload(MODEL_NAME), timeout=60)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return data.get("response", "") or data.get("text", "")
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
|
||||
def list_models() -> list[str]:
|
||||
try:
|
||||
r = requests.get(TAGS_URL, timeout=10)
|
||||
r.raise_for_status()
|
||||
data = r.json() or {}
|
||||
items = data.get("models") or []
|
||||
names = []
|
||||
for m in items:
|
||||
name = m.get("name") or m.get("model")
|
||||
if name:
|
||||
names.append(name)
|
||||
names = [n for n in names if not re.search(r"embedding|rerank|reranker|bge", n, re.I)]
|
||||
return names
|
||||
except Exception:
|
||||
return [MODEL_NAME]
|
||||
20
_qwen_xinference_demo/opro/prompt_utils.py
Normal file
20
_qwen_xinference_demo/opro/prompt_utils.py
Normal file
@@ -0,0 +1,20 @@
|
||||
def refine_instruction(query: str) -> str:
|
||||
return f"""
|
||||
你是一个“问题澄清与重写助手”。
|
||||
请根据用户的原始问题:
|
||||
【{query}】
|
||||
生成不少于20条多角度、可直接执行的问题改写,每行一条。
|
||||
"""
|
||||
|
||||
def refine_instruction_with_history(query: str, rejected_list: list) -> str:
|
||||
rejected_text = "\n".join(f"- {r}" for r in rejected_list) if rejected_list else ""
|
||||
return f"""
|
||||
你是一个“问题澄清与重写助手”。
|
||||
原始问题:
|
||||
{query}
|
||||
|
||||
以下改写已被否定:
|
||||
{rejected_text}
|
||||
|
||||
请从新的角度重新生成至少20条不同的改写问题,每条单独一行。
|
||||
"""
|
||||
56
_qwen_xinference_demo/opro/session_state.py
Normal file
56
_qwen_xinference_demo/opro/session_state.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import uuid
|
||||
|
||||
SESSIONS = {}
|
||||
USER_FEEDBACK_LOG = []
|
||||
|
||||
def create_session(query: str) -> str:
|
||||
sid = uuid.uuid4().hex
|
||||
SESSIONS[sid] = {
|
||||
"original_query": query,
|
||||
"round": 0,
|
||||
"history_candidates": [],
|
||||
"user_feedback": [],
|
||||
"rejected": [],
|
||||
"selected_prompt": None,
|
||||
"chat_history": [],
|
||||
"model_name": None
|
||||
}
|
||||
return sid
|
||||
|
||||
def get_session(sid: str):
|
||||
return SESSIONS.get(sid)
|
||||
|
||||
def update_session_add_candidates(sid: str, candidates: list):
|
||||
s = SESSIONS[sid]
|
||||
s["round"] += 1
|
||||
s["history_candidates"].extend(candidates)
|
||||
|
||||
def log_user_choice(sid: str, choice: str):
|
||||
SESSIONS[sid]["user_feedback"].append(
|
||||
{"round": SESSIONS[sid]["round"], "choice": choice}
|
||||
)
|
||||
USER_FEEDBACK_LOG.append({
|
||||
"session_id": sid,
|
||||
"round": SESSIONS[sid]["round"],
|
||||
"choice": choice
|
||||
})
|
||||
|
||||
def log_user_reject(sid: str, candidate: str, reason: str | None = None):
|
||||
SESSIONS[sid]["rejected"].append(candidate)
|
||||
USER_FEEDBACK_LOG.append({
|
||||
"session_id": sid,
|
||||
"round": SESSIONS[sid]["round"],
|
||||
"reject": candidate,
|
||||
"reason": reason or ""
|
||||
})
|
||||
|
||||
def set_selected_prompt(sid: str, prompt: str):
|
||||
SESSIONS[sid]["selected_prompt"] = prompt
|
||||
|
||||
def log_chat_message(sid: str, role: str, content: str):
|
||||
SESSIONS[sid]["chat_history"].append({"role": role, "content": content})
|
||||
|
||||
def set_session_model(sid: str, model_name: str | None):
|
||||
s = SESSIONS.get(sid)
|
||||
if s is not None:
|
||||
s["model_name"] = model_name
|
||||
55
_qwen_xinference_demo/opro/user_prompt_optimizer.py
Normal file
55
_qwen_xinference_demo/opro/user_prompt_optimizer.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import re
|
||||
import numpy as np
|
||||
from sklearn.cluster import AgglomerativeClustering
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
import config
|
||||
|
||||
from .ollama_client import call_qwen
|
||||
from .xinference_client import embed_texts
|
||||
from .prompt_utils import refine_instruction, refine_instruction_with_history
|
||||
|
||||
def parse_candidates(raw: str) -> list:
|
||||
lines = [l.strip() for l in re.split(r'\r?\n', raw) if l.strip()]
|
||||
cleaned = []
|
||||
for l in lines:
|
||||
l = re.sub(r'^[\-\*\d\.\)\s]+', '', l).strip()
|
||||
if len(l) >= 6:
|
||||
cleaned.append(l)
|
||||
return list(dict.fromkeys(cleaned))
|
||||
|
||||
def cluster_and_select(candidates: list, top_k=config.TOP_K, distance_threshold=config.CLUSTER_DISTANCE_THRESHOLD):
|
||||
if not candidates:
|
||||
return []
|
||||
if len(candidates) <= top_k:
|
||||
return candidates
|
||||
vecs = embed_texts(candidates)
|
||||
if not vecs or len(vecs) != len(candidates):
|
||||
return candidates[:top_k]
|
||||
X = np.array(vecs)
|
||||
|
||||
clustering = AgglomerativeClustering(n_clusters=None,
|
||||
distance_threshold=distance_threshold,
|
||||
metric="cosine",
|
||||
linkage="average")
|
||||
labels = clustering.fit_predict(X)
|
||||
|
||||
selected_idx = []
|
||||
for label in sorted(set(labels)):
|
||||
idxs = [i for i,l in enumerate(labels) if l == label]
|
||||
sims = cosine_similarity(X[idxs]).mean(axis=1)
|
||||
rep = idxs[int(np.argmax(sims))]
|
||||
selected_idx.append(rep)
|
||||
|
||||
selected = [candidates[i] for i in sorted(selected_idx)]
|
||||
return selected[:top_k]
|
||||
|
||||
def generate_candidates(query: str, rejected=None, top_k=config.TOP_K, model_name=None):
|
||||
rejected = rejected or []
|
||||
if rejected:
|
||||
prompt = refine_instruction_with_history(query, rejected)
|
||||
else:
|
||||
prompt = refine_instruction(query)
|
||||
|
||||
raw = call_qwen(prompt, temperature=0.9, max_tokens=1024, model_name=model_name)
|
||||
all_candidates = parse_candidates(raw)
|
||||
return cluster_and_select(all_candidates, top_k=top_k)
|
||||
29
_qwen_xinference_demo/opro/xinference_client.py
Normal file
29
_qwen_xinference_demo/opro/xinference_client.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import requests
|
||||
from typing import List
|
||||
import config
|
||||
|
||||
XINFERENCE_EMBED_URL = config.XINFERENCE_EMBED_URL
|
||||
OLLAMA_EMBED_URL = config.OLLAMA_HOST + "/api/embeddings"
|
||||
|
||||
def embed_texts(texts: List[str]) -> List[List[float]]:
|
||||
payload = {"inputs": texts}
|
||||
try:
|
||||
resp = requests.post(XINFERENCE_EMBED_URL, json=payload, timeout=10)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
embs = data.get("embeddings", [])
|
||||
if embs:
|
||||
return embs
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
payload2 = {"model": config.DEFAULT_EMBED_MODEL, "input": texts}
|
||||
resp2 = requests.post(OLLAMA_EMBED_URL, json=payload2, timeout=15)
|
||||
resp2.raise_for_status()
|
||||
data2 = resp2.json()
|
||||
if isinstance(data2, dict) and "data" in data2:
|
||||
return [item.get("embedding", []) for item in data2["data"]]
|
||||
return data2.get("embeddings", [])
|
||||
except Exception:
|
||||
return []
|
||||
Reference in New Issue
Block a user