Files
guba-indicator/database.py

220 lines
7.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
数据库模块 - SQLite存储评论和分析结果
"""
import sqlite3
import hashlib
import json
from datetime import datetime
from typing import List, Dict, Optional, Tuple
from pathlib import Path
class DatabaseManager:
"""数据库管理器"""
def __init__(self, db_path: str = "guba.db"):
self.db_path = Path(db_path)
self._init_db()
def _init_db(self):
"""初始化数据库表"""
conn = self._get_connection()
cursor = conn.cursor()
# 评论表
cursor.execute('''
CREATE TABLE IF NOT EXISTS comments (
id INTEGER PRIMARY KEY AUTOINCREMENT,
content TEXT NOT NULL,
content_hash TEXT UNIQUE NOT NULL,
url TEXT,
created_at TEXT,
fetched_at TEXT DEFAULT CURRENT_TIMESTAMP,
analyzed INTEGER DEFAULT 0,
sentiment_score REAL,
analyzed_at TEXT
)
''')
# 分析历史表
cursor.execute('''
CREATE TABLE IF NOT EXISTS analysis_history (
id INTEGER PRIMARY KEY AUTOINCREMENT,
comment_id INTEGER,
sentiment_score REAL NOT NULL,
analysis_text TEXT,
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (comment_id) REFERENCES comments(id)
)
''')
# 配置表
cursor.execute('''
CREATE TABLE IF NOT EXISTS config (
key TEXT PRIMARY KEY,
value TEXT,
updated_at TEXT DEFAULT CURRENT_TIMESTAMP
)
''')
conn.commit()
conn.close()
def _get_connection(self) -> sqlite3.Connection:
"""获取数据库连接"""
return sqlite3.connect(str(self.db_path))
@staticmethod
def hash_content(content: str) -> str:
"""计算内容哈希值用于去重"""
return hashlib.md5(content.encode('utf-8')).hexdigest()
def is_comment_exists(self, content_hash: str) -> bool:
"""检查评论是否已存在"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute('SELECT 1 FROM comments WHERE content_hash = ?', (content_hash,))
exists = cursor.fetchone() is not None
conn.close()
return exists
def add_comment(self, content: str, url: str = None) -> Optional[int]:
"""添加评论返回评论ID"""
content_hash = self.hash_content(content)
if self.is_comment_exists(content_hash):
return None # 已存在
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute('''
INSERT INTO comments (content, content_hash, url, created_at)
VALUES (?, ?, ?, ?)
''', (content, content_hash, url, datetime.now().isoformat()))
comment_id = cursor.lastrowid
conn.commit()
conn.close()
return comment_id
def add_comments_batch(self, comments: List[Dict]) -> List[int]:
"""批量添加评论返回新添加的ID列表"""
new_ids = []
conn = self._get_connection()
cursor = conn.cursor()
for comment in comments:
content = comment.get('content', '')
url = comment.get('url')
content_hash = self.hash_content(content)
if self.is_comment_exists(content_hash):
continue
cursor.execute('''
INSERT INTO comments (content, content_hash, url, created_at)
VALUES (?, ?, ?, ?)
''', (content, content_hash, url, datetime.now().isoformat()))
new_ids.append(cursor.lastrowid)
conn.commit()
conn.close()
return new_ids
def get_unanalyzed_comments(self, limit: int = 50) -> List[Dict]:
"""获取未分析的评论"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute('''
SELECT id, content, url FROM comments
WHERE analyzed = 0
ORDER BY fetched_at ASC
LIMIT ?
''', (limit,))
rows = cursor.fetchall()
conn.close()
return [{'id': row[0], 'content': row[1], 'url': row[2]} for row in rows]
def mark_analyzed(self, comment_id: int, sentiment_score: float, analysis_text: str):
"""标记评论已分析"""
conn = self._get_connection()
cursor = conn.cursor()
# 更新评论状态
cursor.execute('''
UPDATE comments
SET analyzed = 1, sentiment_score = ?, analyzed_at = ?
WHERE id = ?
''', (sentiment_score, datetime.now().isoformat(), comment_id))
# 添加分析历史
cursor.execute('''
INSERT INTO analysis_history (comment_id, sentiment_score, analysis_text)
VALUES (?, ?, ?)
''', (comment_id, sentiment_score, analysis_text))
conn.commit()
conn.close()
def get_latest_sentiment_score(self) -> Optional[float]:
"""获取最新的情感分数"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute('''
SELECT sentiment_score FROM comments
WHERE analyzed = 1 AND sentiment_score IS NOT NULL
ORDER BY analyzed_at DESC
LIMIT 1
''')
row = cursor.fetchone()
conn.close()
return row[0] if row else None
def get_all_scores(self) -> List[float]:
"""获取所有已分析的分数"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute('''
SELECT sentiment_score FROM comments
WHERE analyzed = 1 AND sentiment_score IS NOT NULL
ORDER BY analyzed_at DESC
''')
rows = cursor.fetchall()
conn.close()
return [row[0] for row in rows if row[0] is not None]
def get_comment_count(self) -> int:
"""获取评论总数"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute('SELECT COUNT(*) FROM comments')
count = cursor.fetchone()[0]
conn.close()
return count
def get_analyzed_count(self) -> int:
"""获取已分析评论数"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute('SELECT COUNT(*) FROM comments WHERE analyzed = 1')
count = cursor.fetchone()[0]
conn.close()
return count
def get_recent_comments(self, limit: int = 10) -> List[Dict]:
"""获取最近的评论"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute('''
SELECT id, content, sentiment_score, analyzed_at
FROM comments
ORDER BY fetched_at DESC
LIMIT ?
''', (limit,))
rows = cursor.fetchall()
conn.close()
return [
{'id': row[0], 'content': row[1][:50] + '...' if len(row[1]) > 50 else row[1],
'score': row[2], 'analyzed_at': row[3]}
for row in rows
]