148 lines
5.0 KiB
Python
148 lines
5.0 KiB
Python
"""Agnes(及任意 OpenAI 兼容端点)的 LLM 客户端。
|
|
|
|
设计:
|
|
- 内部持 chat 和 image 两个 Semaphore(各 1 个并发),互不阻塞
|
|
- 每次调用后 await asyncio.sleep(interval_sec) 节流
|
|
- 失败重试 1 次,再失败抛异常由上层标记 status=failed
|
|
- 用 httpx.AsyncClient,超时 60s
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
from typing import Any
|
|
|
|
import httpx
|
|
|
|
from app.config import settings as app_settings
|
|
|
|
logger = logging.getLogger("news.llm.client")
|
|
|
|
|
|
class LlmClient:
|
|
"""单一客户端,所有 LLM 调用都过它。"""
|
|
|
|
def __init__(
|
|
self,
|
|
base_url: str | None = None,
|
|
api_key: str | None = None,
|
|
chat_model: str | None = None,
|
|
image_model: str | None = None,
|
|
interval_sec: float | None = None,
|
|
):
|
|
self.base_url = (base_url or app_settings.agnes_base_url).rstrip("/")
|
|
self.api_key = api_key or app_settings.agnes_api_key
|
|
self.chat_model = chat_model or app_settings.agnes_chat_model
|
|
self.image_model = image_model or app_settings.agnes_image_model
|
|
self.interval_sec = (
|
|
interval_sec if interval_sec is not None else app_settings.llm_interval_sec
|
|
)
|
|
# chat 和 image 各一个串行信号
|
|
self._chat_sem = asyncio.Semaphore(1)
|
|
self._image_sem = asyncio.Semaphore(1)
|
|
|
|
def is_configured(self) -> bool:
|
|
return bool(self.api_key)
|
|
|
|
async def chat(
|
|
self,
|
|
system: str,
|
|
user: str,
|
|
*,
|
|
temperature: float = 0.4,
|
|
max_tokens: int = 1500,
|
|
model: str | None = None,
|
|
) -> str:
|
|
"""调 chat/completions,返回 assistant 文本。"""
|
|
if not self.is_configured():
|
|
raise RuntimeError("AGNES_API_KEY 未配置")
|
|
url = f"{self.base_url}/chat/completions"
|
|
payload = {
|
|
"model": model or self.chat_model,
|
|
"messages": [
|
|
{"role": "system", "content": system},
|
|
{"role": "user", "content": user},
|
|
],
|
|
"temperature": temperature,
|
|
"max_tokens": max_tokens,
|
|
}
|
|
async with self._chat_sem:
|
|
res = await self._post_with_retry(url, payload)
|
|
await asyncio.sleep(self.interval_sec)
|
|
return res["choices"][0]["message"]["content"].strip()
|
|
|
|
async def classify_json(
|
|
self,
|
|
system: str,
|
|
user: str,
|
|
*,
|
|
max_tokens: int = 200,
|
|
) -> dict[str, Any]:
|
|
"""调 chat 并尝试解析 JSON。失败时回退:返回空 dict。"""
|
|
text = await self.chat(system, user, temperature=0.2, max_tokens=max_tokens)
|
|
# 容错解析:可能被 ```json ... ``` 包裹
|
|
text = text.strip()
|
|
if text.startswith("```"):
|
|
# 去掉代码块围栏
|
|
lines = text.split("\n")
|
|
text = "\n".join(l for l in lines if not l.strip().startswith("```"))
|
|
text = text.strip()
|
|
import json
|
|
try:
|
|
return json.loads(text)
|
|
except Exception as e:
|
|
logger.warning("classify_json 解析失败: %s; raw=%r", e, text[:200])
|
|
return {}
|
|
|
|
async def generate_image(
|
|
self,
|
|
prompt: str,
|
|
*,
|
|
size: str = "1024x768",
|
|
model: str | None = None,
|
|
) -> str:
|
|
"""调 images/generations,返回图片 URL。"""
|
|
if not self.is_configured():
|
|
raise RuntimeError("AGNES_API_KEY 未配置")
|
|
url = f"{self.base_url}/images/generations"
|
|
payload = {
|
|
"model": model or self.image_model,
|
|
"prompt": prompt,
|
|
"size": size,
|
|
}
|
|
async with self._image_sem:
|
|
res = await self._post_with_retry(url, payload, timeout=120)
|
|
await asyncio.sleep(self.interval_sec)
|
|
return res["data"][0]["url"]
|
|
|
|
async def _post_with_retry(
|
|
self, url: str, payload: dict, *, timeout: float = 60.0, retries: int = 1
|
|
) -> dict:
|
|
"""POST + 简单重试(对 5xx / 超时)。"""
|
|
headers = {
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
last_exc: Exception | None = None
|
|
for attempt in range(retries + 1):
|
|
try:
|
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
r = await client.post(url, json=payload, headers=headers)
|
|
if r.status_code >= 500:
|
|
raise RuntimeError(f"LLM 5xx: {r.status_code} {r.text[:200]}")
|
|
if r.status_code != 200:
|
|
raise RuntimeError(f"LLM {r.status_code}: {r.text[:300]}")
|
|
return r.json()
|
|
except Exception as e:
|
|
last_exc = e
|
|
if attempt < retries:
|
|
wait = 2 ** attempt
|
|
logger.warning("LLM 调用失败,%.1fs 后重试: %s", wait, e)
|
|
await asyncio.sleep(wait)
|
|
assert last_exc is not None
|
|
raise last_exc
|
|
|
|
|
|
# 全局单例(读环境变量 + 启动时初始化)
|
|
client = LlmClient()
|