56 lines
2.0 KiB
Python
56 lines
2.0 KiB
Python
|
|
import re
|
||
|
|
import numpy as np
|
||
|
|
from sklearn.cluster import AgglomerativeClustering
|
||
|
|
from sklearn.metrics.pairwise import cosine_similarity
|
||
|
|
import config
|
||
|
|
|
||
|
|
from .ollama_client import call_qwen
|
||
|
|
from .xinference_client import embed_texts
|
||
|
|
from .prompt_utils import refine_instruction, refine_instruction_with_history
|
||
|
|
|
||
|
|
def parse_candidates(raw: str) -> list:
|
||
|
|
lines = [l.strip() for l in re.split(r'\r?\n', raw) if l.strip()]
|
||
|
|
cleaned = []
|
||
|
|
for l in lines:
|
||
|
|
l = re.sub(r'^[\-\*\d\.\)\s]+', '', l).strip()
|
||
|
|
if len(l) >= 6:
|
||
|
|
cleaned.append(l)
|
||
|
|
return list(dict.fromkeys(cleaned))
|
||
|
|
|
||
|
|
def cluster_and_select(candidates: list, top_k=config.TOP_K, distance_threshold=config.CLUSTER_DISTANCE_THRESHOLD):
|
||
|
|
if not candidates:
|
||
|
|
return []
|
||
|
|
if len(candidates) <= top_k:
|
||
|
|
return candidates
|
||
|
|
vecs = embed_texts(candidates)
|
||
|
|
if not vecs or len(vecs) != len(candidates):
|
||
|
|
return candidates[:top_k]
|
||
|
|
X = np.array(vecs)
|
||
|
|
|
||
|
|
clustering = AgglomerativeClustering(n_clusters=None,
|
||
|
|
distance_threshold=distance_threshold,
|
||
|
|
metric="cosine",
|
||
|
|
linkage="average")
|
||
|
|
labels = clustering.fit_predict(X)
|
||
|
|
|
||
|
|
selected_idx = []
|
||
|
|
for label in sorted(set(labels)):
|
||
|
|
idxs = [i for i,l in enumerate(labels) if l == label]
|
||
|
|
sims = cosine_similarity(X[idxs]).mean(axis=1)
|
||
|
|
rep = idxs[int(np.argmax(sims))]
|
||
|
|
selected_idx.append(rep)
|
||
|
|
|
||
|
|
selected = [candidates[i] for i in sorted(selected_idx)]
|
||
|
|
return selected[:top_k]
|
||
|
|
|
||
|
|
def generate_candidates(query: str, rejected=None, top_k=config.TOP_K, model_name=None):
|
||
|
|
rejected = rejected or []
|
||
|
|
if rejected:
|
||
|
|
prompt = refine_instruction_with_history(query, rejected)
|
||
|
|
else:
|
||
|
|
prompt = refine_instruction(query)
|
||
|
|
|
||
|
|
raw = call_qwen(prompt, temperature=0.9, max_tokens=1024, model_name=model_name)
|
||
|
|
all_candidates = parse_candidates(raw)
|
||
|
|
return cluster_and_select(all_candidates, top_k=top_k)
|