Files
opro_demo/test_opro_api.py

185 lines
6.2 KiB
Python
Raw Permalink Normal View History

#!/usr/bin/env python3
"""
Test script for TRUE OPRO API endpoints.
This script tests the complete OPRO workflow:
1. Create OPRO run
2. Generate initial candidates
3. Evaluate candidates
4. Generate optimized candidates
5. View results
Usage:
python test_opro_api.py
"""
import requests
import json
import time
BASE_URL = "http://127.0.0.1:8010"
def print_section(title):
"""Print a section header."""
print("\n" + "=" * 60)
print(f" {title}")
print("=" * 60)
def test_opro_workflow():
"""Test the complete OPRO workflow."""
print_section("1. Create OPRO Run")
# Create a new OPRO run
create_req = {
"task_description": "将用户输入的中文翻译成英文,要求准确自然",
"test_cases": [
{"input": "你好", "expected_output": "Hello"},
{"input": "谢谢", "expected_output": "Thank you"},
{"input": "早上好", "expected_output": "Good morning"},
{"input": "晚安", "expected_output": "Good night"},
{"input": "再见", "expected_output": "Goodbye"}
]
}
response = requests.post(f"{BASE_URL}/opro/create", json=create_req)
result = response.json()
if not result.get("success"):
print(f"❌ Failed to create OPRO run: {result}")
return
run_id = result["data"]["run_id"]
print(f"✅ Created OPRO run: {run_id}")
print(f" Task: {result['data']['task_description']}")
print(f" Test cases: {result['data']['num_test_cases']}")
# ========================================================================
print_section("2. Generate Initial Candidates")
iterate_req = {"run_id": run_id, "top_k": 5}
response = requests.post(f"{BASE_URL}/opro/iterate", json=iterate_req)
result = response.json()
if not result.get("success"):
print(f"❌ Failed to generate candidates: {result}")
return
candidates = result["data"]["candidates"]
print(f"✅ Generated {len(candidates)} initial candidates:")
for i, candidate in enumerate(candidates, 1):
print(f"\n [{i}] {candidate[:100]}...")
# ========================================================================
print_section("3. Evaluate Candidates")
scores = []
for i, candidate in enumerate(candidates, 1):
print(f"\n Evaluating candidate {i}/{len(candidates)}...")
eval_req = {
"run_id": run_id,
"instruction": candidate
}
response = requests.post(f"{BASE_URL}/opro/evaluate", json=eval_req)
result = response.json()
if result.get("success"):
score = result["data"]["score"]
scores.append(score)
is_best = "🏆" if result["data"]["is_new_best"] else ""
print(f" ✅ Score: {score:.4f} {is_best}")
else:
print(f" ❌ Evaluation failed: {result}")
time.sleep(0.5) # Small delay to avoid overwhelming the API
print(f"\n Average score: {sum(scores)/len(scores):.4f}")
print(f" Best score: {max(scores):.4f}")
# ========================================================================
print_section("4. Generate Optimized Candidates (Iteration 2)")
print(" Generating candidates based on performance trajectory...")
iterate_req = {"run_id": run_id, "top_k": 5}
response = requests.post(f"{BASE_URL}/opro/iterate", json=iterate_req)
result = response.json()
if not result.get("success"):
print(f"❌ Failed to generate optimized candidates: {result}")
return
optimized_candidates = result["data"]["candidates"]
print(f"✅ Generated {len(optimized_candidates)} optimized candidates:")
for i, candidate in enumerate(optimized_candidates, 1):
print(f"\n [{i}] {candidate[:100]}...")
# ========================================================================
print_section("5. View Run Details")
response = requests.get(f"{BASE_URL}/opro/run/{run_id}")
result = response.json()
if not result.get("success"):
print(f"❌ Failed to get run details: {result}")
return
data = result["data"]
print(f"✅ OPRO Run Details:")
print(f" Run ID: {data['run_id']}")
print(f" Task: {data['task_description']}")
print(f" Iteration: {data['iteration']}")
print(f" Status: {data['status']}")
print(f" Best Score: {data['best_score']:.4f}")
print(f"\n Best Instruction:")
print(f" {data['best_instruction'][:200]}...")
print(f"\n Top 5 Trajectory:")
for i, item in enumerate(data['trajectory'][:5], 1):
print(f" [{i}] Score: {item['score']:.4f}")
print(f" {item['instruction'][:80]}...")
# ========================================================================
print_section("6. List All Runs")
response = requests.get(f"{BASE_URL}/opro/runs")
result = response.json()
if result.get("success"):
runs = result["data"]["runs"]
print(f"✅ Total OPRO runs: {result['data']['total']}")
for run in runs:
print(f"\n Run: {run['run_id']}")
print(f" Task: {run['task_description'][:50]}...")
print(f" Iteration: {run['iteration']}, Best Score: {run['best_score']:.4f}")
print_section("✅ OPRO Workflow Test Complete!")
print(f"\nRun ID: {run_id}")
print("You can view details at:")
print(f" {BASE_URL}/opro/run/{run_id}")
if __name__ == "__main__":
print("=" * 60)
print(" TRUE OPRO API Test")
print("=" * 60)
print(f"\nBase URL: {BASE_URL}")
print("\nMake sure the API server is running:")
print(" uvicorn _qwen_xinference_demo.api:app --host 127.0.0.1 --port 8010")
print("\nStarting test in 3 seconds...")
time.sleep(3)
try:
test_opro_workflow()
except requests.exceptions.ConnectionError:
print("\n❌ ERROR: Could not connect to API server")
print("Please start the server first:")
print(" uvicorn _qwen_xinference_demo.api:app --host 127.0.0.1 --port 8010")
except Exception as e:
print(f"\n❌ ERROR: {e}")
import traceback
traceback.print_exc()