feat: implement session-based architecture for OPRO

- Add session layer above runs to group related optimization tasks
- Sessions use first task description as name instead of 'Session 1'
- Simplified sidebar: show sessions without expansion
- Add '+ 新建任务' button in header to create runs within session
- Fix: reload sessions after creating new run
- Add debugging logs for candidate generation
- Backend: auto-update session name with first task description
This commit is contained in:
2025-12-06 21:26:24 +08:00
parent 1376d60ed5
commit da30a0999c
4 changed files with 380 additions and 42 deletions

View File

@@ -14,6 +14,7 @@ from .opro.session_state import USER_FEEDBACK_LOG
# True OPRO session management # True OPRO session management
from .opro.session_state import ( from .opro.session_state import (
create_opro_session, get_opro_session, list_opro_sessions,
create_opro_run, get_opro_run, update_opro_iteration, create_opro_run, get_opro_run, update_opro_iteration,
add_opro_evaluation, get_opro_trajectory, set_opro_test_cases, add_opro_evaluation, get_opro_trajectory, set_opro_test_cases,
complete_opro_run, list_opro_runs complete_opro_run, list_opro_runs
@@ -122,6 +123,7 @@ class CreateOPRORunReq(BaseModel):
task_description: str task_description: str
test_cases: Optional[List[TestCase]] = None test_cases: Optional[List[TestCase]] = None
model_name: Optional[str] = None model_name: Optional[str] = None
session_id: Optional[str] = None # Optional session to associate with
class OPROIterateReq(BaseModel): class OPROIterateReq(BaseModel):
@@ -360,12 +362,62 @@ def set_model(req: SetModelReq):
# TRUE OPRO ENDPOINTS (System Instruction Optimization) # TRUE OPRO ENDPOINTS (System Instruction Optimization)
# ============================================================================ # ============================================================================
# Session Management
@app.post("/opro/session/create", tags=["opro-true"])
def opro_create_session(session_name: str = None):
"""
Create a new OPRO session that can contain multiple runs.
"""
session_id = create_opro_session(session_name=session_name)
session = get_opro_session(session_id)
return ok({
"session_id": session_id,
"session_name": session["session_name"],
"num_runs": len(session["run_ids"])
})
@app.get("/opro/sessions", tags=["opro-true"])
def opro_list_sessions():
"""
List all OPRO sessions.
"""
sessions = list_opro_sessions()
return ok({"sessions": sessions})
@app.get("/opro/session/{session_id}", tags=["opro-true"])
def opro_get_session(session_id: str):
"""
Get detailed information about an OPRO session.
"""
session = get_opro_session(session_id)
if not session:
raise AppException(404, "Session not found", "SESSION_NOT_FOUND")
# Get all runs in this session
runs = list_opro_runs(session_id=session_id)
return ok({
"session_id": session_id,
"session_name": session["session_name"],
"created_at": session["created_at"],
"num_runs": len(session["run_ids"]),
"runs": runs
})
# Run Management
@app.post("/opro/create", tags=["opro-true"]) @app.post("/opro/create", tags=["opro-true"])
def opro_create_run(req: CreateOPRORunReq): def opro_create_run(req: CreateOPRORunReq):
""" """
Create a new OPRO optimization run. Create a new OPRO optimization run.
This starts a new system instruction optimization process for a given task. This starts a new system instruction optimization process for a given task.
Optionally can be associated with a session.
""" """
# Convert test cases from Pydantic models to tuples # Convert test cases from Pydantic models to tuples
test_cases = None test_cases = None
@@ -375,7 +427,8 @@ def opro_create_run(req: CreateOPRORunReq):
run_id = create_opro_run( run_id = create_opro_run(
task_description=req.task_description, task_description=req.task_description,
test_cases=test_cases, test_cases=test_cases,
model_name=req.model_name model_name=req.model_name,
session_id=req.session_id
) )
run = get_opro_run(run_id) run = get_opro_run(run_id)
@@ -385,7 +438,8 @@ def opro_create_run(req: CreateOPRORunReq):
"task_description": run["task_description"], "task_description": run["task_description"],
"num_test_cases": len(run["test_cases"]), "num_test_cases": len(run["test_cases"]),
"iteration": run["iteration"], "iteration": run["iteration"],
"status": run["status"] "status": run["status"],
"session_id": run.get("session_id")
}) })

View File

@@ -66,10 +66,57 @@ def set_session_model(sid: str, model_name: str | None):
# TRUE OPRO SESSION MANAGEMENT # TRUE OPRO SESSION MANAGEMENT
# ============================================================================ # ============================================================================
# Session storage (contains multiple runs)
OPRO_SESSIONS = {}
def create_opro_session(session_name: str = None) -> str:
"""
Create a new OPRO session that can contain multiple runs.
Args:
session_name: Optional name for the session
Returns:
session_id: Unique identifier for this session
"""
session_id = uuid.uuid4().hex
OPRO_SESSIONS[session_id] = {
"session_name": session_name or "新会话", # Will be updated with first task description
"created_at": uuid.uuid1().time,
"run_ids": [], # List of run IDs in this session
"chat_history": [] # Cross-run chat history
}
return session_id
def get_opro_session(session_id: str) -> Dict[str, Any]:
"""Get OPRO session by ID."""
return OPRO_SESSIONS.get(session_id)
def list_opro_sessions() -> List[Dict[str, Any]]:
"""
List all OPRO sessions with summary information.
Returns:
List of session summaries
"""
return [
{
"session_id": session_id,
"session_name": session["session_name"],
"num_runs": len(session["run_ids"]),
"created_at": session["created_at"]
}
for session_id, session in OPRO_SESSIONS.items()
]
def create_opro_run( def create_opro_run(
task_description: str, task_description: str,
test_cases: List[Tuple[str, str]] = None, test_cases: List[Tuple[str, str]] = None,
model_name: str = None model_name: str = None,
session_id: str = None
) -> str: ) -> str:
""" """
Create a new OPRO optimization run. Create a new OPRO optimization run.
@@ -78,6 +125,7 @@ def create_opro_run(
task_description: Description of the task to optimize for task_description: Description of the task to optimize for
test_cases: List of (input, expected_output) tuples for evaluation test_cases: List of (input, expected_output) tuples for evaluation
model_name: Optional model name to use model_name: Optional model name to use
session_id: Optional session ID to associate this run with
Returns: Returns:
run_id: Unique identifier for this OPRO run run_id: Unique identifier for this OPRO run
@@ -87,6 +135,7 @@ def create_opro_run(
"task_description": task_description, "task_description": task_description,
"test_cases": test_cases or [], "test_cases": test_cases or [],
"model_name": model_name, "model_name": model_name,
"session_id": session_id, # Link to parent session
"iteration": 0, "iteration": 0,
"trajectory": [], # List of (instruction, score) tuples "trajectory": [], # List of (instruction, score) tuples
"best_instruction": None, "best_instruction": None,
@@ -95,6 +144,14 @@ def create_opro_run(
"created_at": uuid.uuid1().time, "created_at": uuid.uuid1().time,
"status": "active" # active, completed, failed "status": "active" # active, completed, failed
} }
# Add run to session if session_id provided
if session_id and session_id in OPRO_SESSIONS:
OPRO_SESSIONS[session_id]["run_ids"].append(run_id)
# Update session name with first task description if it's still default
if OPRO_SESSIONS[session_id]["session_name"] == "新会话" and len(OPRO_SESSIONS[session_id]["run_ids"]) == 1:
OPRO_SESSIONS[session_id]["session_name"] = task_description
return run_id return run_id
@@ -206,13 +263,22 @@ def complete_opro_run(run_id: str):
run["status"] = "completed" run["status"] = "completed"
def list_opro_runs() -> List[Dict[str, Any]]: def list_opro_runs(session_id: str = None) -> List[Dict[str, Any]]:
""" """
List all OPRO runs with summary information. List all OPRO runs with summary information.
Args:
session_id: Optional session ID to filter runs by session
Returns: Returns:
List of run summaries List of run summaries
""" """
runs_to_list = OPRO_RUNS.items()
# Filter by session if provided
if session_id:
runs_to_list = [(rid, r) for rid, r in runs_to_list if r.get("session_id") == session_id]
return [ return [
{ {
"run_id": run_id, "run_id": run_id,
@@ -220,7 +286,8 @@ def list_opro_runs() -> List[Dict[str, Any]]:
"iteration": run["iteration"], "iteration": run["iteration"],
"best_score": run["best_score"], "best_score": run["best_score"],
"num_test_cases": len(run["test_cases"]), "num_test_cases": len(run["test_cases"]),
"status": run["status"] "status": run["status"],
"session_id": run.get("session_id")
} }
for run_id, run in OPRO_RUNS.items() for run_id, run in runs_to_list
] ]

View File

@@ -50,7 +50,9 @@
// Main App Component // Main App Component
function App() { function App() {
const [sidebarOpen, setSidebarOpen] = useState(false); const [sidebarOpen, setSidebarOpen] = useState(false);
const [runs, setRuns] = useState([]); const [sessions, setSessions] = useState([]);
const [currentSessionId, setCurrentSessionId] = useState(null);
const [currentSessionRuns, setCurrentSessionRuns] = useState([]);
const [currentRunId, setCurrentRunId] = useState(null); const [currentRunId, setCurrentRunId] = useState(null);
const [messages, setMessages] = useState([]); const [messages, setMessages] = useState([]);
const [inputValue, setInputValue] = useState(''); const [inputValue, setInputValue] = useState('');
@@ -59,9 +61,9 @@
const [selectedModel, setSelectedModel] = useState(''); const [selectedModel, setSelectedModel] = useState('');
const chatEndRef = useRef(null); const chatEndRef = useRef(null);
// Load runs and models on mount // Load sessions and models on mount
useEffect(() => { useEffect(() => {
loadRuns(); loadSessions();
loadModels(); loadModels();
}, []); }, []);
@@ -85,29 +87,77 @@
chatEndRef.current?.scrollIntoView({ behavior: 'smooth' }); chatEndRef.current?.scrollIntoView({ behavior: 'smooth' });
}, [messages]); }, [messages]);
async function loadRuns() { async function loadSessions() {
try { try {
const res = await fetch(`${API_BASE}/opro/runs`); const res = await fetch(`${API_BASE}/opro/sessions`);
const data = await res.json(); const data = await res.json();
if (data.success) { if (data.success) {
setRuns(data.data.runs || []); setSessions(data.data.sessions || []);
} }
} catch (err) { } catch (err) {
console.error('Failed to load runs:', err); console.error('Failed to load sessions:', err);
}
}
async function loadSessionRuns(sessionId) {
try {
const res = await fetch(`${API_BASE}/opro/session/${sessionId}`);
const data = await res.json();
if (data.success) {
setCurrentSessionRuns(data.data.runs || []);
}
} catch (err) {
console.error('Failed to load session runs:', err);
}
}
async function createNewSession() {
try {
const res = await fetch(`${API_BASE}/opro/session/create`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' }
});
const data = await res.json();
if (!data.success) {
throw new Error(data.error || 'Failed to create session');
}
const sessionId = data.data.session_id;
setCurrentSessionId(sessionId);
setCurrentSessionRuns([]);
setCurrentRunId(null);
setMessages([]);
// Reload sessions list
await loadSessions();
return sessionId;
} catch (err) {
alert('创建会话失败: ' + err.message);
return null;
} }
} }
async function createNewRun(taskDescription) { async function createNewRun(taskDescription) {
setLoading(true); setLoading(true);
try { try {
// Create run // Ensure we have a session
let sessionId = currentSessionId;
if (!sessionId) {
sessionId = await createNewSession();
if (!sessionId) return;
}
// Create run within session
const res = await fetch(`${API_BASE}/opro/create`, { const res = await fetch(`${API_BASE}/opro/create`, {
method: 'POST', method: 'POST',
headers: { 'Content-Type': 'application/json' }, headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ body: JSON.stringify({
task_description: taskDescription, task_description: taskDescription,
test_cases: [], test_cases: [],
model_name: selectedModel || undefined model_name: selectedModel || undefined,
session_id: sessionId
}) })
}); });
const data = await res.json(); const data = await res.json();
@@ -125,10 +175,12 @@
// Generate and evaluate candidates // Generate and evaluate candidates
await generateCandidates(runId); await generateCandidates(runId);
// Reload runs list // Reload sessions and session runs
await loadRuns(); await loadSessions();
await loadSessionRuns(sessionId);
} catch (err) { } catch (err) {
alert('创建任务失败: ' + err.message); alert('创建任务失败: ' + err.message);
console.error('Error creating run:', err);
} finally { } finally {
setLoading(false); setLoading(false);
} }
@@ -137,6 +189,7 @@
async function generateCandidates(runId) { async function generateCandidates(runId) {
setLoading(true); setLoading(true);
try { try {
console.log('Generating candidates for run:', runId);
const res = await fetch(`${API_BASE}/opro/generate_and_evaluate`, { const res = await fetch(`${API_BASE}/opro/generate_and_evaluate`, {
method: 'POST', method: 'POST',
headers: { 'Content-Type': 'application/json' }, headers: { 'Content-Type': 'application/json' },
@@ -148,6 +201,8 @@
}); });
const data = await res.json(); const data = await res.json();
console.log('Generate candidates response:', data);
if (!data.success) { if (!data.success) {
throw new Error(data.error || 'Failed to generate candidates'); throw new Error(data.error || 'Failed to generate candidates');
} }
@@ -161,6 +216,7 @@
}]); }]);
} catch (err) { } catch (err) {
alert('生成候选指令失败: ' + err.message); alert('生成候选指令失败: ' + err.message);
console.error('Error generating candidates:', err);
} finally { } finally {
setLoading(false); setLoading(false);
} }
@@ -221,8 +277,7 @@
function handleExecute(instruction) { function handleExecute(instruction) {
if (loading) return; if (loading) return;
const userInput = prompt('请输入要处理的内容(可选):'); executeInstruction(instruction, '');
executeInstruction(instruction, userInput);
} }
function handleCopyInstruction(instruction) { function handleCopyInstruction(instruction) {
@@ -235,11 +290,31 @@
} }
function handleNewTask() { function handleNewTask() {
// Create new run within current session
setCurrentRunId(null); setCurrentRunId(null);
setMessages([]); setMessages([]);
setInputValue(''); setInputValue('');
} }
async function handleNewSession() {
// Create completely new session
const sessionId = await createNewSession();
if (sessionId) {
setCurrentSessionId(sessionId);
setCurrentSessionRuns([]);
setCurrentRunId(null);
setMessages([]);
setInputValue('');
}
}
async function handleSelectSession(sessionId) {
setCurrentSessionId(sessionId);
setCurrentRunId(null);
setMessages([]);
await loadSessionRuns(sessionId);
}
async function loadRun(runId) { async function loadRun(runId) {
setLoading(true); setLoading(true);
try { try {
@@ -301,22 +376,22 @@
// Content area // Content area
React.createElement('div', { className: 'flex-1 overflow-y-auto scrollbar-hide p-2 flex flex-col' }, React.createElement('div', { className: 'flex-1 overflow-y-auto scrollbar-hide p-2 flex flex-col' },
sidebarOpen ? React.createElement(React.Fragment, null, sidebarOpen ? React.createElement(React.Fragment, null,
// New task button (expanded) // New session button (expanded)
React.createElement('button', { React.createElement('button', {
onClick: handleNewTask, onClick: handleNewSession,
className: 'mb-3 px-4 py-2.5 bg-white border border-gray-300 hover:bg-gray-50 rounded-lg transition-colors flex items-center justify-center gap-2 text-gray-700 font-medium' className: 'mb-3 px-4 py-2.5 bg-white border border-gray-300 hover:bg-gray-50 rounded-lg transition-colors flex items-center justify-center gap-2 text-gray-700 font-medium'
}, },
React.createElement('span', { className: 'text-lg' }, '+'), React.createElement('span', { className: 'text-lg' }, '+'),
React.createElement('span', null, '新建会话') React.createElement('span', null, '新建会话')
), ),
// Sessions list // Sessions list
runs.length > 0 && React.createElement('div', { className: 'text-xs text-gray-500 mb-2 px-2' }, '会话列表'), sessions.length > 0 && React.createElement('div', { className: 'text-xs text-gray-500 mb-2 px-2' }, '会话列表'),
runs.map(run => sessions.map(session =>
React.createElement('div', { React.createElement('div', {
key: run.run_id, key: session.session_id,
onClick: () => loadRun(run.run_id), onClick: () => handleSelectSession(session.session_id),
className: `p-3 mb-1 rounded-lg cursor-pointer transition-colors flex items-center gap-2 ${ className: `p-3 mb-1 rounded-lg cursor-pointer transition-colors flex items-center gap-2 ${
currentRunId === run.run_id ? 'bg-gray-100' : 'hover:bg-gray-50' currentSessionId === session.session_id ? 'bg-gray-100' : 'hover:bg-gray-50'
}` }`
}, },
React.createElement('svg', { React.createElement('svg', {
@@ -331,12 +406,12 @@
React.createElement('path', { d: 'M21 15a2 2 0 0 1-2 2H7l-4 4V5a2 2 0 0 1 2-2h14a2 2 0 0 1 2 2z' }) React.createElement('path', { d: 'M21 15a2 2 0 0 1-2 2H7l-4 4V5a2 2 0 0 1 2-2h14a2 2 0 0 1 2 2z' })
), ),
React.createElement('div', { className: 'text-sm text-gray-800 truncate flex-1' }, React.createElement('div', { className: 'text-sm text-gray-800 truncate flex-1' },
run.task_description session.session_name
) )
) )
) )
) : React.createElement('button', { ) : React.createElement('button', {
onClick: handleNewTask, onClick: handleNewSession,
className: 'p-2 text-gray-600 hover:bg-gray-100 rounded-lg transition-colors flex items-center justify-center', className: 'p-2 text-gray-600 hover:bg-gray-100 rounded-lg transition-colors flex items-center justify-center',
title: '新建会话' title: '新建会话'
}, },
@@ -350,10 +425,19 @@
// Main Chat Area // Main Chat Area
React.createElement('div', { className: 'flex-1 flex flex-col bg-white' }, React.createElement('div', { className: 'flex-1 flex flex-col bg-white' },
// Header // Header
React.createElement('div', { className: 'px-4 py-3 border-b border-gray-200 bg-white flex items-center gap-3' }, React.createElement('div', { className: 'px-4 py-3 border-b border-gray-200 bg-white flex items-center justify-between' },
React.createElement('h1', { className: 'text-lg font-normal text-gray-800' }, React.createElement('div', { className: 'flex items-center gap-3' },
'OPRO' React.createElement('h1', { className: 'text-lg font-normal text-gray-800' },
) 'OPRO'
),
currentSessionId && React.createElement('div', { className: 'text-sm text-gray-500' },
sessions.find(s => s.session_id === currentSessionId)?.session_name || '当前会话'
)
),
currentSessionId && React.createElement('button', {
onClick: handleNewTask,
className: 'px-3 py-1.5 text-sm bg-white border border-gray-300 hover:bg-gray-50 rounded-lg transition-colors text-gray-700'
}, '+ 新建任务')
), ),
// Chat Messages // Chat Messages
@@ -490,7 +574,9 @@
) )
), ),
!currentRunId && React.createElement('div', { className: 'text-xs text-gray-500 mt-3 px-4' }, !currentRunId && React.createElement('div', { className: 'text-xs text-gray-500 mt-3 px-4' },
'输入任务描述后AI 将为你生成优化的系统指令' currentSessionId
? '输入任务描述后AI 将为你生成优化的系统指令'
: '点击左侧"新建会话"开始,或输入任务描述自动创建会话'
) )
) )
) )

131
test_session_api.py Normal file
View File

@@ -0,0 +1,131 @@
#!/usr/bin/env python3
"""
Test script for OPRO session-based API
"""
import requests
import json
BASE_URL = "http://127.0.0.1:8010"
def print_section(title):
print(f"\n{'='*60}")
print(f" {title}")
print(f"{'='*60}\n")
def test_session_workflow():
"""Test the complete session-based workflow."""
print_section("1. Create Session")
# Create a new session
response = requests.post(f"{BASE_URL}/opro/session/create")
result = response.json()
if not result.get("success"):
print(f"❌ Failed to create session: {result}")
return
session_id = result["data"]["session_id"]
print(f"✅ Session created: {session_id}")
print(f" Session name: {result['data']['session_name']}")
print_section("2. Create First Run in Session")
# Create first run
create_req = {
"task_description": "将中文翻译成英文",
"test_cases": [
{"input": "你好", "expected_output": "Hello"},
{"input": "谢谢", "expected_output": "Thank you"}
],
"session_id": session_id
}
response = requests.post(f"{BASE_URL}/opro/create", json=create_req)
result = response.json()
if not result.get("success"):
print(f"❌ Failed to create run: {result}")
return
run1_id = result["data"]["run_id"]
print(f"✅ Run 1 created: {run1_id}")
print(f" Task: {result['data']['task_description']}")
print_section("3. Create Second Run in Same Session")
# Create second run in same session
create_req2 = {
"task_description": "将英文翻译成中文",
"test_cases": [
{"input": "Hello", "expected_output": "你好"},
{"input": "Thank you", "expected_output": "谢谢"}
],
"session_id": session_id
}
response = requests.post(f"{BASE_URL}/opro/create", json=create_req2)
result = response.json()
if not result.get("success"):
print(f"❌ Failed to create run 2: {result}")
return
run2_id = result["data"]["run_id"]
print(f"✅ Run 2 created: {run2_id}")
print(f" Task: {result['data']['task_description']}")
print_section("4. Get Session Details")
response = requests.get(f"{BASE_URL}/opro/session/{session_id}")
result = response.json()
if not result.get("success"):
print(f"❌ Failed to get session: {result}")
return
print(f"✅ Session details:")
print(f" Session ID: {result['data']['session_id']}")
print(f" Session name: {result['data']['session_name']}")
print(f" Number of runs: {result['data']['num_runs']}")
print(f" Runs:")
for run in result['data']['runs']:
print(f" - {run['run_id'][:8]}... : {run['task_description']}")
print_section("5. List All Sessions")
response = requests.get(f"{BASE_URL}/opro/sessions")
result = response.json()
if not result.get("success"):
print(f"❌ Failed to list sessions: {result}")
return
print(f"✅ Total sessions: {len(result['data']['sessions'])}")
for session in result['data']['sessions']:
print(f" - {session['session_name']}: {session['num_runs']} runs")
print_section("✅ All Tests Passed!")
if __name__ == "__main__":
try:
# Check if server is running
response = requests.get(f"{BASE_URL}/health")
if response.status_code != 200:
print("❌ Server is not running. Please start it with:")
print(" uvicorn _qwen_xinference_demo.api:app --host 127.0.0.1 --port 8010")
exit(1)
test_session_workflow()
except requests.exceptions.ConnectionError:
print("❌ Cannot connect to server. Please start it with:")
print(" uvicorn _qwen_xinference_demo.api:app --host 127.0.0.1 --port 8010")
exit(1)
except Exception as e:
print(f"❌ Error: {e}")
import traceback
traceback.print_exc()
exit(1)