Files
opro_demo/_qwen_xinference_demo/opro/session_state.py

227 lines
5.9 KiB
Python
Raw Normal View History

2025-12-05 07:11:25 +00:00
import uuid
from typing import List, Tuple, Dict, Any
2025-12-05 07:11:25 +00:00
# Legacy session storage (for query rewriting)
2025-12-05 07:11:25 +00:00
SESSIONS = {}
USER_FEEDBACK_LOG = []
# OPRO session storage (for system instruction optimization)
OPRO_RUNS = {}
OPRO_RUN_LOG = []
2025-12-05 07:11:25 +00:00
def create_session(query: str) -> str:
sid = uuid.uuid4().hex
SESSIONS[sid] = {
"original_query": query,
"round": 0,
"history_candidates": [],
"user_feedback": [],
"rejected": [],
"selected_prompt": None,
"chat_history": [],
"model_name": None
}
return sid
def get_session(sid: str):
return SESSIONS.get(sid)
def update_session_add_candidates(sid: str, candidates: list):
s = SESSIONS[sid]
s["round"] += 1
s["history_candidates"].extend(candidates)
def log_user_choice(sid: str, choice: str):
SESSIONS[sid]["user_feedback"].append(
{"round": SESSIONS[sid]["round"], "choice": choice}
)
USER_FEEDBACK_LOG.append({
"session_id": sid,
"round": SESSIONS[sid]["round"],
"choice": choice
})
def log_user_reject(sid: str, candidate: str, reason: str | None = None):
SESSIONS[sid]["rejected"].append(candidate)
USER_FEEDBACK_LOG.append({
"session_id": sid,
"round": SESSIONS[sid]["round"],
"reject": candidate,
"reason": reason or ""
})
def set_selected_prompt(sid: str, prompt: str):
SESSIONS[sid]["selected_prompt"] = prompt
def log_chat_message(sid: str, role: str, content: str):
SESSIONS[sid]["chat_history"].append({"role": role, "content": content})
def set_session_model(sid: str, model_name: str | None):
s = SESSIONS.get(sid)
if s is not None:
s["model_name"] = model_name
# ============================================================================
# TRUE OPRO SESSION MANAGEMENT
# ============================================================================
def create_opro_run(
task_description: str,
test_cases: List[Tuple[str, str]] = None,
model_name: str = None
) -> str:
"""
Create a new OPRO optimization run.
Args:
task_description: Description of the task to optimize for
test_cases: List of (input, expected_output) tuples for evaluation
model_name: Optional model name to use
Returns:
run_id: Unique identifier for this OPRO run
"""
run_id = uuid.uuid4().hex
OPRO_RUNS[run_id] = {
"task_description": task_description,
"test_cases": test_cases or [],
"model_name": model_name,
"iteration": 0,
"trajectory": [], # List of (instruction, score) tuples
"best_instruction": None,
"best_score": 0.0,
"current_candidates": [],
"created_at": uuid.uuid1().time,
"status": "active" # active, completed, failed
}
return run_id
def get_opro_run(run_id: str) -> Dict[str, Any]:
"""Get OPRO run by ID."""
return OPRO_RUNS.get(run_id)
def update_opro_iteration(
run_id: str,
candidates: List[str],
scores: List[float] = None
):
"""
Update OPRO run with new iteration results.
Args:
run_id: OPRO run identifier
candidates: List of system instruction candidates
scores: Optional list of scores (if evaluated)
"""
run = OPRO_RUNS.get(run_id)
if not run:
return
run["iteration"] += 1
run["current_candidates"] = candidates
# If scores provided, update trajectory
if scores and len(scores) == len(candidates):
for candidate, score in zip(candidates, scores):
run["trajectory"].append((candidate, score))
# Update best if this is better
if score > run["best_score"]:
run["best_score"] = score
run["best_instruction"] = candidate
# Log the iteration
OPRO_RUN_LOG.append({
"run_id": run_id,
"iteration": run["iteration"],
"num_candidates": len(candidates),
"best_score": run["best_score"]
})
def add_opro_evaluation(
run_id: str,
instruction: str,
score: float
):
"""
Add a single evaluation result to OPRO run.
Args:
run_id: OPRO run identifier
instruction: System instruction that was evaluated
score: Performance score
"""
run = OPRO_RUNS.get(run_id)
if not run:
return
# Add to trajectory
run["trajectory"].append((instruction, score))
# Update best if this is better
if score > run["best_score"]:
run["best_score"] = score
run["best_instruction"] = instruction
def get_opro_trajectory(run_id: str) -> List[Tuple[str, float]]:
"""
Get the performance trajectory for an OPRO run.
Returns:
List of (instruction, score) tuples sorted by score (highest first)
"""
run = OPRO_RUNS.get(run_id)
if not run:
return []
trajectory = run["trajectory"]
return sorted(trajectory, key=lambda x: x[1], reverse=True)
def set_opro_test_cases(
run_id: str,
test_cases: List[Tuple[str, str]]
):
"""
Set or update test cases for an OPRO run.
Args:
run_id: OPRO run identifier
test_cases: List of (input, expected_output) tuples
"""
run = OPRO_RUNS.get(run_id)
if run:
run["test_cases"] = test_cases
def complete_opro_run(run_id: str):
"""Mark an OPRO run as completed."""
run = OPRO_RUNS.get(run_id)
if run:
run["status"] = "completed"
def list_opro_runs() -> List[Dict[str, Any]]:
"""
List all OPRO runs with summary information.
Returns:
List of run summaries
"""
return [
{
"run_id": run_id,
"task_description": run["task_description"][:100] + "..." if len(run["task_description"]) > 100 else run["task_description"],
"iteration": run["iteration"],
"best_score": run["best_score"],
"num_test_cases": len(run["test_cases"]),
"status": run["status"]
}
for run_id, run in OPRO_RUNS.items()
]