diff --git a/_qwen_xinference_demo/api.py b/_qwen_xinference_demo/api.py index 4cb1f96..b39740f 100644 --- a/_qwen_xinference_demo/api.py +++ b/_qwen_xinference_demo/api.py @@ -14,6 +14,7 @@ from .opro.session_state import USER_FEEDBACK_LOG # True OPRO session management from .opro.session_state import ( + create_opro_session, get_opro_session, list_opro_sessions, create_opro_run, get_opro_run, update_opro_iteration, add_opro_evaluation, get_opro_trajectory, set_opro_test_cases, complete_opro_run, list_opro_runs @@ -122,6 +123,7 @@ class CreateOPRORunReq(BaseModel): task_description: str test_cases: Optional[List[TestCase]] = None model_name: Optional[str] = None + session_id: Optional[str] = None # Optional session to associate with class OPROIterateReq(BaseModel): @@ -360,12 +362,62 @@ def set_model(req: SetModelReq): # TRUE OPRO ENDPOINTS (System Instruction Optimization) # ============================================================================ +# Session Management + +@app.post("/opro/session/create", tags=["opro-true"]) +def opro_create_session(session_name: str = None): + """ + Create a new OPRO session that can contain multiple runs. + """ + session_id = create_opro_session(session_name=session_name) + session = get_opro_session(session_id) + + return ok({ + "session_id": session_id, + "session_name": session["session_name"], + "num_runs": len(session["run_ids"]) + }) + + +@app.get("/opro/sessions", tags=["opro-true"]) +def opro_list_sessions(): + """ + List all OPRO sessions. + """ + sessions = list_opro_sessions() + return ok({"sessions": sessions}) + + +@app.get("/opro/session/{session_id}", tags=["opro-true"]) +def opro_get_session(session_id: str): + """ + Get detailed information about an OPRO session. + """ + session = get_opro_session(session_id) + if not session: + raise AppException(404, "Session not found", "SESSION_NOT_FOUND") + + # Get all runs in this session + runs = list_opro_runs(session_id=session_id) + + return ok({ + "session_id": session_id, + "session_name": session["session_name"], + "created_at": session["created_at"], + "num_runs": len(session["run_ids"]), + "runs": runs + }) + + +# Run Management + @app.post("/opro/create", tags=["opro-true"]) def opro_create_run(req: CreateOPRORunReq): """ Create a new OPRO optimization run. This starts a new system instruction optimization process for a given task. + Optionally can be associated with a session. """ # Convert test cases from Pydantic models to tuples test_cases = None @@ -375,7 +427,8 @@ def opro_create_run(req: CreateOPRORunReq): run_id = create_opro_run( task_description=req.task_description, test_cases=test_cases, - model_name=req.model_name + model_name=req.model_name, + session_id=req.session_id ) run = get_opro_run(run_id) @@ -385,7 +438,8 @@ def opro_create_run(req: CreateOPRORunReq): "task_description": run["task_description"], "num_test_cases": len(run["test_cases"]), "iteration": run["iteration"], - "status": run["status"] + "status": run["status"], + "session_id": run.get("session_id") }) diff --git a/_qwen_xinference_demo/opro/session_state.py b/_qwen_xinference_demo/opro/session_state.py index 5ff87b7..3442d73 100644 --- a/_qwen_xinference_demo/opro/session_state.py +++ b/_qwen_xinference_demo/opro/session_state.py @@ -66,10 +66,57 @@ def set_session_model(sid: str, model_name: str | None): # TRUE OPRO SESSION MANAGEMENT # ============================================================================ +# Session storage (contains multiple runs) +OPRO_SESSIONS = {} + +def create_opro_session(session_name: str = None) -> str: + """ + Create a new OPRO session that can contain multiple runs. + + Args: + session_name: Optional name for the session + + Returns: + session_id: Unique identifier for this session + """ + session_id = uuid.uuid4().hex + OPRO_SESSIONS[session_id] = { + "session_name": session_name or "新会话", # Will be updated with first task description + "created_at": uuid.uuid1().time, + "run_ids": [], # List of run IDs in this session + "chat_history": [] # Cross-run chat history + } + return session_id + + +def get_opro_session(session_id: str) -> Dict[str, Any]: + """Get OPRO session by ID.""" + return OPRO_SESSIONS.get(session_id) + + +def list_opro_sessions() -> List[Dict[str, Any]]: + """ + List all OPRO sessions with summary information. + + Returns: + List of session summaries + """ + return [ + { + "session_id": session_id, + "session_name": session["session_name"], + "num_runs": len(session["run_ids"]), + "created_at": session["created_at"] + } + for session_id, session in OPRO_SESSIONS.items() + ] + + def create_opro_run( task_description: str, test_cases: List[Tuple[str, str]] = None, - model_name: str = None + model_name: str = None, + session_id: str = None ) -> str: """ Create a new OPRO optimization run. @@ -78,6 +125,7 @@ def create_opro_run( task_description: Description of the task to optimize for test_cases: List of (input, expected_output) tuples for evaluation model_name: Optional model name to use + session_id: Optional session ID to associate this run with Returns: run_id: Unique identifier for this OPRO run @@ -87,6 +135,7 @@ def create_opro_run( "task_description": task_description, "test_cases": test_cases or [], "model_name": model_name, + "session_id": session_id, # Link to parent session "iteration": 0, "trajectory": [], # List of (instruction, score) tuples "best_instruction": None, @@ -95,6 +144,14 @@ def create_opro_run( "created_at": uuid.uuid1().time, "status": "active" # active, completed, failed } + + # Add run to session if session_id provided + if session_id and session_id in OPRO_SESSIONS: + OPRO_SESSIONS[session_id]["run_ids"].append(run_id) + # Update session name with first task description if it's still default + if OPRO_SESSIONS[session_id]["session_name"] == "新会话" and len(OPRO_SESSIONS[session_id]["run_ids"]) == 1: + OPRO_SESSIONS[session_id]["session_name"] = task_description + return run_id @@ -206,13 +263,22 @@ def complete_opro_run(run_id: str): run["status"] = "completed" -def list_opro_runs() -> List[Dict[str, Any]]: +def list_opro_runs(session_id: str = None) -> List[Dict[str, Any]]: """ List all OPRO runs with summary information. + Args: + session_id: Optional session ID to filter runs by session + Returns: List of run summaries """ + runs_to_list = OPRO_RUNS.items() + + # Filter by session if provided + if session_id: + runs_to_list = [(rid, r) for rid, r in runs_to_list if r.get("session_id") == session_id] + return [ { "run_id": run_id, @@ -220,7 +286,8 @@ def list_opro_runs() -> List[Dict[str, Any]]: "iteration": run["iteration"], "best_score": run["best_score"], "num_test_cases": len(run["test_cases"]), - "status": run["status"] + "status": run["status"], + "session_id": run.get("session_id") } - for run_id, run in OPRO_RUNS.items() + for run_id, run in runs_to_list ] diff --git a/frontend/opro.html b/frontend/opro.html index ddf03cc..86ee588 100644 --- a/frontend/opro.html +++ b/frontend/opro.html @@ -50,7 +50,9 @@ // Main App Component function App() { const [sidebarOpen, setSidebarOpen] = useState(false); - const [runs, setRuns] = useState([]); + const [sessions, setSessions] = useState([]); + const [currentSessionId, setCurrentSessionId] = useState(null); + const [currentSessionRuns, setCurrentSessionRuns] = useState([]); const [currentRunId, setCurrentRunId] = useState(null); const [messages, setMessages] = useState([]); const [inputValue, setInputValue] = useState(''); @@ -59,9 +61,9 @@ const [selectedModel, setSelectedModel] = useState(''); const chatEndRef = useRef(null); - // Load runs and models on mount + // Load sessions and models on mount useEffect(() => { - loadRuns(); + loadSessions(); loadModels(); }, []); @@ -79,56 +81,106 @@ console.error('Failed to load models:', err); } } - + // Auto-scroll chat useEffect(() => { chatEndRef.current?.scrollIntoView({ behavior: 'smooth' }); }, [messages]); - - async function loadRuns() { + + async function loadSessions() { try { - const res = await fetch(`${API_BASE}/opro/runs`); + const res = await fetch(`${API_BASE}/opro/sessions`); const data = await res.json(); if (data.success) { - setRuns(data.data.runs || []); + setSessions(data.data.sessions || []); } } catch (err) { - console.error('Failed to load runs:', err); + console.error('Failed to load sessions:', err); + } + } + + async function loadSessionRuns(sessionId) { + try { + const res = await fetch(`${API_BASE}/opro/session/${sessionId}`); + const data = await res.json(); + if (data.success) { + setCurrentSessionRuns(data.data.runs || []); + } + } catch (err) { + console.error('Failed to load session runs:', err); } } + async function createNewSession() { + try { + const res = await fetch(`${API_BASE}/opro/session/create`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' } + }); + const data = await res.json(); + + if (!data.success) { + throw new Error(data.error || 'Failed to create session'); + } + + const sessionId = data.data.session_id; + setCurrentSessionId(sessionId); + setCurrentSessionRuns([]); + setCurrentRunId(null); + setMessages([]); + + // Reload sessions list + await loadSessions(); + + return sessionId; + } catch (err) { + alert('创建会话失败: ' + err.message); + return null; + } + } + async function createNewRun(taskDescription) { setLoading(true); try { - // Create run + // Ensure we have a session + let sessionId = currentSessionId; + if (!sessionId) { + sessionId = await createNewSession(); + if (!sessionId) return; + } + + // Create run within session const res = await fetch(`${API_BASE}/opro/create`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ task_description: taskDescription, test_cases: [], - model_name: selectedModel || undefined + model_name: selectedModel || undefined, + session_id: sessionId }) }); const data = await res.json(); - + if (!data.success) { throw new Error(data.error || 'Failed to create run'); } - + const runId = data.data.run_id; setCurrentRunId(runId); - + // Add user message setMessages([{ role: 'user', content: taskDescription }]); - + // Generate and evaluate candidates await generateCandidates(runId); - - // Reload runs list - await loadRuns(); + + // Reload sessions and session runs + await loadSessions(); + await loadSessionRuns(sessionId); } catch (err) { alert('创建任务失败: ' + err.message); + console.error('Error creating run:', err); } finally { setLoading(false); } @@ -137,6 +189,7 @@ async function generateCandidates(runId) { setLoading(true); try { + console.log('Generating candidates for run:', runId); const res = await fetch(`${API_BASE}/opro/generate_and_evaluate`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, @@ -147,11 +200,13 @@ }) }); const data = await res.json(); - + + console.log('Generate candidates response:', data); + if (!data.success) { throw new Error(data.error || 'Failed to generate candidates'); } - + // Add assistant message with candidates setMessages(prev => [...prev, { role: 'assistant', @@ -161,6 +216,7 @@ }]); } catch (err) { alert('生成候选指令失败: ' + err.message); + console.error('Error generating candidates:', err); } finally { setLoading(false); } @@ -221,8 +277,7 @@ function handleExecute(instruction) { if (loading) return; - const userInput = prompt('请输入要处理的内容(可选):'); - executeInstruction(instruction, userInput); + executeInstruction(instruction, ''); } function handleCopyInstruction(instruction) { @@ -235,11 +290,31 @@ } function handleNewTask() { + // Create new run within current session setCurrentRunId(null); setMessages([]); setInputValue(''); } + async function handleNewSession() { + // Create completely new session + const sessionId = await createNewSession(); + if (sessionId) { + setCurrentSessionId(sessionId); + setCurrentSessionRuns([]); + setCurrentRunId(null); + setMessages([]); + setInputValue(''); + } + } + + async function handleSelectSession(sessionId) { + setCurrentSessionId(sessionId); + setCurrentRunId(null); + setMessages([]); + await loadSessionRuns(sessionId); + } + async function loadRun(runId) { setLoading(true); try { @@ -301,22 +376,22 @@ // Content area React.createElement('div', { className: 'flex-1 overflow-y-auto scrollbar-hide p-2 flex flex-col' }, sidebarOpen ? React.createElement(React.Fragment, null, - // New task button (expanded) + // New session button (expanded) React.createElement('button', { - onClick: handleNewTask, + onClick: handleNewSession, className: 'mb-3 px-4 py-2.5 bg-white border border-gray-300 hover:bg-gray-50 rounded-lg transition-colors flex items-center justify-center gap-2 text-gray-700 font-medium' }, React.createElement('span', { className: 'text-lg' }, '+'), React.createElement('span', null, '新建会话') ), // Sessions list - runs.length > 0 && React.createElement('div', { className: 'text-xs text-gray-500 mb-2 px-2' }, '会话列表'), - runs.map(run => + sessions.length > 0 && React.createElement('div', { className: 'text-xs text-gray-500 mb-2 px-2' }, '会话列表'), + sessions.map(session => React.createElement('div', { - key: run.run_id, - onClick: () => loadRun(run.run_id), + key: session.session_id, + onClick: () => handleSelectSession(session.session_id), className: `p-3 mb-1 rounded-lg cursor-pointer transition-colors flex items-center gap-2 ${ - currentRunId === run.run_id ? 'bg-gray-100' : 'hover:bg-gray-50' + currentSessionId === session.session_id ? 'bg-gray-100' : 'hover:bg-gray-50' }` }, React.createElement('svg', { @@ -331,12 +406,12 @@ React.createElement('path', { d: 'M21 15a2 2 0 0 1-2 2H7l-4 4V5a2 2 0 0 1 2-2h14a2 2 0 0 1 2 2z' }) ), React.createElement('div', { className: 'text-sm text-gray-800 truncate flex-1' }, - run.task_description + session.session_name ) ) ) ) : React.createElement('button', { - onClick: handleNewTask, + onClick: handleNewSession, className: 'p-2 text-gray-600 hover:bg-gray-100 rounded-lg transition-colors flex items-center justify-center', title: '新建会话' }, @@ -350,10 +425,19 @@ // Main Chat Area React.createElement('div', { className: 'flex-1 flex flex-col bg-white' }, // Header - React.createElement('div', { className: 'px-4 py-3 border-b border-gray-200 bg-white flex items-center gap-3' }, - React.createElement('h1', { className: 'text-lg font-normal text-gray-800' }, - 'OPRO' - ) + React.createElement('div', { className: 'px-4 py-3 border-b border-gray-200 bg-white flex items-center justify-between' }, + React.createElement('div', { className: 'flex items-center gap-3' }, + React.createElement('h1', { className: 'text-lg font-normal text-gray-800' }, + 'OPRO' + ), + currentSessionId && React.createElement('div', { className: 'text-sm text-gray-500' }, + sessions.find(s => s.session_id === currentSessionId)?.session_name || '当前会话' + ) + ), + currentSessionId && React.createElement('button', { + onClick: handleNewTask, + className: 'px-3 py-1.5 text-sm bg-white border border-gray-300 hover:bg-gray-50 rounded-lg transition-colors text-gray-700' + }, '+ 新建任务') ), // Chat Messages @@ -490,7 +574,9 @@ ) ), !currentRunId && React.createElement('div', { className: 'text-xs text-gray-500 mt-3 px-4' }, - '输入任务描述后,AI 将为你生成优化的系统指令' + currentSessionId + ? '输入任务描述后,AI 将为你生成优化的系统指令' + : '点击左侧"新建会话"开始,或输入任务描述自动创建会话' ) ) ) diff --git a/test_session_api.py b/test_session_api.py new file mode 100644 index 0000000..637016d --- /dev/null +++ b/test_session_api.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python3 +""" +Test script for OPRO session-based API +""" + +import requests +import json + +BASE_URL = "http://127.0.0.1:8010" + +def print_section(title): + print(f"\n{'='*60}") + print(f" {title}") + print(f"{'='*60}\n") + +def test_session_workflow(): + """Test the complete session-based workflow.""" + + print_section("1. Create Session") + + # Create a new session + response = requests.post(f"{BASE_URL}/opro/session/create") + result = response.json() + + if not result.get("success"): + print(f"❌ Failed to create session: {result}") + return + + session_id = result["data"]["session_id"] + print(f"✅ Session created: {session_id}") + print(f" Session name: {result['data']['session_name']}") + + print_section("2. Create First Run in Session") + + # Create first run + create_req = { + "task_description": "将中文翻译成英文", + "test_cases": [ + {"input": "你好", "expected_output": "Hello"}, + {"input": "谢谢", "expected_output": "Thank you"} + ], + "session_id": session_id + } + + response = requests.post(f"{BASE_URL}/opro/create", json=create_req) + result = response.json() + + if not result.get("success"): + print(f"❌ Failed to create run: {result}") + return + + run1_id = result["data"]["run_id"] + print(f"✅ Run 1 created: {run1_id}") + print(f" Task: {result['data']['task_description']}") + + print_section("3. Create Second Run in Same Session") + + # Create second run in same session + create_req2 = { + "task_description": "将英文翻译成中文", + "test_cases": [ + {"input": "Hello", "expected_output": "你好"}, + {"input": "Thank you", "expected_output": "谢谢"} + ], + "session_id": session_id + } + + response = requests.post(f"{BASE_URL}/opro/create", json=create_req2) + result = response.json() + + if not result.get("success"): + print(f"❌ Failed to create run 2: {result}") + return + + run2_id = result["data"]["run_id"] + print(f"✅ Run 2 created: {run2_id}") + print(f" Task: {result['data']['task_description']}") + + print_section("4. Get Session Details") + + response = requests.get(f"{BASE_URL}/opro/session/{session_id}") + result = response.json() + + if not result.get("success"): + print(f"❌ Failed to get session: {result}") + return + + print(f"✅ Session details:") + print(f" Session ID: {result['data']['session_id']}") + print(f" Session name: {result['data']['session_name']}") + print(f" Number of runs: {result['data']['num_runs']}") + print(f" Runs:") + for run in result['data']['runs']: + print(f" - {run['run_id'][:8]}... : {run['task_description']}") + + print_section("5. List All Sessions") + + response = requests.get(f"{BASE_URL}/opro/sessions") + result = response.json() + + if not result.get("success"): + print(f"❌ Failed to list sessions: {result}") + return + + print(f"✅ Total sessions: {len(result['data']['sessions'])}") + for session in result['data']['sessions']: + print(f" - {session['session_name']}: {session['num_runs']} runs") + + print_section("✅ All Tests Passed!") + +if __name__ == "__main__": + try: + # Check if server is running + response = requests.get(f"{BASE_URL}/health") + if response.status_code != 200: + print("❌ Server is not running. Please start it with:") + print(" uvicorn _qwen_xinference_demo.api:app --host 127.0.0.1 --port 8010") + exit(1) + + test_session_workflow() + + except requests.exceptions.ConnectionError: + print("❌ Cannot connect to server. Please start it with:") + print(" uvicorn _qwen_xinference_demo.api:app --host 127.0.0.1 --port 8010") + exit(1) + except Exception as e: + print(f"❌ Error: {e}") + import traceback + traceback.print_exc() + exit(1) +