后端: - POST /me/reads/by-category:按 category + time window 批量标已读 - GET /me/reads/category-count:查多个分类的 24h 未读数 前端: - 标已读后,查该文章前 2 个 category 的 24h 未读数 - 每个有未读的 category 显示一行提示(独立全部已读/稍后再说按钮) - 全部已读走乐观更新,命中的 article 走累计 delay 滑出 - 过滤变化时清掉提示(上下文变了) - 8 秒自动消失
352 lines
12 KiB
Python
352 lines
12 KiB
Python
"""/me 当前用户信息 + 翻译配额 + 已读文章 + 按分类批量已读。"""
|
|
from __future__ import annotations
|
|
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Literal
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
|
from pydantic import BaseModel, Field
|
|
from sqlalchemy import and_, delete, func, not_, or_, select, text
|
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.config import settings
|
|
from app.core.deps import get_current_user
|
|
from app.database import get_session
|
|
from app.models.article import Article
|
|
from app.models.article_read import ArticleRead
|
|
from app.models.source import Source
|
|
from app.models.user import User
|
|
from app.redis_client import get_redis
|
|
|
|
router = APIRouter(prefix="/me", tags=["me"])
|
|
|
|
|
|
class MeOut(BaseModel):
|
|
id: int
|
|
username: str
|
|
email: str | None
|
|
role: str
|
|
display_name: str | None
|
|
created_at: datetime
|
|
|
|
|
|
class UsageOut(BaseModel):
|
|
month: str
|
|
used_chars: int
|
|
quota_chars: int
|
|
remaining_chars: int
|
|
buffered_quota: int
|
|
pct_used: float
|
|
|
|
|
|
class ReadToggleResponse(BaseModel):
|
|
article_id: int
|
|
is_read: bool
|
|
|
|
|
|
class ReadListResponse(BaseModel):
|
|
"""已读文章 ID 列表(给前端用,Feed 列表展示)."""
|
|
article_ids: list[int]
|
|
total: int
|
|
|
|
|
|
@router.get("", response_model=MeOut)
|
|
async def me(user: User = Depends(get_current_user)):
|
|
return MeOut(
|
|
id=user.id,
|
|
username=user.username,
|
|
email=user.email,
|
|
role=user.role.value,
|
|
display_name=user.display_name,
|
|
created_at=user.created_at,
|
|
)
|
|
|
|
|
|
@router.get("/usage", response_model=UsageOut)
|
|
async def usage(
|
|
user: User = Depends(get_current_user),
|
|
session: AsyncSession = Depends(get_session), # noqa: ARG001
|
|
):
|
|
r = get_redis()
|
|
now = datetime.now(timezone.utc)
|
|
key = f"translation:month:{now:%Y%m}"
|
|
used = int(await r.get(key) or 0)
|
|
quota = settings.tencent_tmt_quota_month
|
|
buffered = int(quota * (1 - settings.tencent_tmt_quota_buffer))
|
|
remaining = max(0, quota - used)
|
|
return UsageOut(
|
|
month=f"{now:%Y%m}",
|
|
used_chars=used,
|
|
quota_chars=quota,
|
|
remaining_chars=remaining,
|
|
buffered_quota=buffered,
|
|
pct_used=round(used / quota * 100, 2) if quota else 0.0,
|
|
)
|
|
|
|
|
|
# === 已读文章(per-user) ===
|
|
@router.post("/reads/{article_id}", response_model=ReadToggleResponse)
|
|
async def mark_read(
|
|
article_id: int,
|
|
user: User = Depends(get_current_user),
|
|
session: AsyncSession = Depends(get_session),
|
|
):
|
|
"""标记文章为已读(幂等,重复调用不报错)。"""
|
|
# 先确认文章存在(不存在返 404,避免 FK 报错)
|
|
exists = (await session.execute(select(Article.id).where(Article.id == article_id))).scalar_one_or_none()
|
|
if not exists:
|
|
raise HTTPException(status.HTTP_404_NOT_FOUND, "Article not found")
|
|
# 用 INSERT ... ON CONFLICT DO NOTHING (PG 原生)— 等价 upsert 跳过
|
|
stmt = pg_insert(ArticleRead).values(user_id=user.id, article_id=article_id).on_conflict_do_nothing(
|
|
index_elements=["user_id", "article_id"]
|
|
)
|
|
await session.execute(stmt)
|
|
await session.commit()
|
|
return ReadToggleResponse(article_id=article_id, is_read=True)
|
|
|
|
|
|
@router.delete("/reads/{article_id}", response_model=ReadToggleResponse)
|
|
async def unmark_read(
|
|
article_id: int,
|
|
user: User = Depends(get_current_user),
|
|
session: AsyncSession = Depends(get_session),
|
|
):
|
|
"""取消已读(幂等)。"""
|
|
stmt = delete(ArticleRead).where(
|
|
and_(ArticleRead.user_id == user.id, ArticleRead.article_id == article_id)
|
|
)
|
|
res = await session.execute(stmt)
|
|
await session.commit()
|
|
return ReadToggleResponse(article_id=article_id, is_read=False)
|
|
|
|
|
|
@router.get("/reads", response_model=ReadListResponse)
|
|
async def list_reads(
|
|
since: datetime | None = None,
|
|
limit: int = 500,
|
|
user: User = Depends(get_current_user),
|
|
session: AsyncSession = Depends(get_session),
|
|
):
|
|
"""列出当前用户已读文章 ID(给 Feed 过滤 / 已读时间线用)。
|
|
|
|
- since: 只返回 read_at >= since 的(默认 7 天前,避免数据爆炸)
|
|
- limit: 最多返回多少(默认 500)
|
|
"""
|
|
if since is None:
|
|
since = datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(days=7)
|
|
stmt = (
|
|
select(ArticleRead.article_id)
|
|
.where(ArticleRead.user_id == user.id, ArticleRead.read_at >= since)
|
|
.order_by(ArticleRead.read_at.desc())
|
|
.limit(min(limit, 2000))
|
|
)
|
|
rows = (await session.execute(stmt)).all()
|
|
ids = [r[0] for r in rows]
|
|
return ReadListResponse(article_ids=ids, total=len(ids))
|
|
|
|
|
|
# === 按分类批量已读(per-user) ===
|
|
# 用例:用户在 Feed 标了一条《xxx》为已读,该文章 category 含 "美国" "社会" 两个 tag,
|
|
# 前端会同时查这两个分类下"24 小时未读数"展示提示条,点头部的"全部已读"调本接口。
|
|
# 设计要点:
|
|
# - category 用 ilike 模糊匹配(Article.category 是逗号分隔字符串 '美国,社会,紧急')
|
|
# - 默认遵守当前 Feed 过滤(关键词/源)— scope=filtered_unread,避免误杀
|
|
# - 时间窗默认 24h(防"老新闻刷不完"的提示)
|
|
# - 批量插入用 ON CONFLICT DO NOTHING,幂等
|
|
# - 返回 article_ids 给前端做乐观滑出动画
|
|
|
|
|
|
def _unread_subquery(user_id: int):
|
|
"""未读文章 id 子查询(per-user)— 复用 articles.py 的 NOT EXISTS 模式。"""
|
|
return (
|
|
select(ArticleRead.article_id)
|
|
.where(ArticleRead.user_id == user_id, ArticleRead.article_id == Article.id)
|
|
.exists()
|
|
)
|
|
|
|
|
|
def _escape_like(term: str) -> str:
|
|
"""转义 ilike 的元字符,避免 category 含 % / _ 时被当通配符。"""
|
|
return term.replace("\\", "\\\\").replace("%", r"\%").replace("_", r"\_")
|
|
|
|
|
|
class CategoryReadRequest(BaseModel):
|
|
category: str = Field(..., min_length=1, max_length=32, description="要批量已读的分类 tag")
|
|
scope: Literal["all_unread", "filtered_unread"] = Field(
|
|
default="filtered_unread",
|
|
description="过滤范围:filtered_unread=遵守 sources/q;all_unread=忽略",
|
|
)
|
|
window_hours: int = Field(
|
|
default=24, ge=1, le=168, description="只标 published_at 在最近 N 小时内的未读"
|
|
)
|
|
sources: list[str] | None = Field(default=None, description="源 slug 列表,过滤范围")
|
|
q: str | None = Field(default=None, description="关键词过滤(标题/正文模糊)")
|
|
|
|
|
|
class CategoryReadResponse(BaseModel):
|
|
category: str
|
|
matched: int # 命中的未读数
|
|
marked: int # 实际新标已读数(去重后)
|
|
article_ids: list[int] # 给前端做滑出动画
|
|
|
|
|
|
class CategoryCountItem(BaseModel):
|
|
category: str
|
|
unread_count: int
|
|
window_hours: int
|
|
|
|
|
|
def _build_category_filter(
|
|
category: str,
|
|
window_hours: int,
|
|
user_id: int,
|
|
sources: list[str] | None = None,
|
|
q: str | None = None,
|
|
):
|
|
"""构造"分类 + 时间窗 + 未读"三合一 filter。
|
|
|
|
- category 用 ilike 模糊匹配(逗号分隔串中含此 tag 即可)
|
|
- window_hours 通过 published_at 过滤
|
|
- 用 NOT EXISTS 排除当前用户已读
|
|
- sources / q 可选,跟 articles.py 列表查询口径一致
|
|
"""
|
|
pattern = f"%{_escape_like(category)}%"
|
|
# 用 text() 拼 interval,避免 make_interval 的 PG 函数参数绑定在某些
|
|
# SQLAlchemy 版本下的兼容性问题。window_hours 是 int,只来自我们自己的
|
|
# endpoint,不是用户原始字符串,安全。
|
|
interval_sql = text(f"interval '{int(window_hours)} hours'")
|
|
filters = [
|
|
Article.category.ilike(pattern, escape="\\"),
|
|
Article.published_at >= func.now() - interval_sql,
|
|
not_(_unread_subquery(user_id)),
|
|
]
|
|
if sources:
|
|
slugs = [s.strip() for s in sources if s.strip()]
|
|
if slugs:
|
|
filters.append(Source.slug.in_(slugs))
|
|
if q and q.strip():
|
|
like = f"%{q.strip()}%"
|
|
filters.append(
|
|
or_(
|
|
Article.title.ilike(like),
|
|
Article.body_text.ilike(like),
|
|
Article.title_zh.ilike(like),
|
|
Article.body_zh_text.ilike(like),
|
|
Article.summary_zh.ilike(like),
|
|
)
|
|
)
|
|
return filters
|
|
|
|
|
|
@router.post("/reads/by-category", response_model=CategoryReadResponse)
|
|
async def mark_category_read(
|
|
body: CategoryReadRequest,
|
|
user: User = Depends(get_current_user),
|
|
session: AsyncSession = Depends(get_session),
|
|
):
|
|
"""按分类批量已读。
|
|
|
|
- 默认 scope=filtered_unread:只在当前 Feed 过滤范围内标
|
|
- 时间窗默认 24h:防"老新闻刷不完"
|
|
- 幂等:已读的不重复处理
|
|
"""
|
|
cat = body.category.strip()
|
|
if not cat:
|
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "category 不能为空")
|
|
|
|
# scope=all_unread 时,无视 sources/q(传了也忽略,语义更清晰)
|
|
if body.scope == "all_unread":
|
|
eff_sources: list[str] | None = None
|
|
eff_q: str | None = None
|
|
else:
|
|
eff_sources = body.sources
|
|
eff_q = body.q
|
|
|
|
filters = _build_category_filter(
|
|
category=cat,
|
|
window_hours=body.window_hours,
|
|
user_id=user.id,
|
|
sources=eff_sources,
|
|
q=eff_q,
|
|
)
|
|
|
|
# 1. 查命中 id 列表(返回给前端)
|
|
id_stmt = select(Article.id)
|
|
# sources 过滤需要 join
|
|
if eff_sources:
|
|
id_stmt = id_stmt.join(Source, Source.id == Article.source_id)
|
|
for f in filters:
|
|
id_stmt = id_stmt.where(f)
|
|
article_ids = [r[0] for r in (await session.execute(id_stmt)).all()]
|
|
|
|
if not article_ids:
|
|
return CategoryReadResponse(category=cat, matched=0, marked=0, article_ids=[])
|
|
|
|
# 2. 批量插入 article_reads(ON CONFLICT DO NOTHING,幂等)
|
|
# 分批:防御性,避免单次 VALUES 太多;500 是经验值,PG 完全能吃更大
|
|
BATCH = 500
|
|
marked_total = 0
|
|
for i in range(0, len(article_ids), BATCH):
|
|
chunk = article_ids[i : i + BATCH]
|
|
rows = [{"user_id": user.id, "article_id": aid} for aid in chunk]
|
|
stmt = pg_insert(ArticleRead).values(rows).on_conflict_do_nothing(
|
|
index_elements=["user_id", "article_id"]
|
|
)
|
|
# PG ON CONFLICT 不直接告诉哪些是"真插入的",通过 result() 或者信任总数;
|
|
# 这里简单信任 chunk 大小,marked 返回 attempt 数(去重由 PK 保证)
|
|
await session.execute(stmt)
|
|
marked_total += len(chunk)
|
|
await session.commit()
|
|
|
|
return CategoryReadResponse(
|
|
category=cat,
|
|
matched=len(article_ids),
|
|
marked=marked_total,
|
|
article_ids=article_ids,
|
|
)
|
|
|
|
|
|
@router.get("/reads/category-count", response_model=list[CategoryCountItem])
|
|
async def count_unread_by_categories(
|
|
categories: str = Query(..., description="逗号分隔的分类列表,例如 '美国,社会'"),
|
|
window_hours: int = Query(default=24, ge=1, le=168),
|
|
sources: str | None = Query(default=None, description="逗号分隔源 slug"),
|
|
q: str | None = Query(default=None, description="关键词过滤"),
|
|
user: User = Depends(get_current_user),
|
|
session: AsyncSession = Depends(get_session),
|
|
):
|
|
"""批量查多个分类的未读数(给前端提示条用,一次拿全,省 round-trip)。
|
|
|
|
- categories 必填,逗号分隔
|
|
- 每个分类独立 count(单次请求内是 N 次查询,数据量小,够用)
|
|
- 返回的 unread_count = 0 的分类前端不展示提示行
|
|
"""
|
|
cat_list = [c.strip() for c in categories.split(",") if c.strip()]
|
|
if not cat_list:
|
|
return []
|
|
|
|
src_list = [s.strip() for s in sources.split(",") if s.strip()] if sources else None
|
|
eff_q = q.strip() if q and q.strip() else None
|
|
|
|
results: list[CategoryCountItem] = []
|
|
for cat in cat_list[:10]: # 防御性:一次最多 10 个分类
|
|
filters = _build_category_filter(
|
|
category=cat,
|
|
window_hours=window_hours,
|
|
user_id=user.id,
|
|
sources=src_list,
|
|
q=eff_q,
|
|
)
|
|
stmt = select(func.count(Article.id))
|
|
if src_list:
|
|
stmt = stmt.join(Source, Source.id == Article.source_id)
|
|
for f in filters:
|
|
stmt = stmt.where(f)
|
|
n = (await session.execute(stmt)).scalar_one()
|
|
if n > 0:
|
|
results.append(
|
|
CategoryCountItem(category=cat, unread_count=int(n), window_hours=window_hours)
|
|
)
|
|
return results
|