原始代码
This commit is contained in:
52
user_prompt_optimizer.py
Normal file
52
user_prompt_optimizer.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import re
|
||||
import numpy as np
|
||||
from sklearn.cluster import AgglomerativeClustering
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
from opro.ollama_client import call_qwen
|
||||
from opro.xinference_client import embed_texts
|
||||
from opro.prompt_utils import refine_instruction, refine_instruction_with_history
|
||||
|
||||
def parse_candidates(raw: str) -> list:
|
||||
lines = [l.strip() for l in re.split(r'\r?\n', raw) if l.strip()]
|
||||
cleaned = []
|
||||
for l in lines:
|
||||
l = re.sub(r'^[\-\*\d\.\)\s]+', '', l).strip()
|
||||
if len(l) >= 6:
|
||||
cleaned.append(l)
|
||||
return list(dict.fromkeys(cleaned))
|
||||
|
||||
def cluster_and_select(candidates: list, top_k=5, distance_threshold=0.15):
|
||||
if not candidates:
|
||||
return []
|
||||
vecs = embed_texts(candidates)
|
||||
X = np.array(vecs)
|
||||
if len(candidates) <= top_k:
|
||||
return candidates
|
||||
|
||||
clustering = AgglomerativeClustering(n_clusters=None,
|
||||
distance_threshold=distance_threshold,
|
||||
metric="cosine",
|
||||
linkage="average")
|
||||
labels = clustering.fit_predict(X)
|
||||
|
||||
selected_idx = []
|
||||
for label in sorted(set(labels)):
|
||||
idxs = [i for i,l in enumerate(labels) if l == label]
|
||||
sims = cosine_similarity(X[idxs]).mean(axis=1)
|
||||
rep = idxs[int(np.argmax(sims))]
|
||||
selected_idx.append(rep)
|
||||
|
||||
selected = [candidates[i] for i in sorted(selected_idx)]
|
||||
return selected[:top_k]
|
||||
|
||||
def generate_candidates(query: str, rejected=None, top_k=5):
|
||||
rejected = rejected or []
|
||||
if rejected:
|
||||
prompt = refine_instruction_with_history(query, rejected)
|
||||
else:
|
||||
prompt = refine_instruction(query)
|
||||
|
||||
raw = call_qwen(prompt, temperature=0.9, max_tokens=512)
|
||||
all_candidates = parse_candidates(raw)
|
||||
return cluster_and_select(all_candidates, top_k=top_k)
|
||||
Reference in New Issue
Block a user