"""/me 当前用户信息 + 翻译配额 + 已读文章 + 按分类批量已读。""" from __future__ import annotations from datetime import datetime, timedelta, timezone from typing import Literal from fastapi import APIRouter, Depends, HTTPException, Query, status from pydantic import BaseModel, Field from sqlalchemy import and_, delete, func, not_, or_, select, text from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.ext.asyncio import AsyncSession from app.config import settings from app.core.deps import get_current_user from app.database import get_session from app.models.article import Article from app.models.article_read import ArticleRead from app.models.source import Source from app.models.user import User from app.redis_client import get_redis router = APIRouter(prefix="/me", tags=["me"]) class MeOut(BaseModel): id: int username: str email: str | None role: str display_name: str | None created_at: datetime class UsageOut(BaseModel): month: str used_chars: int quota_chars: int remaining_chars: int buffered_quota: int pct_used: float class ReadToggleResponse(BaseModel): article_id: int is_read: bool class ReadListResponse(BaseModel): """已读文章 ID 列表(给前端用,Feed 列表展示).""" article_ids: list[int] total: int @router.get("", response_model=MeOut) async def me(user: User = Depends(get_current_user)): return MeOut( id=user.id, username=user.username, email=user.email, role=user.role.value, display_name=user.display_name, created_at=user.created_at, ) @router.get("/usage", response_model=UsageOut) async def usage( user: User = Depends(get_current_user), session: AsyncSession = Depends(get_session), # noqa: ARG001 ): r = get_redis() now = datetime.now(timezone.utc) key = f"translation:month:{now:%Y%m}" used = int(await r.get(key) or 0) quota = settings.tencent_tmt_quota_month buffered = int(quota * (1 - settings.tencent_tmt_quota_buffer)) remaining = max(0, quota - used) return UsageOut( month=f"{now:%Y%m}", used_chars=used, quota_chars=quota, remaining_chars=remaining, buffered_quota=buffered, pct_used=round(used / quota * 100, 2) if quota else 0.0, ) # === 已读文章(per-user) === @router.post("/reads/{article_id}", response_model=ReadToggleResponse) async def mark_read( article_id: int, user: User = Depends(get_current_user), session: AsyncSession = Depends(get_session), ): """标记文章为已读(幂等,重复调用不报错)。""" # 先确认文章存在(不存在返 404,避免 FK 报错) exists = (await session.execute(select(Article.id).where(Article.id == article_id))).scalar_one_or_none() if not exists: raise HTTPException(status.HTTP_404_NOT_FOUND, "Article not found") # 用 INSERT ... ON CONFLICT DO NOTHING (PG 原生)— 等价 upsert 跳过 stmt = pg_insert(ArticleRead).values(user_id=user.id, article_id=article_id).on_conflict_do_nothing( index_elements=["user_id", "article_id"] ) await session.execute(stmt) await session.commit() return ReadToggleResponse(article_id=article_id, is_read=True) @router.delete("/reads/{article_id}", response_model=ReadToggleResponse) async def unmark_read( article_id: int, user: User = Depends(get_current_user), session: AsyncSession = Depends(get_session), ): """取消已读(幂等)。""" stmt = delete(ArticleRead).where( and_(ArticleRead.user_id == user.id, ArticleRead.article_id == article_id) ) res = await session.execute(stmt) await session.commit() return ReadToggleResponse(article_id=article_id, is_read=False) @router.get("/reads", response_model=ReadListResponse) async def list_reads( since: datetime | None = None, limit: int = 500, user: User = Depends(get_current_user), session: AsyncSession = Depends(get_session), ): """列出当前用户已读文章 ID(给 Feed 过滤 / 已读时间线用)。 - since: 只返回 read_at >= since 的(默认 7 天前,避免数据爆炸) - limit: 最多返回多少(默认 500) """ if since is None: since = datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(days=7) stmt = ( select(ArticleRead.article_id) .where(ArticleRead.user_id == user.id, ArticleRead.read_at >= since) .order_by(ArticleRead.read_at.desc()) .limit(min(limit, 2000)) ) rows = (await session.execute(stmt)).all() ids = [r[0] for r in rows] return ReadListResponse(article_ids=ids, total=len(ids)) # === 按分类批量已读(per-user) === # 用例:用户在 Feed 标了一条《xxx》为已读,该文章 category 含 "美国" "社会" 两个 tag, # 前端会同时查这两个分类下"24 小时未读数"展示提示条,点头部的"全部已读"调本接口。 # 设计要点: # - category 用 ilike 模糊匹配(Article.category 是逗号分隔字符串 '美国,社会,紧急') # - 默认遵守当前 Feed 过滤(关键词/源)— scope=filtered_unread,避免误杀 # - 时间窗默认 24h(防"老新闻刷不完"的提示) # - 批量插入用 ON CONFLICT DO NOTHING,幂等 # - 返回 article_ids 给前端做乐观滑出动画 def _unread_subquery(user_id: int): """未读文章 id 子查询(per-user)— 复用 articles.py 的 NOT EXISTS 模式。""" return ( select(ArticleRead.article_id) .where(ArticleRead.user_id == user_id, ArticleRead.article_id == Article.id) .exists() ) def _escape_like(term: str) -> str: """转义 ilike 的元字符,避免 category 含 % / _ 时被当通配符。""" return term.replace("\\", "\\\\").replace("%", r"\%").replace("_", r"\_") class CategoryReadRequest(BaseModel): category: str = Field(..., min_length=1, max_length=32, description="要批量已读的分类 tag") scope: Literal["all_unread", "filtered_unread"] = Field( default="filtered_unread", description="过滤范围:filtered_unread=遵守 sources/q;all_unread=忽略", ) window_hours: int = Field( default=24, ge=1, le=168, description="只标 published_at 在最近 N 小时内的未读" ) sources: list[str] | None = Field(default=None, description="源 slug 列表,过滤范围") q: str | None = Field(default=None, description="关键词过滤(标题/正文模糊)") class CategoryReadResponse(BaseModel): category: str matched: int # 命中的未读数 marked: int # 实际新标已读数(去重后) article_ids: list[int] # 给前端做滑出动画 class CategoryCountItem(BaseModel): category: str unread_count: int window_hours: int def _build_category_filter( category: str, window_hours: int, user_id: int, sources: list[str] | None = None, q: str | None = None, ): """构造"分类 + 时间窗 + 未读"三合一 filter。 - category 用 ilike 模糊匹配(逗号分隔串中含此 tag 即可) - window_hours 通过 published_at 过滤 - 用 NOT EXISTS 排除当前用户已读 - sources / q 可选,跟 articles.py 列表查询口径一致 """ pattern = f"%{_escape_like(category)}%" # 用 text() 拼 interval,避免 make_interval 的 PG 函数参数绑定在某些 # SQLAlchemy 版本下的兼容性问题。window_hours 是 int,只来自我们自己的 # endpoint,不是用户原始字符串,安全。 interval_sql = text(f"interval '{int(window_hours)} hours'") filters = [ Article.category.ilike(pattern, escape="\\"), Article.published_at >= func.now() - interval_sql, not_(_unread_subquery(user_id)), ] if sources: slugs = [s.strip() for s in sources if s.strip()] if slugs: filters.append(Source.slug.in_(slugs)) if q and q.strip(): like = f"%{q.strip()}%" filters.append( or_( Article.title.ilike(like), Article.body_text.ilike(like), Article.title_zh.ilike(like), Article.body_zh_text.ilike(like), Article.summary_zh.ilike(like), ) ) return filters @router.post("/reads/by-category", response_model=CategoryReadResponse) async def mark_category_read( body: CategoryReadRequest, user: User = Depends(get_current_user), session: AsyncSession = Depends(get_session), ): """按分类批量已读。 - 默认 scope=filtered_unread:只在当前 Feed 过滤范围内标 - 时间窗默认 24h:防"老新闻刷不完" - 幂等:已读的不重复处理 """ cat = body.category.strip() if not cat: raise HTTPException(status.HTTP_400_BAD_REQUEST, "category 不能为空") # scope=all_unread 时,无视 sources/q(传了也忽略,语义更清晰) if body.scope == "all_unread": eff_sources: list[str] | None = None eff_q: str | None = None else: eff_sources = body.sources eff_q = body.q filters = _build_category_filter( category=cat, window_hours=body.window_hours, user_id=user.id, sources=eff_sources, q=eff_q, ) # 1. 查命中 id 列表(返回给前端) id_stmt = select(Article.id) # sources 过滤需要 join if eff_sources: id_stmt = id_stmt.join(Source, Source.id == Article.source_id) for f in filters: id_stmt = id_stmt.where(f) article_ids = [r[0] for r in (await session.execute(id_stmt)).all()] if not article_ids: return CategoryReadResponse(category=cat, matched=0, marked=0, article_ids=[]) # 2. 批量插入 article_reads(ON CONFLICT DO NOTHING,幂等) # 分批:防御性,避免单次 VALUES 太多;500 是经验值,PG 完全能吃更大 BATCH = 500 marked_total = 0 for i in range(0, len(article_ids), BATCH): chunk = article_ids[i : i + BATCH] rows = [{"user_id": user.id, "article_id": aid} for aid in chunk] stmt = pg_insert(ArticleRead).values(rows).on_conflict_do_nothing( index_elements=["user_id", "article_id"] ) # PG ON CONFLICT 不直接告诉哪些是"真插入的",通过 result() 或者信任总数; # 这里简单信任 chunk 大小,marked 返回 attempt 数(去重由 PK 保证) await session.execute(stmt) marked_total += len(chunk) await session.commit() return CategoryReadResponse( category=cat, matched=len(article_ids), marked=marked_total, article_ids=article_ids, ) @router.get("/reads/category-count", response_model=list[CategoryCountItem]) async def count_unread_by_categories( categories: str = Query(..., description="逗号分隔的分类列表,例如 '美国,社会'"), window_hours: int = Query(default=24, ge=1, le=168), sources: str | None = Query(default=None, description="逗号分隔源 slug"), q: str | None = Query(default=None, description="关键词过滤"), user: User = Depends(get_current_user), session: AsyncSession = Depends(get_session), ): """批量查多个分类的未读数(给前端提示条用,一次拿全,省 round-trip)。 - categories 必填,逗号分隔 - 每个分类独立 count(单次请求内是 N 次查询,数据量小,够用) - 返回的 unread_count = 0 的分类前端不展示提示行 """ cat_list = [c.strip() for c in categories.split(",") if c.strip()] if not cat_list: return [] src_list = [s.strip() for s in sources.split(",") if s.strip()] if sources else None eff_q = q.strip() if q and q.strip() else None results: list[CategoryCountItem] = [] for cat in cat_list[:10]: # 防御性:一次最多 10 个分类 filters = _build_category_filter( category=cat, window_hours=window_hours, user_id=user.id, sources=src_list, q=eff_q, ) stmt = select(func.count(Article.id)) if src_list: stmt = stmt.join(Source, Source.id == Article.source_id) for f in filters: stmt = stmt.where(f) n = (await session.execute(stmt)).scalar_one() if n > 0: results.append( CategoryCountItem(category=cat, unread_count=int(n), window_hours=window_hours) ) return results