30 lines
948 B
Python
30 lines
948 B
Python
|
|
import requests
|
||
|
|
from typing import List
|
||
|
|
import config
|
||
|
|
|
||
|
|
XINFERENCE_EMBED_URL = config.XINFERENCE_EMBED_URL
|
||
|
|
OLLAMA_EMBED_URL = config.OLLAMA_HOST + "/api/embeddings"
|
||
|
|
|
||
|
|
def embed_texts(texts: List[str]) -> List[List[float]]:
|
||
|
|
payload = {"inputs": texts}
|
||
|
|
try:
|
||
|
|
resp = requests.post(XINFERENCE_EMBED_URL, json=payload, timeout=10)
|
||
|
|
resp.raise_for_status()
|
||
|
|
data = resp.json()
|
||
|
|
embs = data.get("embeddings", [])
|
||
|
|
if embs:
|
||
|
|
return embs
|
||
|
|
except Exception:
|
||
|
|
pass
|
||
|
|
|
||
|
|
try:
|
||
|
|
payload2 = {"model": config.DEFAULT_EMBED_MODEL, "input": texts}
|
||
|
|
resp2 = requests.post(OLLAMA_EMBED_URL, json=payload2, timeout=15)
|
||
|
|
resp2.raise_for_status()
|
||
|
|
data2 = resp2.json()
|
||
|
|
if isinstance(data2, dict) and "data" in data2:
|
||
|
|
return [item.get("embedding", []) for item in data2["data"]]
|
||
|
|
return data2.get("embeddings", [])
|
||
|
|
except Exception:
|
||
|
|
return []
|