原始代码
This commit is contained in:
29
_qwen_xinference_demo/opro/xinference_client.py
Normal file
29
_qwen_xinference_demo/opro/xinference_client.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import requests
|
||||
from typing import List
|
||||
import config
|
||||
|
||||
XINFERENCE_EMBED_URL = config.XINFERENCE_EMBED_URL
|
||||
OLLAMA_EMBED_URL = config.OLLAMA_HOST + "/api/embeddings"
|
||||
|
||||
def embed_texts(texts: List[str]) -> List[List[float]]:
|
||||
payload = {"inputs": texts}
|
||||
try:
|
||||
resp = requests.post(XINFERENCE_EMBED_URL, json=payload, timeout=10)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
embs = data.get("embeddings", [])
|
||||
if embs:
|
||||
return embs
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
payload2 = {"model": config.DEFAULT_EMBED_MODEL, "input": texts}
|
||||
resp2 = requests.post(OLLAMA_EMBED_URL, json=payload2, timeout=15)
|
||||
resp2.raise_for_status()
|
||||
data2 = resp2.json()
|
||||
if isinstance(data2, dict) and "data" in data2:
|
||||
return [item.get("embedding", []) for item in data2["data"]]
|
||||
return data2.get("embeddings", [])
|
||||
except Exception:
|
||||
return []
|
||||
Reference in New Issue
Block a user