285 lines
10 KiB
Python
285 lines
10 KiB
Python
from fastapi import FastAPI, HTTPException, Request
|
|
from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from pydantic import BaseModel
|
|
import config
|
|
|
|
from .opro.session_state import create_session, get_session, update_session_add_candidates, log_user_choice
|
|
from .opro.session_state import log_user_reject
|
|
from .opro.session_state import set_selected_prompt, log_chat_message
|
|
from .opro.session_state import set_session_model
|
|
from .opro.session_state import USER_FEEDBACK_LOG
|
|
from .opro.user_prompt_optimizer import generate_candidates
|
|
from .opro.ollama_client import call_qwen
|
|
from .opro.ollama_client import list_models
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
app = FastAPI(
|
|
title=config.APP_TITLE,
|
|
description=config.APP_DESCRIPTION,
|
|
version=config.APP_VERSION,
|
|
contact=config.APP_CONTACT,
|
|
openapi_tags=[
|
|
{"name": "health", "description": "健康检查"},
|
|
{"name": "models", "description": "模型列表与设置"},
|
|
{"name": "sessions", "description": "会话管理"},
|
|
{"name": "opro", "description": "提示优化候选生成与选择/拒绝"},
|
|
{"name": "chat", "description": "会话聊天"},
|
|
{"name": "ui", "description": "静态页面"}
|
|
]
|
|
)
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
MAX_ROUNDS = 3
|
|
|
|
def ok(data=None):
|
|
return JSONResponse({"success": True, "data": data})
|
|
|
|
class AppException(HTTPException):
|
|
def __init__(self, status_code: int, detail: str, error_code: str):
|
|
super().__init__(status_code=status_code, detail=detail)
|
|
self.error_code = error_code
|
|
|
|
@app.exception_handler(AppException)
|
|
def _app_exc_handler(request: Request, exc: AppException):
|
|
return JSONResponse(status_code=exc.status_code, content={
|
|
"success": False,
|
|
"error": str(exc.detail),
|
|
"error_code": exc.error_code
|
|
})
|
|
|
|
@app.exception_handler(HTTPException)
|
|
def _http_exc_handler(request: Request, exc: HTTPException):
|
|
return JSONResponse(status_code=exc.status_code, content={
|
|
"success": False,
|
|
"error": str(exc.detail),
|
|
"error_code": "HTTP_ERROR"
|
|
})
|
|
|
|
@app.exception_handler(Exception)
|
|
def _generic_exc_handler(request: Request, exc: Exception):
|
|
return JSONResponse(status_code=500, content={
|
|
"success": False,
|
|
"error": "internal error",
|
|
"error_code": "INTERNAL_ERROR"
|
|
})
|
|
|
|
class StartReq(BaseModel):
|
|
query: str
|
|
|
|
class NextReq(BaseModel):
|
|
session_id: str
|
|
|
|
class SelectReq(BaseModel):
|
|
session_id: str
|
|
choice: str
|
|
|
|
class RejectReq(BaseModel):
|
|
session_id: str
|
|
candidate: str
|
|
reason: str | None = None
|
|
|
|
class SetModelReq(BaseModel):
|
|
session_id: str
|
|
model_name: str
|
|
|
|
@app.post("/start", tags=["opro"])
|
|
def start(req: StartReq):
|
|
sid = create_session(req.query)
|
|
cands = generate_candidates(req.query, [], model_name=get_session(sid).get("model_name"))
|
|
update_session_add_candidates(sid, cands)
|
|
return ok({"session_id": sid, "round": 0, "candidates": cands})
|
|
|
|
@app.post("/next", tags=["opro"])
|
|
def next_round(req: NextReq):
|
|
s = get_session(req.session_id)
|
|
if not s:
|
|
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
|
|
|
|
if s["round"] >= MAX_ROUNDS:
|
|
ans = call_qwen(s["original_query"], temperature=0.3, max_tokens=512)
|
|
return ok({"final": True, "answer": ans})
|
|
|
|
cands = generate_candidates(s["original_query"], s["history_candidates"], model_name=s.get("model_name"))
|
|
update_session_add_candidates(req.session_id, cands)
|
|
return ok({"session_id": req.session_id, "round": s["round"], "candidates": cands})
|
|
|
|
@app.post("/select", tags=["opro"])
|
|
def select(req: SelectReq):
|
|
s = get_session(req.session_id)
|
|
if not s:
|
|
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
|
|
|
|
log_user_choice(req.session_id, req.choice)
|
|
set_selected_prompt(req.session_id, req.choice)
|
|
log_chat_message(req.session_id, "system", req.choice)
|
|
try:
|
|
ans = call_qwen(req.choice, temperature=0.2, max_tokens=1024, model_name=s.get("model_name"))
|
|
except Exception as e:
|
|
raise AppException(400, f"ollama error: {e}", "OLLAMA_ERROR")
|
|
log_chat_message(req.session_id, "assistant", ans)
|
|
try:
|
|
import os, json
|
|
os.makedirs("outputs", exist_ok=True)
|
|
with open("outputs/user_feedback.jsonl", "a", encoding="utf-8") as f:
|
|
f.write(json.dumps({
|
|
"session_id": req.session_id,
|
|
"round": s["round"],
|
|
"choice": req.choice,
|
|
"answer": ans
|
|
}, ensure_ascii=False) + "\n")
|
|
except Exception:
|
|
pass
|
|
return ok({"prompt": req.choice, "answer": ans})
|
|
|
|
@app.post("/reject", tags=["opro"])
|
|
def reject(req: RejectReq):
|
|
s = get_session(req.session_id)
|
|
if not s:
|
|
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
|
|
log_user_reject(req.session_id, req.candidate, req.reason)
|
|
cands = generate_candidates(s["original_query"], s["history_candidates"] + [req.candidate], model_name=s.get("model_name"))
|
|
update_session_add_candidates(req.session_id, cands)
|
|
return ok({"session_id": req.session_id, "round": s["round"], "candidates": cands})
|
|
class QueryReq(BaseModel):
|
|
query: str
|
|
session_id: str | None = None
|
|
|
|
@app.post("/query", tags=["opro"])
|
|
def query(req: QueryReq):
|
|
if req.session_id:
|
|
s = get_session(req.session_id)
|
|
if not s:
|
|
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
|
|
cands = generate_candidates(s["original_query"], s["history_candidates"], model_name=s.get("model_name"))
|
|
update_session_add_candidates(req.session_id, cands)
|
|
return ok({"session_id": req.session_id, "round": s["round"], "candidates": cands})
|
|
else:
|
|
sid = create_session(req.query)
|
|
log_chat_message(sid, "user", req.query)
|
|
cands = generate_candidates(req.query, [], model_name=get_session(sid).get("model_name"))
|
|
update_session_add_candidates(sid, cands)
|
|
return ok({"session_id": sid, "round": 0, "candidates": cands})
|
|
app.mount("/ui", StaticFiles(directory="frontend", html=True), name="static")
|
|
|
|
@app.get("/", tags=["ui"])
|
|
def root():
|
|
return RedirectResponse(url="/ui/")
|
|
|
|
@app.get("/health", tags=["health"])
|
|
def health():
|
|
return ok({"status": "ok", "version": config.APP_VERSION})
|
|
|
|
@app.get("/version", tags=["health"])
|
|
def version():
|
|
return ok({"version": config.APP_VERSION})
|
|
|
|
# @app.get("/ui/react", tags=["ui"])
|
|
# def ui_react():
|
|
# return FileResponse("frontend/react/index.html")
|
|
|
|
# @app.get("/ui/offline", tags=["ui"])
|
|
# def ui_offline():
|
|
# return FileResponse("frontend/ui_offline.html")
|
|
|
|
@app.get("/react", tags=["ui"])
|
|
def react_root():
|
|
return FileResponse("frontend/react/index.html")
|
|
|
|
@app.get("/sessions", tags=["sessions"])
|
|
def sessions():
|
|
from .opro.session_state import SESSIONS
|
|
return ok({"sessions": [{
|
|
"session_id": sid,
|
|
"round": s.get("round", 0),
|
|
"selected_prompt": s.get("selected_prompt"),
|
|
"original_query": s.get("original_query")
|
|
} for sid, s in SESSIONS.items()]})
|
|
|
|
@app.get("/session/{sid}", tags=["sessions"])
|
|
def session_detail(sid: str):
|
|
s = get_session(sid)
|
|
if not s:
|
|
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
|
|
return ok({
|
|
"session_id": sid,
|
|
"round": s["round"],
|
|
"original_query": s["original_query"],
|
|
"selected_prompt": s["selected_prompt"],
|
|
"candidates": s["history_candidates"],
|
|
"user_feedback": s["user_feedback"],
|
|
"rejected": s["rejected"],
|
|
"history": s["chat_history"],
|
|
})
|
|
|
|
class MessageReq(BaseModel):
|
|
session_id: str
|
|
message: str
|
|
|
|
@app.post("/message", tags=["chat"])
|
|
def message(req: MessageReq):
|
|
s = get_session(req.session_id)
|
|
if not s:
|
|
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
|
|
log_chat_message(req.session_id, "user", req.message)
|
|
base_prompt = s.get("selected_prompt") or s["original_query"]
|
|
full_prompt = base_prompt + "\n\n" + req.message
|
|
try:
|
|
ans = call_qwen(full_prompt, temperature=0.3, max_tokens=1024, model_name=s.get("model_name"))
|
|
except Exception as e:
|
|
raise AppException(400, f"ollama error: {e}", "OLLAMA_ERROR")
|
|
log_chat_message(req.session_id, "assistant", ans)
|
|
return ok({"session_id": req.session_id, "answer": ans, "history": s["chat_history"]})
|
|
|
|
class QueryFromMsgReq(BaseModel):
|
|
session_id: str
|
|
|
|
@app.post("/query_from_message", tags=["opro"])
|
|
def query_from_message(req: QueryFromMsgReq):
|
|
s = get_session(req.session_id)
|
|
if not s:
|
|
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
|
|
last_user = None
|
|
for m in reversed(s.get("chat_history", [])):
|
|
if m.get("role") == "user" and m.get("content"):
|
|
last_user = m["content"]
|
|
break
|
|
base = last_user or s["original_query"]
|
|
cands = generate_candidates(base, s["history_candidates"], model_name=s.get("model_name"))
|
|
update_session_add_candidates(req.session_id, cands)
|
|
return ok({"session_id": req.session_id, "round": s["round"], "candidates": cands})
|
|
|
|
class AnswerReq(BaseModel):
|
|
query: str
|
|
|
|
@app.post("/answer", tags=["opro"])
|
|
def answer(req: AnswerReq):
|
|
sid = create_session(req.query)
|
|
log_chat_message(sid, "user", req.query)
|
|
ans = call_qwen(req.query, temperature=0.2, max_tokens=1024)
|
|
log_chat_message(sid, "assistant", ans)
|
|
cands = generate_candidates(req.query, [])
|
|
update_session_add_candidates(sid, cands)
|
|
return ok({"session_id": sid, "answer": ans, "candidates": cands})
|
|
|
|
@app.get("/models", tags=["models"])
|
|
def models():
|
|
return ok({"models": list_models()})
|
|
|
|
@app.post("/set_model", tags=["models"])
|
|
def set_model(req: SetModelReq):
|
|
s = get_session(req.session_id)
|
|
if not s:
|
|
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
|
|
avail = set(list_models() or [])
|
|
if req.model_name not in avail:
|
|
raise AppException(400, f"model not available: {req.model_name}", "MODEL_NOT_AVAILABLE")
|
|
set_session_model(req.session_id, req.model_name)
|
|
return ok({"session_id": req.session_id, "model_name": req.model_name})
|