原始代码

This commit is contained in:
xxm
2025-12-05 07:11:25 +00:00
parent 045e777a11
commit dd5339de32
46 changed files with 5848 additions and 0 deletions

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,284 @@
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import config
from .opro.session_state import create_session, get_session, update_session_add_candidates, log_user_choice
from .opro.session_state import log_user_reject
from .opro.session_state import set_selected_prompt, log_chat_message
from .opro.session_state import set_session_model
from .opro.session_state import USER_FEEDBACK_LOG
from .opro.user_prompt_optimizer import generate_candidates
from .opro.ollama_client import call_qwen
from .opro.ollama_client import list_models
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI(
title=config.APP_TITLE,
description=config.APP_DESCRIPTION,
version=config.APP_VERSION,
contact=config.APP_CONTACT,
openapi_tags=[
{"name": "health", "description": "健康检查"},
{"name": "models", "description": "模型列表与设置"},
{"name": "sessions", "description": "会话管理"},
{"name": "opro", "description": "提示优化候选生成与选择/拒绝"},
{"name": "chat", "description": "会话聊天"},
{"name": "ui", "description": "静态页面"}
]
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
MAX_ROUNDS = 3
def ok(data=None):
return JSONResponse({"success": True, "data": data})
class AppException(HTTPException):
def __init__(self, status_code: int, detail: str, error_code: str):
super().__init__(status_code=status_code, detail=detail)
self.error_code = error_code
@app.exception_handler(AppException)
def _app_exc_handler(request: Request, exc: AppException):
return JSONResponse(status_code=exc.status_code, content={
"success": False,
"error": str(exc.detail),
"error_code": exc.error_code
})
@app.exception_handler(HTTPException)
def _http_exc_handler(request: Request, exc: HTTPException):
return JSONResponse(status_code=exc.status_code, content={
"success": False,
"error": str(exc.detail),
"error_code": "HTTP_ERROR"
})
@app.exception_handler(Exception)
def _generic_exc_handler(request: Request, exc: Exception):
return JSONResponse(status_code=500, content={
"success": False,
"error": "internal error",
"error_code": "INTERNAL_ERROR"
})
class StartReq(BaseModel):
query: str
class NextReq(BaseModel):
session_id: str
class SelectReq(BaseModel):
session_id: str
choice: str
class RejectReq(BaseModel):
session_id: str
candidate: str
reason: str | None = None
class SetModelReq(BaseModel):
session_id: str
model_name: str
@app.post("/start", tags=["opro"])
def start(req: StartReq):
sid = create_session(req.query)
cands = generate_candidates(req.query, [], model_name=get_session(sid).get("model_name"))
update_session_add_candidates(sid, cands)
return ok({"session_id": sid, "round": 0, "candidates": cands})
@app.post("/next", tags=["opro"])
def next_round(req: NextReq):
s = get_session(req.session_id)
if not s:
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
if s["round"] >= MAX_ROUNDS:
ans = call_qwen(s["original_query"], temperature=0.3, max_tokens=512)
return ok({"final": True, "answer": ans})
cands = generate_candidates(s["original_query"], s["history_candidates"], model_name=s.get("model_name"))
update_session_add_candidates(req.session_id, cands)
return ok({"session_id": req.session_id, "round": s["round"], "candidates": cands})
@app.post("/select", tags=["opro"])
def select(req: SelectReq):
s = get_session(req.session_id)
if not s:
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
log_user_choice(req.session_id, req.choice)
set_selected_prompt(req.session_id, req.choice)
log_chat_message(req.session_id, "system", req.choice)
try:
ans = call_qwen(req.choice, temperature=0.2, max_tokens=1024, model_name=s.get("model_name"))
except Exception as e:
raise AppException(400, f"ollama error: {e}", "OLLAMA_ERROR")
log_chat_message(req.session_id, "assistant", ans)
try:
import os, json
os.makedirs("outputs", exist_ok=True)
with open("outputs/user_feedback.jsonl", "a", encoding="utf-8") as f:
f.write(json.dumps({
"session_id": req.session_id,
"round": s["round"],
"choice": req.choice,
"answer": ans
}, ensure_ascii=False) + "\n")
except Exception:
pass
return ok({"prompt": req.choice, "answer": ans})
@app.post("/reject", tags=["opro"])
def reject(req: RejectReq):
s = get_session(req.session_id)
if not s:
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
log_user_reject(req.session_id, req.candidate, req.reason)
cands = generate_candidates(s["original_query"], s["history_candidates"] + [req.candidate], model_name=s.get("model_name"))
update_session_add_candidates(req.session_id, cands)
return ok({"session_id": req.session_id, "round": s["round"], "candidates": cands})
class QueryReq(BaseModel):
query: str
session_id: str | None = None
@app.post("/query", tags=["opro"])
def query(req: QueryReq):
if req.session_id:
s = get_session(req.session_id)
if not s:
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
cands = generate_candidates(s["original_query"], s["history_candidates"], model_name=s.get("model_name"))
update_session_add_candidates(req.session_id, cands)
return ok({"session_id": req.session_id, "round": s["round"], "candidates": cands})
else:
sid = create_session(req.query)
log_chat_message(sid, "user", req.query)
cands = generate_candidates(req.query, [], model_name=get_session(sid).get("model_name"))
update_session_add_candidates(sid, cands)
return ok({"session_id": sid, "round": 0, "candidates": cands})
app.mount("/ui", StaticFiles(directory="frontend", html=True), name="static")
@app.get("/", tags=["ui"])
def root():
return RedirectResponse(url="/ui/")
@app.get("/health", tags=["health"])
def health():
return ok({"status": "ok", "version": config.APP_VERSION})
@app.get("/version", tags=["health"])
def version():
return ok({"version": config.APP_VERSION})
# @app.get("/ui/react", tags=["ui"])
# def ui_react():
# return FileResponse("frontend/react/index.html")
# @app.get("/ui/offline", tags=["ui"])
# def ui_offline():
# return FileResponse("frontend/ui_offline.html")
@app.get("/react", tags=["ui"])
def react_root():
return FileResponse("frontend/react/index.html")
@app.get("/sessions", tags=["sessions"])
def sessions():
from .opro.session_state import SESSIONS
return ok({"sessions": [{
"session_id": sid,
"round": s.get("round", 0),
"selected_prompt": s.get("selected_prompt"),
"original_query": s.get("original_query")
} for sid, s in SESSIONS.items()]})
@app.get("/session/{sid}", tags=["sessions"])
def session_detail(sid: str):
s = get_session(sid)
if not s:
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
return ok({
"session_id": sid,
"round": s["round"],
"original_query": s["original_query"],
"selected_prompt": s["selected_prompt"],
"candidates": s["history_candidates"],
"user_feedback": s["user_feedback"],
"rejected": s["rejected"],
"history": s["chat_history"],
})
class MessageReq(BaseModel):
session_id: str
message: str
@app.post("/message", tags=["chat"])
def message(req: MessageReq):
s = get_session(req.session_id)
if not s:
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
log_chat_message(req.session_id, "user", req.message)
base_prompt = s.get("selected_prompt") or s["original_query"]
full_prompt = base_prompt + "\n\n" + req.message
try:
ans = call_qwen(full_prompt, temperature=0.3, max_tokens=1024, model_name=s.get("model_name"))
except Exception as e:
raise AppException(400, f"ollama error: {e}", "OLLAMA_ERROR")
log_chat_message(req.session_id, "assistant", ans)
return ok({"session_id": req.session_id, "answer": ans, "history": s["chat_history"]})
class QueryFromMsgReq(BaseModel):
session_id: str
@app.post("/query_from_message", tags=["opro"])
def query_from_message(req: QueryFromMsgReq):
s = get_session(req.session_id)
if not s:
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
last_user = None
for m in reversed(s.get("chat_history", [])):
if m.get("role") == "user" and m.get("content"):
last_user = m["content"]
break
base = last_user or s["original_query"]
cands = generate_candidates(base, s["history_candidates"], model_name=s.get("model_name"))
update_session_add_candidates(req.session_id, cands)
return ok({"session_id": req.session_id, "round": s["round"], "candidates": cands})
class AnswerReq(BaseModel):
query: str
@app.post("/answer", tags=["opro"])
def answer(req: AnswerReq):
sid = create_session(req.query)
log_chat_message(sid, "user", req.query)
ans = call_qwen(req.query, temperature=0.2, max_tokens=1024)
log_chat_message(sid, "assistant", ans)
cands = generate_candidates(req.query, [])
update_session_add_candidates(sid, cands)
return ok({"session_id": sid, "answer": ans, "candidates": cands})
@app.get("/models", tags=["models"])
def models():
return ok({"models": list_models()})
@app.post("/set_model", tags=["models"])
def set_model(req: SetModelReq):
s = get_session(req.session_id)
if not s:
raise AppException(404, "session not found", "SESSION_NOT_FOUND")
avail = set(list_models() or [])
if req.model_name not in avail:
raise AppException(400, f"model not available: {req.model_name}", "MODEL_NOT_AVAILABLE")
set_session_model(req.session_id, req.model_name)
return ok({"session_id": req.session_id, "model_name": req.model_name})

View File

@@ -0,0 +1,52 @@
import requests
import re
import config
OLLAMA_URL = config.OLLAMA_GENERATE_URL
TAGS_URL = config.OLLAMA_TAGS_URL
MODEL_NAME = config.DEFAULT_CHAT_MODEL
def call_qwen(prompt: str, temperature: float = 0.8, max_tokens: int = 512, model_name: str | None = None) -> str:
def _payload(m: str):
return {
"model": m,
"prompt": prompt,
"stream": False,
"options": {
"temperature": temperature,
"num_predict": max_tokens
}
}
primary = model_name or MODEL_NAME
try:
resp = requests.post(OLLAMA_URL, json=_payload(primary), timeout=60)
resp.raise_for_status()
data = resp.json()
return data.get("response", "") or data.get("text", "")
except requests.HTTPError as e:
# Try fallback to default when user-selected model fails
if model_name and model_name != MODEL_NAME:
try:
resp = requests.post(OLLAMA_URL, json=_payload(MODEL_NAME), timeout=60)
resp.raise_for_status()
data = resp.json()
return data.get("response", "") or data.get("text", "")
except Exception:
pass
raise
def list_models() -> list[str]:
try:
r = requests.get(TAGS_URL, timeout=10)
r.raise_for_status()
data = r.json() or {}
items = data.get("models") or []
names = []
for m in items:
name = m.get("name") or m.get("model")
if name:
names.append(name)
names = [n for n in names if not re.search(r"embedding|rerank|reranker|bge", n, re.I)]
return names
except Exception:
return [MODEL_NAME]

View File

@@ -0,0 +1,20 @@
def refine_instruction(query: str) -> str:
return f"""
你是一个“问题澄清与重写助手”。
请根据用户的原始问题:
{query}
生成不少于20条多角度、可直接执行的问题改写每行一条。
"""
def refine_instruction_with_history(query: str, rejected_list: list) -> str:
rejected_text = "\n".join(f"- {r}" for r in rejected_list) if rejected_list else ""
return f"""
你是一个“问题澄清与重写助手”。
原始问题:
{query}
以下改写已被否定:
{rejected_text}
请从新的角度重新生成至少20条不同的改写问题每条单独一行。
"""

View File

@@ -0,0 +1,56 @@
import uuid
SESSIONS = {}
USER_FEEDBACK_LOG = []
def create_session(query: str) -> str:
sid = uuid.uuid4().hex
SESSIONS[sid] = {
"original_query": query,
"round": 0,
"history_candidates": [],
"user_feedback": [],
"rejected": [],
"selected_prompt": None,
"chat_history": [],
"model_name": None
}
return sid
def get_session(sid: str):
return SESSIONS.get(sid)
def update_session_add_candidates(sid: str, candidates: list):
s = SESSIONS[sid]
s["round"] += 1
s["history_candidates"].extend(candidates)
def log_user_choice(sid: str, choice: str):
SESSIONS[sid]["user_feedback"].append(
{"round": SESSIONS[sid]["round"], "choice": choice}
)
USER_FEEDBACK_LOG.append({
"session_id": sid,
"round": SESSIONS[sid]["round"],
"choice": choice
})
def log_user_reject(sid: str, candidate: str, reason: str | None = None):
SESSIONS[sid]["rejected"].append(candidate)
USER_FEEDBACK_LOG.append({
"session_id": sid,
"round": SESSIONS[sid]["round"],
"reject": candidate,
"reason": reason or ""
})
def set_selected_prompt(sid: str, prompt: str):
SESSIONS[sid]["selected_prompt"] = prompt
def log_chat_message(sid: str, role: str, content: str):
SESSIONS[sid]["chat_history"].append({"role": role, "content": content})
def set_session_model(sid: str, model_name: str | None):
s = SESSIONS.get(sid)
if s is not None:
s["model_name"] = model_name

View File

@@ -0,0 +1,55 @@
import re
import numpy as np
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics.pairwise import cosine_similarity
import config
from .ollama_client import call_qwen
from .xinference_client import embed_texts
from .prompt_utils import refine_instruction, refine_instruction_with_history
def parse_candidates(raw: str) -> list:
lines = [l.strip() for l in re.split(r'\r?\n', raw) if l.strip()]
cleaned = []
for l in lines:
l = re.sub(r'^[\-\*\d\.\)\s]+', '', l).strip()
if len(l) >= 6:
cleaned.append(l)
return list(dict.fromkeys(cleaned))
def cluster_and_select(candidates: list, top_k=config.TOP_K, distance_threshold=config.CLUSTER_DISTANCE_THRESHOLD):
if not candidates:
return []
if len(candidates) <= top_k:
return candidates
vecs = embed_texts(candidates)
if not vecs or len(vecs) != len(candidates):
return candidates[:top_k]
X = np.array(vecs)
clustering = AgglomerativeClustering(n_clusters=None,
distance_threshold=distance_threshold,
metric="cosine",
linkage="average")
labels = clustering.fit_predict(X)
selected_idx = []
for label in sorted(set(labels)):
idxs = [i for i,l in enumerate(labels) if l == label]
sims = cosine_similarity(X[idxs]).mean(axis=1)
rep = idxs[int(np.argmax(sims))]
selected_idx.append(rep)
selected = [candidates[i] for i in sorted(selected_idx)]
return selected[:top_k]
def generate_candidates(query: str, rejected=None, top_k=config.TOP_K, model_name=None):
rejected = rejected or []
if rejected:
prompt = refine_instruction_with_history(query, rejected)
else:
prompt = refine_instruction(query)
raw = call_qwen(prompt, temperature=0.9, max_tokens=1024, model_name=model_name)
all_candidates = parse_candidates(raw)
return cluster_and_select(all_candidates, top_k=top_k)

View File

@@ -0,0 +1,29 @@
import requests
from typing import List
import config
XINFERENCE_EMBED_URL = config.XINFERENCE_EMBED_URL
OLLAMA_EMBED_URL = config.OLLAMA_HOST + "/api/embeddings"
def embed_texts(texts: List[str]) -> List[List[float]]:
payload = {"inputs": texts}
try:
resp = requests.post(XINFERENCE_EMBED_URL, json=payload, timeout=10)
resp.raise_for_status()
data = resp.json()
embs = data.get("embeddings", [])
if embs:
return embs
except Exception:
pass
try:
payload2 = {"model": config.DEFAULT_EMBED_MODEL, "input": texts}
resp2 = requests.post(OLLAMA_EMBED_URL, json=payload2, timeout=15)
resp2.raise_for_status()
data2 = resp2.json()
if isinstance(data2, dict) and "data" in data2:
return [item.get("embedding", []) for item in data2["data"]]
return data2.get("embeddings", [])
except Exception:
return []