from fastapi import FastAPI, HTTPException, Request from fastapi.responses import RedirectResponse, FileResponse, JSONResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel import config from .opro.session_state import create_session, get_session, update_session_add_candidates, log_user_choice from .opro.session_state import log_user_reject from .opro.session_state import set_selected_prompt, log_chat_message from .opro.session_state import set_session_model from .opro.session_state import USER_FEEDBACK_LOG from .opro.user_prompt_optimizer import generate_candidates from .opro.ollama_client import call_qwen from .opro.ollama_client import list_models from fastapi.middleware.cors import CORSMiddleware app = FastAPI( title=config.APP_TITLE, description=config.APP_DESCRIPTION, version=config.APP_VERSION, contact=config.APP_CONTACT, openapi_tags=[ {"name": "health", "description": "健康检查"}, {"name": "models", "description": "模型列表与设置"}, {"name": "sessions", "description": "会话管理"}, {"name": "opro", "description": "提示优化候选生成与选择/拒绝"}, {"name": "chat", "description": "会话聊天"}, {"name": "ui", "description": "静态页面"} ] ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) MAX_ROUNDS = 3 def ok(data=None): return JSONResponse({"success": True, "data": data}) class AppException(HTTPException): def __init__(self, status_code: int, detail: str, error_code: str): super().__init__(status_code=status_code, detail=detail) self.error_code = error_code @app.exception_handler(AppException) def _app_exc_handler(request: Request, exc: AppException): return JSONResponse(status_code=exc.status_code, content={ "success": False, "error": str(exc.detail), "error_code": exc.error_code }) @app.exception_handler(HTTPException) def _http_exc_handler(request: Request, exc: HTTPException): return JSONResponse(status_code=exc.status_code, content={ "success": False, "error": str(exc.detail), "error_code": "HTTP_ERROR" }) @app.exception_handler(Exception) def _generic_exc_handler(request: Request, exc: Exception): return JSONResponse(status_code=500, content={ "success": False, "error": "internal error", "error_code": "INTERNAL_ERROR" }) class StartReq(BaseModel): query: str class NextReq(BaseModel): session_id: str class SelectReq(BaseModel): session_id: str choice: str class RejectReq(BaseModel): session_id: str candidate: str reason: str | None = None class SetModelReq(BaseModel): session_id: str model_name: str @app.post("/start", tags=["opro"]) def start(req: StartReq): sid = create_session(req.query) cands = generate_candidates(req.query, [], model_name=get_session(sid).get("model_name")) update_session_add_candidates(sid, cands) return ok({"session_id": sid, "round": 0, "candidates": cands}) @app.post("/next", tags=["opro"]) def next_round(req: NextReq): s = get_session(req.session_id) if not s: raise AppException(404, "session not found", "SESSION_NOT_FOUND") if s["round"] >= MAX_ROUNDS: ans = call_qwen(s["original_query"], temperature=0.3, max_tokens=512) return ok({"final": True, "answer": ans}) cands = generate_candidates(s["original_query"], s["history_candidates"], model_name=s.get("model_name")) update_session_add_candidates(req.session_id, cands) return ok({"session_id": req.session_id, "round": s["round"], "candidates": cands}) @app.post("/select", tags=["opro"]) def select(req: SelectReq): s = get_session(req.session_id) if not s: raise AppException(404, "session not found", "SESSION_NOT_FOUND") log_user_choice(req.session_id, req.choice) set_selected_prompt(req.session_id, req.choice) log_chat_message(req.session_id, "system", req.choice) try: ans = call_qwen(req.choice, temperature=0.2, max_tokens=1024, model_name=s.get("model_name")) except Exception as e: raise AppException(400, f"ollama error: {e}", "OLLAMA_ERROR") log_chat_message(req.session_id, "assistant", ans) try: import os, json os.makedirs("outputs", exist_ok=True) with open("outputs/user_feedback.jsonl", "a", encoding="utf-8") as f: f.write(json.dumps({ "session_id": req.session_id, "round": s["round"], "choice": req.choice, "answer": ans }, ensure_ascii=False) + "\n") except Exception: pass return ok({"prompt": req.choice, "answer": ans}) @app.post("/reject", tags=["opro"]) def reject(req: RejectReq): s = get_session(req.session_id) if not s: raise AppException(404, "session not found", "SESSION_NOT_FOUND") log_user_reject(req.session_id, req.candidate, req.reason) cands = generate_candidates(s["original_query"], s["history_candidates"] + [req.candidate], model_name=s.get("model_name")) update_session_add_candidates(req.session_id, cands) return ok({"session_id": req.session_id, "round": s["round"], "candidates": cands}) class QueryReq(BaseModel): query: str session_id: str | None = None @app.post("/query", tags=["opro"]) def query(req: QueryReq): if req.session_id: s = get_session(req.session_id) if not s: raise AppException(404, "session not found", "SESSION_NOT_FOUND") cands = generate_candidates(s["original_query"], s["history_candidates"], model_name=s.get("model_name")) update_session_add_candidates(req.session_id, cands) return ok({"session_id": req.session_id, "round": s["round"], "candidates": cands}) else: sid = create_session(req.query) log_chat_message(sid, "user", req.query) cands = generate_candidates(req.query, [], model_name=get_session(sid).get("model_name")) update_session_add_candidates(sid, cands) return ok({"session_id": sid, "round": 0, "candidates": cands}) app.mount("/ui", StaticFiles(directory="frontend", html=True), name="static") @app.get("/", tags=["ui"]) def root(): return RedirectResponse(url="/ui/") @app.get("/health", tags=["health"]) def health(): return ok({"status": "ok", "version": config.APP_VERSION}) @app.get("/version", tags=["health"]) def version(): return ok({"version": config.APP_VERSION}) # @app.get("/ui/react", tags=["ui"]) # def ui_react(): # return FileResponse("frontend/react/index.html") # @app.get("/ui/offline", tags=["ui"]) # def ui_offline(): # return FileResponse("frontend/ui_offline.html") @app.get("/react", tags=["ui"]) def react_root(): return FileResponse("frontend/react/index.html") @app.get("/sessions", tags=["sessions"]) def sessions(): from .opro.session_state import SESSIONS return ok({"sessions": [{ "session_id": sid, "round": s.get("round", 0), "selected_prompt": s.get("selected_prompt"), "original_query": s.get("original_query") } for sid, s in SESSIONS.items()]}) @app.get("/session/{sid}", tags=["sessions"]) def session_detail(sid: str): s = get_session(sid) if not s: raise AppException(404, "session not found", "SESSION_NOT_FOUND") return ok({ "session_id": sid, "round": s["round"], "original_query": s["original_query"], "selected_prompt": s["selected_prompt"], "candidates": s["history_candidates"], "user_feedback": s["user_feedback"], "rejected": s["rejected"], "history": s["chat_history"], }) class MessageReq(BaseModel): session_id: str message: str @app.post("/message", tags=["chat"]) def message(req: MessageReq): s = get_session(req.session_id) if not s: raise AppException(404, "session not found", "SESSION_NOT_FOUND") log_chat_message(req.session_id, "user", req.message) base_prompt = s.get("selected_prompt") or s["original_query"] full_prompt = base_prompt + "\n\n" + req.message try: ans = call_qwen(full_prompt, temperature=0.3, max_tokens=1024, model_name=s.get("model_name")) except Exception as e: raise AppException(400, f"ollama error: {e}", "OLLAMA_ERROR") log_chat_message(req.session_id, "assistant", ans) return ok({"session_id": req.session_id, "answer": ans, "history": s["chat_history"]}) class QueryFromMsgReq(BaseModel): session_id: str message: str | None = None @app.post("/query_from_message", tags=["opro"]) def query_from_message(req: QueryFromMsgReq): s = get_session(req.session_id) if not s: raise AppException(404, "session not found", "SESSION_NOT_FOUND") base = None if req.message: log_chat_message(req.session_id, "user", req.message) base = req.message else: for m in reversed(s.get("chat_history", [])): if m.get("role") == "user" and m.get("content"): base = m["content"] break base = base or s["original_query"] cands = generate_candidates(base, s["history_candidates"], model_name=s.get("model_name")) update_session_add_candidates(req.session_id, cands) return ok({"session_id": req.session_id, "round": s["round"], "candidates": cands}) class AnswerReq(BaseModel): query: str @app.post("/answer", tags=["opro"]) def answer(req: AnswerReq): sid = create_session(req.query) log_chat_message(sid, "user", req.query) ans = call_qwen(req.query, temperature=0.2, max_tokens=1024) log_chat_message(sid, "assistant", ans) cands = generate_candidates(req.query, []) update_session_add_candidates(sid, cands) return ok({"session_id": sid, "answer": ans, "candidates": cands}) @app.get("/models", tags=["models"]) def models(): return ok({"models": list_models()}) @app.post("/set_model", tags=["models"]) def set_model(req: SetModelReq): s = get_session(req.session_id) if not s: raise AppException(404, "session not found", "SESSION_NOT_FOUND") avail = set(list_models() or []) if req.model_name not in avail: raise AppException(400, f"model not available: {req.model_name}", "MODEL_NOT_AVAILABLE") set_session_model(req.session_id, req.model_name) return ok({"session_id": req.session_id, "model_name": req.model_name})