Files
opro_demo/_qwen_xinference_demo/api.py

285 lines
10 KiB
Python
Raw Normal View History

2025-12-05 07:11:25 +00:00
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import config
from .opro.session_state import create_session, get_session, update_session_add_candidates, log_user_choice
from .opro.session_state import log_user_reject
from .opro.session_state import set_selected_prompt, log_chat_message
from .opro.session_state import set_session_model
from .opro.session_state import USER_FEEDBACK_LOG
from .opro.user_prompt_optimizer import generate_candidates
from .opro.ollama_client import call_qwen
from .opro.ollama_client import list_models
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI(
title=config.APP_TITLE,
description=config.APP_DESCRIPTION,
version=config.APP_VERSION,
contact=config.APP_CONTACT,
openapi_tags=[
{"name": "health", "description": "健康检查"},
{"name": "models", "description": "模型列表与设置"},
{"name": "sessions", "description": "会话管理"},
{"name": "opro", "description": "提示优化候选生成与选择/拒绝"},
{"name": "chat", "description": "会话聊天"},
{"name": "ui", "description": "静态页面"}
]
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
MAX_ROUNDS = 3
def ok(data=None):
return JSONResponse({"success": True, "data": data})
class AppException(HTTPException):
def __init__(self, status_code: int, detail: str, error_code: str):
super().__init__(status_code=status_code, detail=detail)
self.error_code = error_code
@app.exception_handler(AppException)
def _app_exc_handler(request: Request, exc: AppException):
return JSONResponse(status_code=exc.status_code, content={
"success": False,
"error": str(exc.detail),
"error_code": exc.error_code
})
@app.exception_handler(HTTPException)
def _http_exc_handler(request: Request, exc: HTTPException):
return JSONResponse(status_code=exc.status_code, content={
"success": False,
"error": str(exc.detail),
"error_code": "HTTP_ERROR"
})
@app.exception_handler(Exception)
def _generic_exc_handler(request: Request, exc: Exception):
return JSONResponse(status_code=500, content={
"success": False,
"error": "internal error",
"error_code": "INTERNAL_ERROR"
})
class StartReq(BaseModel):
query: str
class NextReq(BaseModel):
session_id: str
class SelectReq(BaseModel):
session_id: str
choice: str
class RejectReq(BaseModel):
session_id: str
candidate: str
reason: str | None = None
class SetModelReq(BaseModel):
session_id: str
model_name: str
@app.post("/start", tags=["opro"])
def start(req: StartReq):
sid = create_session(req.query)
cands = generate_candidates(req.query, [], model_name=get_session(sid).get("model_name"))
update_session_add_candidates(sid, cands)
return ok({"session_id": sid, "round": 0, "candidates": cands})
@app.post("/next", tags=["opro"])
def next_round(req: NextReq):
s = get_session(req.session_id)
if not s:
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
if s["round"] >= MAX_ROUNDS:
ans = call_qwen(s["original_query"], temperature=0.3, max_tokens=512)
return ok({"final": True, "answer": ans})
cands = generate_candidates(s["original_query"], s["history_candidates"], model_name=s.get("model_name"))
update_session_add_candidates(req.session_id, cands)
return ok({"session_id": req.session_id, "round": s["round"], "candidates": cands})
@app.post("/select", tags=["opro"])
def select(req: SelectReq):
s = get_session(req.session_id)
if not s:
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
log_user_choice(req.session_id, req.choice)
set_selected_prompt(req.session_id, req.choice)
log_chat_message(req.session_id, "system", req.choice)
try:
ans = call_qwen(req.choice, temperature=0.2, max_tokens=1024, model_name=s.get("model_name"))
except Exception as e:
raise AppException(400, f"ollama error: {e}", "OLLAMA_ERROR")
log_chat_message(req.session_id, "assistant", ans)
try:
import os, json
os.makedirs("outputs", exist_ok=True)
with open("outputs/user_feedback.jsonl", "a", encoding="utf-8") as f:
f.write(json.dumps({
"session_id": req.session_id,
"round": s["round"],
"choice": req.choice,
"answer": ans
}, ensure_ascii=False) + "\n")
except Exception:
pass
return ok({"prompt": req.choice, "answer": ans})
@app.post("/reject", tags=["opro"])
def reject(req: RejectReq):
s = get_session(req.session_id)
if not s:
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
log_user_reject(req.session_id, req.candidate, req.reason)
cands = generate_candidates(s["original_query"], s["history_candidates"] + [req.candidate], model_name=s.get("model_name"))
update_session_add_candidates(req.session_id, cands)
return ok({"session_id": req.session_id, "round": s["round"], "candidates": cands})
class QueryReq(BaseModel):
query: str
session_id: str | None = None
@app.post("/query", tags=["opro"])
def query(req: QueryReq):
if req.session_id:
s = get_session(req.session_id)
if not s:
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
cands = generate_candidates(s["original_query"], s["history_candidates"], model_name=s.get("model_name"))
update_session_add_candidates(req.session_id, cands)
return ok({"session_id": req.session_id, "round": s["round"], "candidates": cands})
else:
sid = create_session(req.query)
log_chat_message(sid, "user", req.query)
cands = generate_candidates(req.query, [], model_name=get_session(sid).get("model_name"))
update_session_add_candidates(sid, cands)
return ok({"session_id": sid, "round": 0, "candidates": cands})
app.mount("/ui", StaticFiles(directory="frontend", html=True), name="static")
@app.get("/", tags=["ui"])
def root():
return RedirectResponse(url="/ui/")
@app.get("/health", tags=["health"])
def health():
return ok({"status": "ok", "version": config.APP_VERSION})
@app.get("/version", tags=["health"])
def version():
return ok({"version": config.APP_VERSION})
# @app.get("/ui/react", tags=["ui"])
# def ui_react():
# return FileResponse("frontend/react/index.html")
# @app.get("/ui/offline", tags=["ui"])
# def ui_offline():
# return FileResponse("frontend/ui_offline.html")
@app.get("/react", tags=["ui"])
def react_root():
return FileResponse("frontend/react/index.html")
@app.get("/sessions", tags=["sessions"])
def sessions():
from .opro.session_state import SESSIONS
return ok({"sessions": [{
"session_id": sid,
"round": s.get("round", 0),
"selected_prompt": s.get("selected_prompt"),
"original_query": s.get("original_query")
} for sid, s in SESSIONS.items()]})
@app.get("/session/{sid}", tags=["sessions"])
def session_detail(sid: str):
s = get_session(sid)
if not s:
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
return ok({
"session_id": sid,
"round": s["round"],
"original_query": s["original_query"],
"selected_prompt": s["selected_prompt"],
"candidates": s["history_candidates"],
"user_feedback": s["user_feedback"],
"rejected": s["rejected"],
"history": s["chat_history"],
})
class MessageReq(BaseModel):
session_id: str
message: str
@app.post("/message", tags=["chat"])
def message(req: MessageReq):
s = get_session(req.session_id)
if not s:
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
log_chat_message(req.session_id, "user", req.message)
base_prompt = s.get("selected_prompt") or s["original_query"]
full_prompt = base_prompt + "\n\n" + req.message
try:
ans = call_qwen(full_prompt, temperature=0.3, max_tokens=1024, model_name=s.get("model_name"))
except Exception as e:
raise AppException(400, f"ollama error: {e}", "OLLAMA_ERROR")
log_chat_message(req.session_id, "assistant", ans)
return ok({"session_id": req.session_id, "answer": ans, "history": s["chat_history"]})
class QueryFromMsgReq(BaseModel):
session_id: str
@app.post("/query_from_message", tags=["opro"])
def query_from_message(req: QueryFromMsgReq):
s = get_session(req.session_id)
if not s:
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
last_user = None
for m in reversed(s.get("chat_history", [])):
if m.get("role") == "user" and m.get("content"):
last_user = m["content"]
break
base = last_user or s["original_query"]
cands = generate_candidates(base, s["history_candidates"], model_name=s.get("model_name"))
update_session_add_candidates(req.session_id, cands)
return ok({"session_id": req.session_id, "round": s["round"], "candidates": cands})
class AnswerReq(BaseModel):
query: str
@app.post("/answer", tags=["opro"])
def answer(req: AnswerReq):
sid = create_session(req.query)
log_chat_message(sid, "user", req.query)
ans = call_qwen(req.query, temperature=0.2, max_tokens=1024)
log_chat_message(sid, "assistant", ans)
cands = generate_candidates(req.query, [])
update_session_add_candidates(sid, cands)
return ok({"session_id": sid, "answer": ans, "candidates": cands})
@app.get("/models", tags=["models"])
def models():
return ok({"models": list_models()})
@app.post("/set_model", tags=["models"])
def set_model(req: SetModelReq):
s = get_session(req.session_id)
if not s:
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
avail = set(list_models() or [])
if req.model_name not in avail:
raise AppException(400, f"model not available: {req.model_name}", "MODEL_NOT_AVAILABLE")
set_session_model(req.session_id, req.model_name)
return ok({"session_id": req.session_id, "model_name": req.model_name})