From 48f7a34ae3298a98d09d65827c07f5c23735badb Mon Sep 17 00:00:00 2001 From: xiaji Date: Mon, 15 Jun 2026 20:48:32 +0800 Subject: [PATCH] =?UTF-8?q?fix(me):=20=E6=8A=8A=20/reads/by-category=20?= =?UTF-8?q?=E8=B7=AF=E7=94=B1=E5=A3=B0=E6=98=8E=E5=9C=A8=20/reads/{article?= =?UTF-8?q?=5Fid}=20=E4=B9=8B=E5=89=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/api/me.py | 227 +++++++++++++++++++++--------------------- 1 file changed, 116 insertions(+), 111 deletions(-) diff --git a/backend/app/api/me.py b/backend/app/api/me.py index 2978e37..c0c7c99 100644 --- a/backend/app/api/me.py +++ b/backend/app/api/me.py @@ -86,6 +86,122 @@ async def usage( # === 已读文章(per-user) === +# ⚠️ 路由顺序重要:`/reads/by-category` 和 `/reads/category-count` 必须在 +# `/reads/{article_id}` **之前**声明,否则 FastAPI 会把 "by-category" / "category-count" +# 当成 article_id 匹配到 mark_read,然后 int 转换失败返 422。 +# FastAPI 路由匹配是"先注册先匹配"(按装饰器执行顺序),不是按特异度。 +@router.post("/reads/by-category", response_model=CategoryReadResponse) +async def mark_category_read( + body: CategoryReadRequest, + user: User = Depends(get_current_user), + session: AsyncSession = Depends(get_session), +): + """按分类批量已读。 + + - 默认 scope=filtered_unread:只在当前 Feed 过滤范围内标 + - 时间窗默认 24h:防"老新闻刷不完" + - 幂等:已读的不重复处理 + """ + cat = body.category.strip() + if not cat: + raise HTTPException(status.HTTP_400_BAD_REQUEST, "category 不能为空") + + # scope=all_unread 时,无视 sources/q(传了也忽略,语义更清晰) + if body.scope == "all_unread": + eff_sources: list[str] | None = None + eff_q: str | None = None + else: + eff_sources = body.sources + eff_q = body.q + + filters = _build_category_filter( + category=cat, + window_hours=body.window_hours, + user_id=user.id, + sources=eff_sources, + q=eff_q, + ) + + # 1. 查命中 id 列表(返回给前端) + id_stmt = select(Article.id) + # sources 过滤需要 join + if eff_sources: + id_stmt = id_stmt.join(Source, Source.id == Article.source_id) + for f in filters: + id_stmt = id_stmt.where(f) + article_ids = [r[0] for r in (await session.execute(id_stmt)).all()] + + if not article_ids: + return CategoryReadResponse(category=cat, matched=0, marked=0, article_ids=[]) + + # 2. 批量插入 article_reads(ON CONFLICT DO NOTHING,幂等) + # 分批:防御性,避免单次 VALUES 太多;500 是经验值,PG 完全能吃更大 + BATCH = 500 + marked_total = 0 + for i in range(0, len(article_ids), BATCH): + chunk = article_ids[i : i + BATCH] + rows = [{"user_id": user.id, "article_id": aid} for aid in chunk] + stmt = pg_insert(ArticleRead).values(rows).on_conflict_do_nothing( + index_elements=["user_id", "article_id"] + ) + # PG ON CONFLICT 不直接告诉哪些是"真插入的",通过 result() 或者信任总数; + # 这里简单信任 chunk 大小,marked 返回 attempt 数(去重由 PK 保证) + await session.execute(stmt) + marked_total += len(chunk) + await session.commit() + + return CategoryReadResponse( + category=cat, + matched=len(article_ids), + marked=marked_total, + article_ids=article_ids, + ) + + +@router.get("/reads/category-count", response_model=list[CategoryCountItem]) +async def count_unread_by_categories( + categories: str = Query(..., description="逗号分隔的分类列表,例如 '美国,社会'"), + window_hours: int = Query(default=24, ge=1, le=168), + sources: str | None = Query(default=None, description="逗号分隔源 slug"), + q: str | None = Query(default=None, description="关键词过滤"), + user: User = Depends(get_current_user), + session: AsyncSession = Depends(get_session), +): + """批量查多个分类的未读数(给前端提示条用,一次拿全,省 round-trip)。 + + - categories 必填,逗号分隔 + - 每个分类独立 count(单次请求内是 N 次查询,数据量小,够用) + - 返回的 unread_count = 0 的分类前端不展示提示行 + """ + cat_list = [c.strip() for c in categories.split(",") if c.strip()] + if not cat_list: + return [] + + src_list = [s.strip() for s in sources.split(",") if s.strip()] if sources else None + eff_q = q.strip() if q and q.strip() else None + + results: list[CategoryCountItem] = [] + for cat in cat_list[:10]: # 防御性:一次最多 10 个分类 + filters = _build_category_filter( + category=cat, + window_hours=window_hours, + user_id=user.id, + sources=src_list, + q=eff_q, + ) + stmt = select(func.count(Article.id)) + if src_list: + stmt = stmt.join(Source, Source.id == Article.source_id) + for f in filters: + stmt = stmt.where(f) + n = (await session.execute(stmt)).scalar_one() + if n > 0: + results.append( + CategoryCountItem(category=cat, unread_count=int(n), window_hours=window_hours) + ) + return results + + @router.post("/reads/{article_id}", response_model=ReadToggleResponse) async def mark_read( article_id: int, @@ -238,114 +354,3 @@ def _build_category_filter( ) return filters - -@router.post("/reads/by-category", response_model=CategoryReadResponse) -async def mark_category_read( - body: CategoryReadRequest, - user: User = Depends(get_current_user), - session: AsyncSession = Depends(get_session), -): - """按分类批量已读。 - - - 默认 scope=filtered_unread:只在当前 Feed 过滤范围内标 - - 时间窗默认 24h:防"老新闻刷不完" - - 幂等:已读的不重复处理 - """ - cat = body.category.strip() - if not cat: - raise HTTPException(status.HTTP_400_BAD_REQUEST, "category 不能为空") - - # scope=all_unread 时,无视 sources/q(传了也忽略,语义更清晰) - if body.scope == "all_unread": - eff_sources: list[str] | None = None - eff_q: str | None = None - else: - eff_sources = body.sources - eff_q = body.q - - filters = _build_category_filter( - category=cat, - window_hours=body.window_hours, - user_id=user.id, - sources=eff_sources, - q=eff_q, - ) - - # 1. 查命中 id 列表(返回给前端) - id_stmt = select(Article.id) - # sources 过滤需要 join - if eff_sources: - id_stmt = id_stmt.join(Source, Source.id == Article.source_id) - for f in filters: - id_stmt = id_stmt.where(f) - article_ids = [r[0] for r in (await session.execute(id_stmt)).all()] - - if not article_ids: - return CategoryReadResponse(category=cat, matched=0, marked=0, article_ids=[]) - - # 2. 批量插入 article_reads(ON CONFLICT DO NOTHING,幂等) - # 分批:防御性,避免单次 VALUES 太多;500 是经验值,PG 完全能吃更大 - BATCH = 500 - marked_total = 0 - for i in range(0, len(article_ids), BATCH): - chunk = article_ids[i : i + BATCH] - rows = [{"user_id": user.id, "article_id": aid} for aid in chunk] - stmt = pg_insert(ArticleRead).values(rows).on_conflict_do_nothing( - index_elements=["user_id", "article_id"] - ) - # PG ON CONFLICT 不直接告诉哪些是"真插入的",通过 result() 或者信任总数; - # 这里简单信任 chunk 大小,marked 返回 attempt 数(去重由 PK 保证) - await session.execute(stmt) - marked_total += len(chunk) - await session.commit() - - return CategoryReadResponse( - category=cat, - matched=len(article_ids), - marked=marked_total, - article_ids=article_ids, - ) - - -@router.get("/reads/category-count", response_model=list[CategoryCountItem]) -async def count_unread_by_categories( - categories: str = Query(..., description="逗号分隔的分类列表,例如 '美国,社会'"), - window_hours: int = Query(default=24, ge=1, le=168), - sources: str | None = Query(default=None, description="逗号分隔源 slug"), - q: str | None = Query(default=None, description="关键词过滤"), - user: User = Depends(get_current_user), - session: AsyncSession = Depends(get_session), -): - """批量查多个分类的未读数(给前端提示条用,一次拿全,省 round-trip)。 - - - categories 必填,逗号分隔 - - 每个分类独立 count(单次请求内是 N 次查询,数据量小,够用) - - 返回的 unread_count = 0 的分类前端不展示提示行 - """ - cat_list = [c.strip() for c in categories.split(",") if c.strip()] - if not cat_list: - return [] - - src_list = [s.strip() for s in sources.split(",") if s.strip()] if sources else None - eff_q = q.strip() if q and q.strip() else None - - results: list[CategoryCountItem] = [] - for cat in cat_list[:10]: # 防御性:一次最多 10 个分类 - filters = _build_category_filter( - category=cat, - window_hours=window_hours, - user_id=user.id, - sources=src_list, - q=eff_q, - ) - stmt = select(func.count(Article.id)) - if src_list: - stmt = stmt.join(Source, Source.id == Article.source_id) - for f in filters: - stmt = stmt.where(f) - n = (await session.execute(stmt)).scalar_one() - if n > 0: - results.append( - CategoryCountItem(category=cat, unread_count=int(n), window_hours=window_hours) - ) - return results