fix(me): 把 /reads/by-category 路由声明在 /reads/{article_id} 之前
This commit is contained in:
@@ -86,6 +86,122 @@ async def usage(
|
|||||||
|
|
||||||
|
|
||||||
# === 已读文章(per-user) ===
|
# === 已读文章(per-user) ===
|
||||||
|
# ⚠️ 路由顺序重要:`/reads/by-category` 和 `/reads/category-count` 必须在
|
||||||
|
# `/reads/{article_id}` **之前**声明,否则 FastAPI 会把 "by-category" / "category-count"
|
||||||
|
# 当成 article_id 匹配到 mark_read,然后 int 转换失败返 422。
|
||||||
|
# FastAPI 路由匹配是"先注册先匹配"(按装饰器执行顺序),不是按特异度。
|
||||||
|
@router.post("/reads/by-category", response_model=CategoryReadResponse)
|
||||||
|
async def mark_category_read(
|
||||||
|
body: CategoryReadRequest,
|
||||||
|
user: User = Depends(get_current_user),
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
):
|
||||||
|
"""按分类批量已读。
|
||||||
|
|
||||||
|
- 默认 scope=filtered_unread:只在当前 Feed 过滤范围内标
|
||||||
|
- 时间窗默认 24h:防"老新闻刷不完"
|
||||||
|
- 幂等:已读的不重复处理
|
||||||
|
"""
|
||||||
|
cat = body.category.strip()
|
||||||
|
if not cat:
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "category 不能为空")
|
||||||
|
|
||||||
|
# scope=all_unread 时,无视 sources/q(传了也忽略,语义更清晰)
|
||||||
|
if body.scope == "all_unread":
|
||||||
|
eff_sources: list[str] | None = None
|
||||||
|
eff_q: str | None = None
|
||||||
|
else:
|
||||||
|
eff_sources = body.sources
|
||||||
|
eff_q = body.q
|
||||||
|
|
||||||
|
filters = _build_category_filter(
|
||||||
|
category=cat,
|
||||||
|
window_hours=body.window_hours,
|
||||||
|
user_id=user.id,
|
||||||
|
sources=eff_sources,
|
||||||
|
q=eff_q,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. 查命中 id 列表(返回给前端)
|
||||||
|
id_stmt = select(Article.id)
|
||||||
|
# sources 过滤需要 join
|
||||||
|
if eff_sources:
|
||||||
|
id_stmt = id_stmt.join(Source, Source.id == Article.source_id)
|
||||||
|
for f in filters:
|
||||||
|
id_stmt = id_stmt.where(f)
|
||||||
|
article_ids = [r[0] for r in (await session.execute(id_stmt)).all()]
|
||||||
|
|
||||||
|
if not article_ids:
|
||||||
|
return CategoryReadResponse(category=cat, matched=0, marked=0, article_ids=[])
|
||||||
|
|
||||||
|
# 2. 批量插入 article_reads(ON CONFLICT DO NOTHING,幂等)
|
||||||
|
# 分批:防御性,避免单次 VALUES 太多;500 是经验值,PG 完全能吃更大
|
||||||
|
BATCH = 500
|
||||||
|
marked_total = 0
|
||||||
|
for i in range(0, len(article_ids), BATCH):
|
||||||
|
chunk = article_ids[i : i + BATCH]
|
||||||
|
rows = [{"user_id": user.id, "article_id": aid} for aid in chunk]
|
||||||
|
stmt = pg_insert(ArticleRead).values(rows).on_conflict_do_nothing(
|
||||||
|
index_elements=["user_id", "article_id"]
|
||||||
|
)
|
||||||
|
# PG ON CONFLICT 不直接告诉哪些是"真插入的",通过 result() 或者信任总数;
|
||||||
|
# 这里简单信任 chunk 大小,marked 返回 attempt 数(去重由 PK 保证)
|
||||||
|
await session.execute(stmt)
|
||||||
|
marked_total += len(chunk)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
return CategoryReadResponse(
|
||||||
|
category=cat,
|
||||||
|
matched=len(article_ids),
|
||||||
|
marked=marked_total,
|
||||||
|
article_ids=article_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/reads/category-count", response_model=list[CategoryCountItem])
|
||||||
|
async def count_unread_by_categories(
|
||||||
|
categories: str = Query(..., description="逗号分隔的分类列表,例如 '美国,社会'"),
|
||||||
|
window_hours: int = Query(default=24, ge=1, le=168),
|
||||||
|
sources: str | None = Query(default=None, description="逗号分隔源 slug"),
|
||||||
|
q: str | None = Query(default=None, description="关键词过滤"),
|
||||||
|
user: User = Depends(get_current_user),
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
):
|
||||||
|
"""批量查多个分类的未读数(给前端提示条用,一次拿全,省 round-trip)。
|
||||||
|
|
||||||
|
- categories 必填,逗号分隔
|
||||||
|
- 每个分类独立 count(单次请求内是 N 次查询,数据量小,够用)
|
||||||
|
- 返回的 unread_count = 0 的分类前端不展示提示行
|
||||||
|
"""
|
||||||
|
cat_list = [c.strip() for c in categories.split(",") if c.strip()]
|
||||||
|
if not cat_list:
|
||||||
|
return []
|
||||||
|
|
||||||
|
src_list = [s.strip() for s in sources.split(",") if s.strip()] if sources else None
|
||||||
|
eff_q = q.strip() if q and q.strip() else None
|
||||||
|
|
||||||
|
results: list[CategoryCountItem] = []
|
||||||
|
for cat in cat_list[:10]: # 防御性:一次最多 10 个分类
|
||||||
|
filters = _build_category_filter(
|
||||||
|
category=cat,
|
||||||
|
window_hours=window_hours,
|
||||||
|
user_id=user.id,
|
||||||
|
sources=src_list,
|
||||||
|
q=eff_q,
|
||||||
|
)
|
||||||
|
stmt = select(func.count(Article.id))
|
||||||
|
if src_list:
|
||||||
|
stmt = stmt.join(Source, Source.id == Article.source_id)
|
||||||
|
for f in filters:
|
||||||
|
stmt = stmt.where(f)
|
||||||
|
n = (await session.execute(stmt)).scalar_one()
|
||||||
|
if n > 0:
|
||||||
|
results.append(
|
||||||
|
CategoryCountItem(category=cat, unread_count=int(n), window_hours=window_hours)
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
@router.post("/reads/{article_id}", response_model=ReadToggleResponse)
|
@router.post("/reads/{article_id}", response_model=ReadToggleResponse)
|
||||||
async def mark_read(
|
async def mark_read(
|
||||||
article_id: int,
|
article_id: int,
|
||||||
@@ -238,114 +354,3 @@ def _build_category_filter(
|
|||||||
)
|
)
|
||||||
return filters
|
return filters
|
||||||
|
|
||||||
|
|
||||||
@router.post("/reads/by-category", response_model=CategoryReadResponse)
|
|
||||||
async def mark_category_read(
|
|
||||||
body: CategoryReadRequest,
|
|
||||||
user: User = Depends(get_current_user),
|
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
):
|
|
||||||
"""按分类批量已读。
|
|
||||||
|
|
||||||
- 默认 scope=filtered_unread:只在当前 Feed 过滤范围内标
|
|
||||||
- 时间窗默认 24h:防"老新闻刷不完"
|
|
||||||
- 幂等:已读的不重复处理
|
|
||||||
"""
|
|
||||||
cat = body.category.strip()
|
|
||||||
if not cat:
|
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "category 不能为空")
|
|
||||||
|
|
||||||
# scope=all_unread 时,无视 sources/q(传了也忽略,语义更清晰)
|
|
||||||
if body.scope == "all_unread":
|
|
||||||
eff_sources: list[str] | None = None
|
|
||||||
eff_q: str | None = None
|
|
||||||
else:
|
|
||||||
eff_sources = body.sources
|
|
||||||
eff_q = body.q
|
|
||||||
|
|
||||||
filters = _build_category_filter(
|
|
||||||
category=cat,
|
|
||||||
window_hours=body.window_hours,
|
|
||||||
user_id=user.id,
|
|
||||||
sources=eff_sources,
|
|
||||||
q=eff_q,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1. 查命中 id 列表(返回给前端)
|
|
||||||
id_stmt = select(Article.id)
|
|
||||||
# sources 过滤需要 join
|
|
||||||
if eff_sources:
|
|
||||||
id_stmt = id_stmt.join(Source, Source.id == Article.source_id)
|
|
||||||
for f in filters:
|
|
||||||
id_stmt = id_stmt.where(f)
|
|
||||||
article_ids = [r[0] for r in (await session.execute(id_stmt)).all()]
|
|
||||||
|
|
||||||
if not article_ids:
|
|
||||||
return CategoryReadResponse(category=cat, matched=0, marked=0, article_ids=[])
|
|
||||||
|
|
||||||
# 2. 批量插入 article_reads(ON CONFLICT DO NOTHING,幂等)
|
|
||||||
# 分批:防御性,避免单次 VALUES 太多;500 是经验值,PG 完全能吃更大
|
|
||||||
BATCH = 500
|
|
||||||
marked_total = 0
|
|
||||||
for i in range(0, len(article_ids), BATCH):
|
|
||||||
chunk = article_ids[i : i + BATCH]
|
|
||||||
rows = [{"user_id": user.id, "article_id": aid} for aid in chunk]
|
|
||||||
stmt = pg_insert(ArticleRead).values(rows).on_conflict_do_nothing(
|
|
||||||
index_elements=["user_id", "article_id"]
|
|
||||||
)
|
|
||||||
# PG ON CONFLICT 不直接告诉哪些是"真插入的",通过 result() 或者信任总数;
|
|
||||||
# 这里简单信任 chunk 大小,marked 返回 attempt 数(去重由 PK 保证)
|
|
||||||
await session.execute(stmt)
|
|
||||||
marked_total += len(chunk)
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
return CategoryReadResponse(
|
|
||||||
category=cat,
|
|
||||||
matched=len(article_ids),
|
|
||||||
marked=marked_total,
|
|
||||||
article_ids=article_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/reads/category-count", response_model=list[CategoryCountItem])
|
|
||||||
async def count_unread_by_categories(
|
|
||||||
categories: str = Query(..., description="逗号分隔的分类列表,例如 '美国,社会'"),
|
|
||||||
window_hours: int = Query(default=24, ge=1, le=168),
|
|
||||||
sources: str | None = Query(default=None, description="逗号分隔源 slug"),
|
|
||||||
q: str | None = Query(default=None, description="关键词过滤"),
|
|
||||||
user: User = Depends(get_current_user),
|
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
):
|
|
||||||
"""批量查多个分类的未读数(给前端提示条用,一次拿全,省 round-trip)。
|
|
||||||
|
|
||||||
- categories 必填,逗号分隔
|
|
||||||
- 每个分类独立 count(单次请求内是 N 次查询,数据量小,够用)
|
|
||||||
- 返回的 unread_count = 0 的分类前端不展示提示行
|
|
||||||
"""
|
|
||||||
cat_list = [c.strip() for c in categories.split(",") if c.strip()]
|
|
||||||
if not cat_list:
|
|
||||||
return []
|
|
||||||
|
|
||||||
src_list = [s.strip() for s in sources.split(",") if s.strip()] if sources else None
|
|
||||||
eff_q = q.strip() if q and q.strip() else None
|
|
||||||
|
|
||||||
results: list[CategoryCountItem] = []
|
|
||||||
for cat in cat_list[:10]: # 防御性:一次最多 10 个分类
|
|
||||||
filters = _build_category_filter(
|
|
||||||
category=cat,
|
|
||||||
window_hours=window_hours,
|
|
||||||
user_id=user.id,
|
|
||||||
sources=src_list,
|
|
||||||
q=eff_q,
|
|
||||||
)
|
|
||||||
stmt = select(func.count(Article.id))
|
|
||||||
if src_list:
|
|
||||||
stmt = stmt.join(Source, Source.id == Article.source_id)
|
|
||||||
for f in filters:
|
|
||||||
stmt = stmt.where(f)
|
|
||||||
n = (await session.execute(stmt)).scalar_one()
|
|
||||||
if n > 0:
|
|
||||||
results.append(
|
|
||||||
CategoryCountItem(category=cat, unread_count=int(n), window_hours=window_hours)
|
|
||||||
)
|
|
||||||
return results
|
|
||||||
|
|||||||
Reference in New Issue
Block a user