Compare commits
2 Commits
1376d60ed5
...
602875b08c
| Author | SHA1 | Date | |
|---|---|---|---|
| 602875b08c | |||
| da30a0999c |
@@ -14,6 +14,7 @@ from .opro.session_state import USER_FEEDBACK_LOG
|
|||||||
|
|
||||||
# True OPRO session management
|
# True OPRO session management
|
||||||
from .opro.session_state import (
|
from .opro.session_state import (
|
||||||
|
create_opro_session, get_opro_session, list_opro_sessions,
|
||||||
create_opro_run, get_opro_run, update_opro_iteration,
|
create_opro_run, get_opro_run, update_opro_iteration,
|
||||||
add_opro_evaluation, get_opro_trajectory, set_opro_test_cases,
|
add_opro_evaluation, get_opro_trajectory, set_opro_test_cases,
|
||||||
complete_opro_run, list_opro_runs
|
complete_opro_run, list_opro_runs
|
||||||
@@ -122,6 +123,7 @@ class CreateOPRORunReq(BaseModel):
|
|||||||
task_description: str
|
task_description: str
|
||||||
test_cases: Optional[List[TestCase]] = None
|
test_cases: Optional[List[TestCase]] = None
|
||||||
model_name: Optional[str] = None
|
model_name: Optional[str] = None
|
||||||
|
session_id: Optional[str] = None # Optional session to associate with
|
||||||
|
|
||||||
|
|
||||||
class OPROIterateReq(BaseModel):
|
class OPROIterateReq(BaseModel):
|
||||||
@@ -360,12 +362,62 @@ def set_model(req: SetModelReq):
|
|||||||
# TRUE OPRO ENDPOINTS (System Instruction Optimization)
|
# TRUE OPRO ENDPOINTS (System Instruction Optimization)
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
# Session Management
|
||||||
|
|
||||||
|
@app.post("/opro/session/create", tags=["opro-true"])
|
||||||
|
def opro_create_session(session_name: str = None):
|
||||||
|
"""
|
||||||
|
Create a new OPRO session that can contain multiple runs.
|
||||||
|
"""
|
||||||
|
session_id = create_opro_session(session_name=session_name)
|
||||||
|
session = get_opro_session(session_id)
|
||||||
|
|
||||||
|
return ok({
|
||||||
|
"session_id": session_id,
|
||||||
|
"session_name": session["session_name"],
|
||||||
|
"num_runs": len(session["run_ids"])
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/opro/sessions", tags=["opro-true"])
|
||||||
|
def opro_list_sessions():
|
||||||
|
"""
|
||||||
|
List all OPRO sessions.
|
||||||
|
"""
|
||||||
|
sessions = list_opro_sessions()
|
||||||
|
return ok({"sessions": sessions})
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/opro/session/{session_id}", tags=["opro-true"])
|
||||||
|
def opro_get_session(session_id: str):
|
||||||
|
"""
|
||||||
|
Get detailed information about an OPRO session.
|
||||||
|
"""
|
||||||
|
session = get_opro_session(session_id)
|
||||||
|
if not session:
|
||||||
|
raise AppException(404, "Session not found", "SESSION_NOT_FOUND")
|
||||||
|
|
||||||
|
# Get all runs in this session
|
||||||
|
runs = list_opro_runs(session_id=session_id)
|
||||||
|
|
||||||
|
return ok({
|
||||||
|
"session_id": session_id,
|
||||||
|
"session_name": session["session_name"],
|
||||||
|
"created_at": session["created_at"],
|
||||||
|
"num_runs": len(session["run_ids"]),
|
||||||
|
"runs": runs
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
# Run Management
|
||||||
|
|
||||||
@app.post("/opro/create", tags=["opro-true"])
|
@app.post("/opro/create", tags=["opro-true"])
|
||||||
def opro_create_run(req: CreateOPRORunReq):
|
def opro_create_run(req: CreateOPRORunReq):
|
||||||
"""
|
"""
|
||||||
Create a new OPRO optimization run.
|
Create a new OPRO optimization run.
|
||||||
|
|
||||||
This starts a new system instruction optimization process for a given task.
|
This starts a new system instruction optimization process for a given task.
|
||||||
|
Optionally can be associated with a session.
|
||||||
"""
|
"""
|
||||||
# Convert test cases from Pydantic models to tuples
|
# Convert test cases from Pydantic models to tuples
|
||||||
test_cases = None
|
test_cases = None
|
||||||
@@ -375,7 +427,8 @@ def opro_create_run(req: CreateOPRORunReq):
|
|||||||
run_id = create_opro_run(
|
run_id = create_opro_run(
|
||||||
task_description=req.task_description,
|
task_description=req.task_description,
|
||||||
test_cases=test_cases,
|
test_cases=test_cases,
|
||||||
model_name=req.model_name
|
model_name=req.model_name,
|
||||||
|
session_id=req.session_id
|
||||||
)
|
)
|
||||||
|
|
||||||
run = get_opro_run(run_id)
|
run = get_opro_run(run_id)
|
||||||
@@ -385,7 +438,8 @@ def opro_create_run(req: CreateOPRORunReq):
|
|||||||
"task_description": run["task_description"],
|
"task_description": run["task_description"],
|
||||||
"num_test_cases": len(run["test_cases"]),
|
"num_test_cases": len(run["test_cases"]),
|
||||||
"iteration": run["iteration"],
|
"iteration": run["iteration"],
|
||||||
"status": run["status"]
|
"status": run["status"],
|
||||||
|
"session_id": run.get("session_id")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
@@ -433,15 +487,14 @@ def opro_evaluate(req: OPROEvaluateReq):
|
|||||||
Evaluate a system instruction on the test cases.
|
Evaluate a system instruction on the test cases.
|
||||||
|
|
||||||
This scores the instruction and updates the performance trajectory.
|
This scores the instruction and updates the performance trajectory.
|
||||||
|
If no test cases are defined, uses a default score of 0.5 to indicate user selection.
|
||||||
"""
|
"""
|
||||||
run = get_opro_run(req.run_id)
|
run = get_opro_run(req.run_id)
|
||||||
if not run:
|
if not run:
|
||||||
raise AppException(404, "OPRO run not found", "RUN_NOT_FOUND")
|
raise AppException(404, "OPRO run not found", "RUN_NOT_FOUND")
|
||||||
|
|
||||||
if not run["test_cases"]:
|
# Evaluate the instruction if test cases exist
|
||||||
raise AppException(400, "No test cases defined for this run", "NO_TEST_CASES")
|
if run["test_cases"] and len(run["test_cases"]) > 0:
|
||||||
|
|
||||||
# Evaluate the instruction
|
|
||||||
try:
|
try:
|
||||||
score = evaluate_system_instruction(
|
score = evaluate_system_instruction(
|
||||||
system_instruction=req.instruction,
|
system_instruction=req.instruction,
|
||||||
@@ -450,6 +503,10 @@ def opro_evaluate(req: OPROEvaluateReq):
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise AppException(500, f"Evaluation failed: {e}", "EVALUATION_ERROR")
|
raise AppException(500, f"Evaluation failed: {e}", "EVALUATION_ERROR")
|
||||||
|
else:
|
||||||
|
# No test cases - use default score to indicate user selection
|
||||||
|
# This allows the trajectory to track which instructions the user preferred
|
||||||
|
score = 0.5
|
||||||
|
|
||||||
# Add to trajectory
|
# Add to trajectory
|
||||||
add_opro_evaluation(req.run_id, req.instruction, score)
|
add_opro_evaluation(req.run_id, req.instruction, score)
|
||||||
@@ -462,7 +519,8 @@ def opro_evaluate(req: OPROEvaluateReq):
|
|||||||
"instruction": req.instruction,
|
"instruction": req.instruction,
|
||||||
"score": score,
|
"score": score,
|
||||||
"best_score": run["best_score"],
|
"best_score": run["best_score"],
|
||||||
"is_new_best": score == run["best_score"] and score > 0
|
"is_new_best": score == run["best_score"] and score > 0,
|
||||||
|
"has_test_cases": len(run["test_cases"]) > 0
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -66,10 +66,57 @@ def set_session_model(sid: str, model_name: str | None):
|
|||||||
# TRUE OPRO SESSION MANAGEMENT
|
# TRUE OPRO SESSION MANAGEMENT
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
# Session storage (contains multiple runs)
|
||||||
|
OPRO_SESSIONS = {}
|
||||||
|
|
||||||
|
def create_opro_session(session_name: str = None) -> str:
|
||||||
|
"""
|
||||||
|
Create a new OPRO session that can contain multiple runs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_name: Optional name for the session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
session_id: Unique identifier for this session
|
||||||
|
"""
|
||||||
|
session_id = uuid.uuid4().hex
|
||||||
|
OPRO_SESSIONS[session_id] = {
|
||||||
|
"session_name": session_name or "新会话", # Will be updated with first task description
|
||||||
|
"created_at": uuid.uuid1().time,
|
||||||
|
"run_ids": [], # List of run IDs in this session
|
||||||
|
"chat_history": [] # Cross-run chat history
|
||||||
|
}
|
||||||
|
return session_id
|
||||||
|
|
||||||
|
|
||||||
|
def get_opro_session(session_id: str) -> Dict[str, Any]:
|
||||||
|
"""Get OPRO session by ID."""
|
||||||
|
return OPRO_SESSIONS.get(session_id)
|
||||||
|
|
||||||
|
|
||||||
|
def list_opro_sessions() -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
List all OPRO sessions with summary information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of session summaries
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"session_id": session_id,
|
||||||
|
"session_name": session["session_name"],
|
||||||
|
"num_runs": len(session["run_ids"]),
|
||||||
|
"created_at": session["created_at"]
|
||||||
|
}
|
||||||
|
for session_id, session in OPRO_SESSIONS.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def create_opro_run(
|
def create_opro_run(
|
||||||
task_description: str,
|
task_description: str,
|
||||||
test_cases: List[Tuple[str, str]] = None,
|
test_cases: List[Tuple[str, str]] = None,
|
||||||
model_name: str = None
|
model_name: str = None,
|
||||||
|
session_id: str = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Create a new OPRO optimization run.
|
Create a new OPRO optimization run.
|
||||||
@@ -78,6 +125,7 @@ def create_opro_run(
|
|||||||
task_description: Description of the task to optimize for
|
task_description: Description of the task to optimize for
|
||||||
test_cases: List of (input, expected_output) tuples for evaluation
|
test_cases: List of (input, expected_output) tuples for evaluation
|
||||||
model_name: Optional model name to use
|
model_name: Optional model name to use
|
||||||
|
session_id: Optional session ID to associate this run with
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
run_id: Unique identifier for this OPRO run
|
run_id: Unique identifier for this OPRO run
|
||||||
@@ -87,6 +135,7 @@ def create_opro_run(
|
|||||||
"task_description": task_description,
|
"task_description": task_description,
|
||||||
"test_cases": test_cases or [],
|
"test_cases": test_cases or [],
|
||||||
"model_name": model_name,
|
"model_name": model_name,
|
||||||
|
"session_id": session_id, # Link to parent session
|
||||||
"iteration": 0,
|
"iteration": 0,
|
||||||
"trajectory": [], # List of (instruction, score) tuples
|
"trajectory": [], # List of (instruction, score) tuples
|
||||||
"best_instruction": None,
|
"best_instruction": None,
|
||||||
@@ -95,6 +144,14 @@ def create_opro_run(
|
|||||||
"created_at": uuid.uuid1().time,
|
"created_at": uuid.uuid1().time,
|
||||||
"status": "active" # active, completed, failed
|
"status": "active" # active, completed, failed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Add run to session if session_id provided
|
||||||
|
if session_id and session_id in OPRO_SESSIONS:
|
||||||
|
OPRO_SESSIONS[session_id]["run_ids"].append(run_id)
|
||||||
|
# Update session name with first task description if it's still default
|
||||||
|
if OPRO_SESSIONS[session_id]["session_name"] == "新会话" and len(OPRO_SESSIONS[session_id]["run_ids"]) == 1:
|
||||||
|
OPRO_SESSIONS[session_id]["session_name"] = task_description
|
||||||
|
|
||||||
return run_id
|
return run_id
|
||||||
|
|
||||||
|
|
||||||
@@ -206,13 +263,22 @@ def complete_opro_run(run_id: str):
|
|||||||
run["status"] = "completed"
|
run["status"] = "completed"
|
||||||
|
|
||||||
|
|
||||||
def list_opro_runs() -> List[Dict[str, Any]]:
|
def list_opro_runs(session_id: str = None) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
List all OPRO runs with summary information.
|
List all OPRO runs with summary information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: Optional session ID to filter runs by session
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of run summaries
|
List of run summaries
|
||||||
"""
|
"""
|
||||||
|
runs_to_list = OPRO_RUNS.items()
|
||||||
|
|
||||||
|
# Filter by session if provided
|
||||||
|
if session_id:
|
||||||
|
runs_to_list = [(rid, r) for rid, r in runs_to_list if r.get("session_id") == session_id]
|
||||||
|
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"run_id": run_id,
|
"run_id": run_id,
|
||||||
@@ -220,7 +286,8 @@ def list_opro_runs() -> List[Dict[str, Any]]:
|
|||||||
"iteration": run["iteration"],
|
"iteration": run["iteration"],
|
||||||
"best_score": run["best_score"],
|
"best_score": run["best_score"],
|
||||||
"num_test_cases": len(run["test_cases"]),
|
"num_test_cases": len(run["test_cases"]),
|
||||||
"status": run["status"]
|
"status": run["status"],
|
||||||
|
"session_id": run.get("session_id")
|
||||||
}
|
}
|
||||||
for run_id, run in OPRO_RUNS.items()
|
for run_id, run in runs_to_list
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -50,18 +50,22 @@
|
|||||||
// Main App Component
|
// Main App Component
|
||||||
function App() {
|
function App() {
|
||||||
const [sidebarOpen, setSidebarOpen] = useState(false);
|
const [sidebarOpen, setSidebarOpen] = useState(false);
|
||||||
const [runs, setRuns] = useState([]);
|
const [sessions, setSessions] = useState([]);
|
||||||
|
const [currentSessionId, setCurrentSessionId] = useState(null);
|
||||||
|
const [currentSessionRuns, setCurrentSessionRuns] = useState([]);
|
||||||
const [currentRunId, setCurrentRunId] = useState(null);
|
const [currentRunId, setCurrentRunId] = useState(null);
|
||||||
const [messages, setMessages] = useState([]);
|
const [messages, setMessages] = useState([]);
|
||||||
|
const [sessionMessages, setSessionMessages] = useState({}); // Store messages per session
|
||||||
|
const [sessionLastRunId, setSessionLastRunId] = useState({}); // Store last run ID per session
|
||||||
const [inputValue, setInputValue] = useState('');
|
const [inputValue, setInputValue] = useState('');
|
||||||
const [loading, setLoading] = useState(false);
|
const [loading, setLoading] = useState(false);
|
||||||
const [models, setModels] = useState([]);
|
const [models, setModels] = useState([]);
|
||||||
const [selectedModel, setSelectedModel] = useState('');
|
const [selectedModel, setSelectedModel] = useState('');
|
||||||
const chatEndRef = useRef(null);
|
const chatEndRef = useRef(null);
|
||||||
|
|
||||||
// Load runs and models on mount
|
// Load sessions and models on mount
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
loadRuns();
|
loadSessions();
|
||||||
loadModels();
|
loadModels();
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
@@ -85,29 +89,78 @@
|
|||||||
chatEndRef.current?.scrollIntoView({ behavior: 'smooth' });
|
chatEndRef.current?.scrollIntoView({ behavior: 'smooth' });
|
||||||
}, [messages]);
|
}, [messages]);
|
||||||
|
|
||||||
async function loadRuns() {
|
async function loadSessions() {
|
||||||
try {
|
try {
|
||||||
const res = await fetch(`${API_BASE}/opro/runs`);
|
const res = await fetch(`${API_BASE}/opro/sessions`);
|
||||||
const data = await res.json();
|
const data = await res.json();
|
||||||
if (data.success) {
|
if (data.success) {
|
||||||
setRuns(data.data.runs || []);
|
setSessions(data.data.sessions || []);
|
||||||
}
|
}
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error('Failed to load runs:', err);
|
console.error('Failed to load sessions:', err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function loadSessionRuns(sessionId) {
|
||||||
|
try {
|
||||||
|
const res = await fetch(`${API_BASE}/opro/session/${sessionId}`);
|
||||||
|
const data = await res.json();
|
||||||
|
if (data.success) {
|
||||||
|
setCurrentSessionRuns(data.data.runs || []);
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
console.error('Failed to load session runs:', err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function createNewSession() {
|
||||||
|
try {
|
||||||
|
const res = await fetch(`${API_BASE}/opro/session/create`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: { 'Content-Type': 'application/json' }
|
||||||
|
});
|
||||||
|
const data = await res.json();
|
||||||
|
|
||||||
|
if (!data.success) {
|
||||||
|
throw new Error(data.error || 'Failed to create session');
|
||||||
|
}
|
||||||
|
|
||||||
|
const sessionId = data.data.session_id;
|
||||||
|
setCurrentSessionId(sessionId);
|
||||||
|
setCurrentSessionRuns([]);
|
||||||
|
setCurrentRunId(null);
|
||||||
|
setMessages([]);
|
||||||
|
setSessionMessages(prev => ({ ...prev, [sessionId]: [] })); // Initialize empty messages for new session
|
||||||
|
|
||||||
|
// Reload sessions list
|
||||||
|
await loadSessions();
|
||||||
|
|
||||||
|
return sessionId;
|
||||||
|
} catch (err) {
|
||||||
|
alert('创建会话失败: ' + err.message);
|
||||||
|
return null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async function createNewRun(taskDescription) {
|
async function createNewRun(taskDescription) {
|
||||||
setLoading(true);
|
setLoading(true);
|
||||||
try {
|
try {
|
||||||
// Create run
|
// Ensure we have a session
|
||||||
|
let sessionId = currentSessionId;
|
||||||
|
if (!sessionId) {
|
||||||
|
sessionId = await createNewSession();
|
||||||
|
if (!sessionId) return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create run within session
|
||||||
const res = await fetch(`${API_BASE}/opro/create`, {
|
const res = await fetch(`${API_BASE}/opro/create`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: { 'Content-Type': 'application/json' },
|
headers: { 'Content-Type': 'application/json' },
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
task_description: taskDescription,
|
task_description: taskDescription,
|
||||||
test_cases: [],
|
test_cases: [],
|
||||||
model_name: selectedModel || undefined
|
model_name: selectedModel || undefined,
|
||||||
|
session_id: sessionId
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
const data = await res.json();
|
const data = await res.json();
|
||||||
@@ -119,16 +172,33 @@
|
|||||||
const runId = data.data.run_id;
|
const runId = data.data.run_id;
|
||||||
setCurrentRunId(runId);
|
setCurrentRunId(runId);
|
||||||
|
|
||||||
// Add user message
|
// Save this as the last run for this session
|
||||||
setMessages([{ role: 'user', content: taskDescription }]);
|
setSessionLastRunId(prev => ({
|
||||||
|
...prev,
|
||||||
|
[sessionId]: runId
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Add user message to existing messages (keep chat history)
|
||||||
|
const newUserMessage = { role: 'user', content: taskDescription };
|
||||||
|
setMessages(prev => {
|
||||||
|
const updated = [...prev, newUserMessage];
|
||||||
|
// Save to session messages
|
||||||
|
setSessionMessages(prevSessions => ({
|
||||||
|
...prevSessions,
|
||||||
|
[sessionId]: updated
|
||||||
|
}));
|
||||||
|
return updated;
|
||||||
|
});
|
||||||
|
|
||||||
// Generate and evaluate candidates
|
// Generate and evaluate candidates
|
||||||
await generateCandidates(runId);
|
await generateCandidates(runId);
|
||||||
|
|
||||||
// Reload runs list
|
// Reload sessions and session runs
|
||||||
await loadRuns();
|
await loadSessions();
|
||||||
|
await loadSessionRuns(sessionId);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
alert('创建任务失败: ' + err.message);
|
alert('创建任务失败: ' + err.message);
|
||||||
|
console.error('Error creating run:', err);
|
||||||
} finally {
|
} finally {
|
||||||
setLoading(false);
|
setLoading(false);
|
||||||
}
|
}
|
||||||
@@ -137,6 +207,7 @@
|
|||||||
async function generateCandidates(runId) {
|
async function generateCandidates(runId) {
|
||||||
setLoading(true);
|
setLoading(true);
|
||||||
try {
|
try {
|
||||||
|
console.log('Generating candidates for run:', runId);
|
||||||
const res = await fetch(`${API_BASE}/opro/generate_and_evaluate`, {
|
const res = await fetch(`${API_BASE}/opro/generate_and_evaluate`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: { 'Content-Type': 'application/json' },
|
headers: { 'Content-Type': 'application/json' },
|
||||||
@@ -148,19 +219,33 @@
|
|||||||
});
|
});
|
||||||
const data = await res.json();
|
const data = await res.json();
|
||||||
|
|
||||||
|
console.log('Generate candidates response:', data);
|
||||||
|
|
||||||
if (!data.success) {
|
if (!data.success) {
|
||||||
throw new Error(data.error || 'Failed to generate candidates');
|
throw new Error(data.error || 'Failed to generate candidates');
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add assistant message with candidates
|
// Add assistant message with candidates
|
||||||
setMessages(prev => [...prev, {
|
const newAssistantMessage = {
|
||||||
role: 'assistant',
|
role: 'assistant',
|
||||||
type: 'candidates',
|
type: 'candidates',
|
||||||
candidates: data.data.candidates,
|
candidates: data.data.candidates,
|
||||||
iteration: data.data.iteration
|
iteration: data.data.iteration
|
||||||
}]);
|
};
|
||||||
|
setMessages(prev => {
|
||||||
|
const updated = [...prev, newAssistantMessage];
|
||||||
|
// Save to session messages
|
||||||
|
if (currentSessionId) {
|
||||||
|
setSessionMessages(prevSessions => ({
|
||||||
|
...prevSessions,
|
||||||
|
[currentSessionId]: updated
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
return updated;
|
||||||
|
});
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
alert('生成候选指令失败: ' + err.message);
|
alert('生成候选指令失败: ' + err.message);
|
||||||
|
console.error('Error generating candidates:', err);
|
||||||
} finally {
|
} finally {
|
||||||
setLoading(false);
|
setLoading(false);
|
||||||
}
|
}
|
||||||
@@ -185,12 +270,23 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add execution result
|
// Add execution result
|
||||||
setMessages(prev => [...prev, {
|
const newExecutionMessage = {
|
||||||
role: 'assistant',
|
role: 'assistant',
|
||||||
type: 'execution',
|
type: 'execution',
|
||||||
instruction: instruction,
|
instruction: instruction,
|
||||||
response: data.data.response
|
response: data.data.response
|
||||||
}]);
|
};
|
||||||
|
setMessages(prev => {
|
||||||
|
const updated = [...prev, newExecutionMessage];
|
||||||
|
// Save to session messages
|
||||||
|
if (currentSessionId) {
|
||||||
|
setSessionMessages(prevSessions => ({
|
||||||
|
...prevSessions,
|
||||||
|
[currentSessionId]: updated
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
return updated;
|
||||||
|
});
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
alert('执行失败: ' + err.message);
|
alert('执行失败: ' + err.message);
|
||||||
} finally {
|
} finally {
|
||||||
@@ -204,25 +300,50 @@
|
|||||||
|
|
||||||
setInputValue('');
|
setInputValue('');
|
||||||
|
|
||||||
if (!currentRunId) {
|
// Always create a new run with the message as task description
|
||||||
// Create new run with task description
|
|
||||||
createNewRun(msg);
|
createNewRun(msg);
|
||||||
} else {
|
}
|
||||||
// Continue optimization or execute
|
|
||||||
// For now, just show message
|
async function handleContinueOptimize(selectedInstruction, selectedScore) {
|
||||||
setMessages(prev => [...prev, { role: 'user', content: msg }]);
|
if (!currentRunId || loading) return;
|
||||||
|
|
||||||
|
// First, evaluate the selected instruction to add it to trajectory
|
||||||
|
if (selectedInstruction) {
|
||||||
|
setLoading(true);
|
||||||
|
try {
|
||||||
|
// Add the selected instruction to trajectory
|
||||||
|
const res = await fetch(`${API_BASE}/opro/evaluate`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
body: JSON.stringify({
|
||||||
|
run_id: currentRunId,
|
||||||
|
instruction: selectedInstruction
|
||||||
|
})
|
||||||
|
});
|
||||||
|
const data = await res.json();
|
||||||
|
|
||||||
|
if (!data.success) {
|
||||||
|
throw new Error(data.error || 'Failed to evaluate instruction');
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log('Evaluated instruction, score:', data.data.score);
|
||||||
|
} catch (err) {
|
||||||
|
alert('评估指令失败: ' + err.message);
|
||||||
|
console.error('Error evaluating instruction:', err);
|
||||||
|
setLoading(false);
|
||||||
|
return;
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleContinueOptimize() {
|
// Then generate new candidates based on updated trajectory
|
||||||
if (!currentRunId || loading) return;
|
await generateCandidates(currentRunId);
|
||||||
generateCandidates(currentRunId);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleExecute(instruction) {
|
function handleExecute(instruction) {
|
||||||
if (loading) return;
|
if (loading) return;
|
||||||
const userInput = prompt('请输入要处理的内容(可选):');
|
executeInstruction(instruction, '');
|
||||||
executeInstruction(instruction, userInput);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleCopyInstruction(instruction) {
|
function handleCopyInstruction(instruction) {
|
||||||
@@ -235,11 +356,33 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
function handleNewTask() {
|
function handleNewTask() {
|
||||||
|
// Create new run within current session
|
||||||
setCurrentRunId(null);
|
setCurrentRunId(null);
|
||||||
setMessages([]);
|
setMessages([]);
|
||||||
setInputValue('');
|
setInputValue('');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function handleNewSession() {
|
||||||
|
// Create completely new session
|
||||||
|
const sessionId = await createNewSession();
|
||||||
|
if (sessionId) {
|
||||||
|
setCurrentSessionId(sessionId);
|
||||||
|
setCurrentSessionRuns([]);
|
||||||
|
setCurrentRunId(null);
|
||||||
|
setMessages([]);
|
||||||
|
setInputValue('');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function handleSelectSession(sessionId) {
|
||||||
|
setCurrentSessionId(sessionId);
|
||||||
|
// Restore the last run ID for this session
|
||||||
|
setCurrentRunId(sessionLastRunId[sessionId] || null);
|
||||||
|
// Load messages from session storage
|
||||||
|
setMessages(sessionMessages[sessionId] || []);
|
||||||
|
await loadSessionRuns(sessionId);
|
||||||
|
}
|
||||||
|
|
||||||
async function loadRun(runId) {
|
async function loadRun(runId) {
|
||||||
setLoading(true);
|
setLoading(true);
|
||||||
try {
|
try {
|
||||||
@@ -301,22 +444,22 @@
|
|||||||
// Content area
|
// Content area
|
||||||
React.createElement('div', { className: 'flex-1 overflow-y-auto scrollbar-hide p-2 flex flex-col' },
|
React.createElement('div', { className: 'flex-1 overflow-y-auto scrollbar-hide p-2 flex flex-col' },
|
||||||
sidebarOpen ? React.createElement(React.Fragment, null,
|
sidebarOpen ? React.createElement(React.Fragment, null,
|
||||||
// New task button (expanded)
|
// New session button (expanded)
|
||||||
React.createElement('button', {
|
React.createElement('button', {
|
||||||
onClick: handleNewTask,
|
onClick: handleNewSession,
|
||||||
className: 'mb-3 px-4 py-2.5 bg-white border border-gray-300 hover:bg-gray-50 rounded-lg transition-colors flex items-center justify-center gap-2 text-gray-700 font-medium'
|
className: 'mb-3 px-4 py-2.5 bg-white border border-gray-300 hover:bg-gray-50 rounded-lg transition-colors flex items-center justify-center gap-2 text-gray-700 font-medium'
|
||||||
},
|
},
|
||||||
React.createElement('span', { className: 'text-lg' }, '+'),
|
React.createElement('span', { className: 'text-lg' }, '+'),
|
||||||
React.createElement('span', null, '新建会话')
|
React.createElement('span', null, '新建会话')
|
||||||
),
|
),
|
||||||
// Sessions list
|
// Sessions list
|
||||||
runs.length > 0 && React.createElement('div', { className: 'text-xs text-gray-500 mb-2 px-2' }, '会话列表'),
|
sessions.length > 0 && React.createElement('div', { className: 'text-xs text-gray-500 mb-2 px-2' }, '会话列表'),
|
||||||
runs.map(run =>
|
sessions.map(session =>
|
||||||
React.createElement('div', {
|
React.createElement('div', {
|
||||||
key: run.run_id,
|
key: session.session_id,
|
||||||
onClick: () => loadRun(run.run_id),
|
onClick: () => handleSelectSession(session.session_id),
|
||||||
className: `p-3 mb-1 rounded-lg cursor-pointer transition-colors flex items-center gap-2 ${
|
className: `p-3 mb-1 rounded-lg cursor-pointer transition-colors flex items-center gap-2 ${
|
||||||
currentRunId === run.run_id ? 'bg-gray-100' : 'hover:bg-gray-50'
|
currentSessionId === session.session_id ? 'bg-gray-100' : 'hover:bg-gray-50'
|
||||||
}`
|
}`
|
||||||
},
|
},
|
||||||
React.createElement('svg', {
|
React.createElement('svg', {
|
||||||
@@ -331,12 +474,12 @@
|
|||||||
React.createElement('path', { d: 'M21 15a2 2 0 0 1-2 2H7l-4 4V5a2 2 0 0 1 2-2h14a2 2 0 0 1 2 2z' })
|
React.createElement('path', { d: 'M21 15a2 2 0 0 1-2 2H7l-4 4V5a2 2 0 0 1 2-2h14a2 2 0 0 1 2 2z' })
|
||||||
),
|
),
|
||||||
React.createElement('div', { className: 'text-sm text-gray-800 truncate flex-1' },
|
React.createElement('div', { className: 'text-sm text-gray-800 truncate flex-1' },
|
||||||
run.task_description
|
session.session_name
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
) : React.createElement('button', {
|
) : React.createElement('button', {
|
||||||
onClick: handleNewTask,
|
onClick: handleNewSession,
|
||||||
className: 'p-2 text-gray-600 hover:bg-gray-100 rounded-lg transition-colors flex items-center justify-center',
|
className: 'p-2 text-gray-600 hover:bg-gray-100 rounded-lg transition-colors flex items-center justify-center',
|
||||||
title: '新建会话'
|
title: '新建会话'
|
||||||
},
|
},
|
||||||
@@ -353,6 +496,9 @@
|
|||||||
React.createElement('div', { className: 'px-4 py-3 border-b border-gray-200 bg-white flex items-center gap-3' },
|
React.createElement('div', { className: 'px-4 py-3 border-b border-gray-200 bg-white flex items-center gap-3' },
|
||||||
React.createElement('h1', { className: 'text-lg font-normal text-gray-800' },
|
React.createElement('h1', { className: 'text-lg font-normal text-gray-800' },
|
||||||
'OPRO'
|
'OPRO'
|
||||||
|
),
|
||||||
|
currentSessionId && React.createElement('div', { className: 'text-sm text-gray-500' },
|
||||||
|
sessions.find(s => s.session_id === currentSessionId)?.session_name || '当前会话'
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
|
|
||||||
@@ -391,7 +537,7 @@
|
|||||||
),
|
),
|
||||||
React.createElement('div', { className: 'flex gap-2' },
|
React.createElement('div', { className: 'flex gap-2' },
|
||||||
React.createElement('button', {
|
React.createElement('button', {
|
||||||
onClick: handleContinueOptimize,
|
onClick: () => handleContinueOptimize(cand.instruction, cand.score),
|
||||||
disabled: loading,
|
disabled: loading,
|
||||||
className: 'px-4 py-2 bg-white border border-gray-300 text-gray-700 rounded-lg hover:bg-gray-50 disabled:bg-gray-100 disabled:text-gray-400 disabled:cursor-not-allowed transition-colors text-sm font-medium'
|
className: 'px-4 py-2 bg-white border border-gray-300 text-gray-700 rounded-lg hover:bg-gray-50 disabled:bg-gray-100 disabled:text-gray-400 disabled:cursor-not-allowed transition-colors text-sm font-medium'
|
||||||
}, '继续优化'),
|
}, '继续优化'),
|
||||||
@@ -404,12 +550,7 @@
|
|||||||
React.createElement('path', { d: 'M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1' })
|
React.createElement('path', { d: 'M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1' })
|
||||||
),
|
),
|
||||||
'复制'
|
'复制'
|
||||||
),
|
)
|
||||||
React.createElement('button', {
|
|
||||||
onClick: () => handleExecute(cand.instruction),
|
|
||||||
disabled: loading,
|
|
||||||
className: 'px-4 py-2 bg-gray-900 text-white rounded-lg hover:bg-gray-800 disabled:bg-gray-300 disabled:cursor-not-allowed transition-colors text-sm font-medium'
|
|
||||||
}, '执行此指令')
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -452,7 +593,7 @@
|
|||||||
handleSendMessage();
|
handleSendMessage();
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
placeholder: currentRunId ? '输入消息...' : '在此输入提示词',
|
placeholder: '输入任务描述,创建新的优化任务...',
|
||||||
disabled: loading,
|
disabled: loading,
|
||||||
rows: 3,
|
rows: 3,
|
||||||
className: 'w-full px-5 pt-4 pb-2 bg-transparent focus:outline-none disabled:bg-transparent text-gray-800 placeholder-gray-500 resize-none'
|
className: 'w-full px-5 pt-4 pb-2 bg-transparent focus:outline-none disabled:bg-transparent text-gray-800 placeholder-gray-500 resize-none'
|
||||||
@@ -489,8 +630,10 @@
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
!currentRunId && React.createElement('div', { className: 'text-xs text-gray-500 mt-3 px-4' },
|
React.createElement('div', { className: 'text-xs text-gray-500 mt-3 px-4' },
|
||||||
'输入任务描述后,AI 将为你生成优化的系统指令'
|
currentSessionId
|
||||||
|
? '输入任务描述,AI 将为你生成优化的系统指令'
|
||||||
|
: '点击左侧"新建会话"开始,或直接输入任务描述自动创建会话'
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
131
test_session_api.py
Normal file
131
test_session_api.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test script for OPRO session-based API
|
||||||
|
"""
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
|
||||||
|
BASE_URL = "http://127.0.0.1:8010"
|
||||||
|
|
||||||
|
def print_section(title):
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f" {title}")
|
||||||
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
|
def test_session_workflow():
|
||||||
|
"""Test the complete session-based workflow."""
|
||||||
|
|
||||||
|
print_section("1. Create Session")
|
||||||
|
|
||||||
|
# Create a new session
|
||||||
|
response = requests.post(f"{BASE_URL}/opro/session/create")
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
if not result.get("success"):
|
||||||
|
print(f"❌ Failed to create session: {result}")
|
||||||
|
return
|
||||||
|
|
||||||
|
session_id = result["data"]["session_id"]
|
||||||
|
print(f"✅ Session created: {session_id}")
|
||||||
|
print(f" Session name: {result['data']['session_name']}")
|
||||||
|
|
||||||
|
print_section("2. Create First Run in Session")
|
||||||
|
|
||||||
|
# Create first run
|
||||||
|
create_req = {
|
||||||
|
"task_description": "将中文翻译成英文",
|
||||||
|
"test_cases": [
|
||||||
|
{"input": "你好", "expected_output": "Hello"},
|
||||||
|
{"input": "谢谢", "expected_output": "Thank you"}
|
||||||
|
],
|
||||||
|
"session_id": session_id
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(f"{BASE_URL}/opro/create", json=create_req)
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
if not result.get("success"):
|
||||||
|
print(f"❌ Failed to create run: {result}")
|
||||||
|
return
|
||||||
|
|
||||||
|
run1_id = result["data"]["run_id"]
|
||||||
|
print(f"✅ Run 1 created: {run1_id}")
|
||||||
|
print(f" Task: {result['data']['task_description']}")
|
||||||
|
|
||||||
|
print_section("3. Create Second Run in Same Session")
|
||||||
|
|
||||||
|
# Create second run in same session
|
||||||
|
create_req2 = {
|
||||||
|
"task_description": "将英文翻译成中文",
|
||||||
|
"test_cases": [
|
||||||
|
{"input": "Hello", "expected_output": "你好"},
|
||||||
|
{"input": "Thank you", "expected_output": "谢谢"}
|
||||||
|
],
|
||||||
|
"session_id": session_id
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(f"{BASE_URL}/opro/create", json=create_req2)
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
if not result.get("success"):
|
||||||
|
print(f"❌ Failed to create run 2: {result}")
|
||||||
|
return
|
||||||
|
|
||||||
|
run2_id = result["data"]["run_id"]
|
||||||
|
print(f"✅ Run 2 created: {run2_id}")
|
||||||
|
print(f" Task: {result['data']['task_description']}")
|
||||||
|
|
||||||
|
print_section("4. Get Session Details")
|
||||||
|
|
||||||
|
response = requests.get(f"{BASE_URL}/opro/session/{session_id}")
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
if not result.get("success"):
|
||||||
|
print(f"❌ Failed to get session: {result}")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"✅ Session details:")
|
||||||
|
print(f" Session ID: {result['data']['session_id']}")
|
||||||
|
print(f" Session name: {result['data']['session_name']}")
|
||||||
|
print(f" Number of runs: {result['data']['num_runs']}")
|
||||||
|
print(f" Runs:")
|
||||||
|
for run in result['data']['runs']:
|
||||||
|
print(f" - {run['run_id'][:8]}... : {run['task_description']}")
|
||||||
|
|
||||||
|
print_section("5. List All Sessions")
|
||||||
|
|
||||||
|
response = requests.get(f"{BASE_URL}/opro/sessions")
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
if not result.get("success"):
|
||||||
|
print(f"❌ Failed to list sessions: {result}")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"✅ Total sessions: {len(result['data']['sessions'])}")
|
||||||
|
for session in result['data']['sessions']:
|
||||||
|
print(f" - {session['session_name']}: {session['num_runs']} runs")
|
||||||
|
|
||||||
|
print_section("✅ All Tests Passed!")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
# Check if server is running
|
||||||
|
response = requests.get(f"{BASE_URL}/health")
|
||||||
|
if response.status_code != 200:
|
||||||
|
print("❌ Server is not running. Please start it with:")
|
||||||
|
print(" uvicorn _qwen_xinference_demo.api:app --host 127.0.0.1 --port 8010")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
test_session_workflow()
|
||||||
|
|
||||||
|
except requests.exceptions.ConnectionError:
|
||||||
|
print("❌ Cannot connect to server. Please start it with:")
|
||||||
|
print(" uvicorn _qwen_xinference_demo.api:app --host 127.0.0.1 --port 8010")
|
||||||
|
exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
exit(1)
|
||||||
|
|
||||||
Reference in New Issue
Block a user