221 lines
7.9 KiB
Python
221 lines
7.9 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
打开tophub_data.db数据库,读取表单,提取所有的类
|
||
访问本地ollama的api,修改类的名称为2-4个字,去掉中间的空格、特殊字符等字符
|
||
|
||
"""
|
||
|
||
import requests
|
||
import sqlite3
|
||
import re
|
||
import time
|
||
from loguru import logger
|
||
|
||
# 配置日志
|
||
logger.add("db_modify.log", rotation="10 MB", level="INFO")
|
||
|
||
class CategoryModifier:
|
||
"""类别修改器,用于优化数据库中的类别名称"""
|
||
|
||
def __init__(self, db_path="tophub_data.db"):
|
||
"""
|
||
初始化类别修改器
|
||
|
||
Args:
|
||
db_path (str): 数据库路径
|
||
"""
|
||
self.db_path = db_path
|
||
self.ollama_url = "http://localhost:11434/api/generate"
|
||
self.model = "qwen3:8b"
|
||
|
||
def get_all_categories(self):
|
||
"""
|
||
从数据库中获取所有唯一的类别
|
||
|
||
Returns:
|
||
list: 包含所有唯一类别的列表
|
||
"""
|
||
try:
|
||
conn = sqlite3.connect(self.db_path)
|
||
cursor = conn.cursor()
|
||
|
||
cursor.execute("SELECT DISTINCT category FROM articles")
|
||
categories = [row[0] for row in cursor.fetchall() if row[0]]
|
||
|
||
conn.close()
|
||
logger.info(f"成功获取 {len(categories)} 个唯一类别")
|
||
return categories
|
||
except Exception as e:
|
||
logger.error(f"获取类别时出错: {e}")
|
||
return []
|
||
|
||
def clean_category_name(self, category):
|
||
"""
|
||
清理类别名称,移除特殊字符和多余空格
|
||
|
||
Args:
|
||
category (str): 原始类别名称
|
||
|
||
Returns:
|
||
str: 清理后的类别名称
|
||
"""
|
||
# 移除特殊字符,只保留中文、英文和数字
|
||
cleaned = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9]', '', category)
|
||
# 移除多余的空格
|
||
cleaned = re.sub(r'\s+', '', cleaned)
|
||
return cleaned
|
||
|
||
def optimize_category_with_ollama(self, category):
|
||
"""
|
||
使用Ollama API优化类别名称
|
||
|
||
Args:
|
||
category (str): 原始类别名称
|
||
|
||
Returns:
|
||
str: 优化后的类别名称
|
||
"""
|
||
try:
|
||
# 构造提示词
|
||
prompt = f"请将以下类别名称简化为3-6个汉字,去除空格和特殊符号,更容易理解,并保持原意:'{category}'。" + \
|
||
"例子一:'新科科技',优化为'新质生产力'。例子二:'产设',优化为'产品设计'。例子三:'史人',优化为'历史人物'。"
|
||
|
||
# 准备请求数据
|
||
data = {
|
||
"model": self.model,
|
||
"prompt": prompt,
|
||
"stream": False
|
||
}
|
||
|
||
# 发送请求到Ollama API
|
||
response = requests.post(self.ollama_url, json=data, timeout=30)
|
||
response.raise_for_status()
|
||
|
||
# 解析响应
|
||
result = response.json()
|
||
optimized = result.get("response", "").strip()
|
||
|
||
# 清理优化后的名称
|
||
optimized = self.clean_category_name(optimized)
|
||
|
||
logger.info(f"类别 '{category}' 优化为 '{optimized}'")
|
||
return optimized
|
||
except Exception as e:
|
||
logger.error(f"优化类别 '{category}' 时出错: {e}")
|
||
# 如果API调用失败,返回清理后的原始名称
|
||
return self.clean_category_name(category)
|
||
|
||
def update_category_in_db(self, old_category, new_category):
|
||
"""
|
||
更新数据库中的类别名称
|
||
|
||
Args:
|
||
old_category (str): 原始类别名称
|
||
new_category (str): 新的类别名称
|
||
"""
|
||
try:
|
||
conn = sqlite3.connect(self.db_path)
|
||
cursor = conn.cursor()
|
||
|
||
cursor.execute(
|
||
"UPDATE articles SET category = ? WHERE category = ?",
|
||
(new_category, old_category)
|
||
)
|
||
|
||
count = cursor.rowcount
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
logger.info(f"成功更新类别 '{old_category}' 为 '{new_category}',影响 {count} 条记录")
|
||
except Exception as e:
|
||
logger.error(f"更新类别 '{old_category}' 时出错: {e}")
|
||
|
||
def process_all_categories(self):
|
||
"""
|
||
处理所有类别
|
||
"""
|
||
logger.info("开始处理所有类别...")
|
||
|
||
# 获取所有类别
|
||
categories = self.get_all_categories()
|
||
|
||
if not categories:
|
||
logger.warning("未找到任何类别")
|
||
return
|
||
|
||
# 初始化进度统计
|
||
total_categories = len(categories)
|
||
processed_count = 0
|
||
unchanged_count = 0
|
||
updated_count = 0
|
||
start_time = time.time()
|
||
|
||
logger.info(f"总共需要处理 {total_categories} 个类别")
|
||
|
||
# 处理每个类别
|
||
for i, category in enumerate(categories, 1):
|
||
category_start_time = time.time()
|
||
logger.info(f"处理进度: {i}/{total_categories} ({i/total_categories*100:.1f}%) - 类别: {category}")
|
||
|
||
# 使用Ollama API优化类别名称
|
||
optimized_category = self.optimize_category_with_ollama(category)
|
||
|
||
# 如果优化后的名称与原始名称不同,则更新数据库
|
||
if optimized_category != category:
|
||
self.update_category_in_db(category, optimized_category)
|
||
updated_count += 1
|
||
logger.info(f"类别 '{category}' 已更新为 '{optimized_category}'")
|
||
else:
|
||
unchanged_count += 1
|
||
logger.info(f"类别 '{category}' 无需更改")
|
||
|
||
processed_count += 1
|
||
category_end_time = time.time()
|
||
category_duration = category_end_time - category_start_time
|
||
|
||
# 显示当前类别处理时间和平均处理时间
|
||
elapsed_time = time.time() - start_time
|
||
avg_time_per_category = elapsed_time / processed_count
|
||
estimated_remaining = avg_time_per_category * (total_categories - processed_count)
|
||
|
||
logger.info(f"类别 '{category}' 处理完成,耗时: {category_duration:.2f}秒")
|
||
logger.info(f"累计处理: {processed_count}/{total_categories} | "
|
||
f"已更新: {updated_count} | 未更改: {unchanged_count} | "
|
||
f"平均耗时: {avg_time_per_category:.2f}秒/类别 | "
|
||
f"预计剩余时间: {estimated_remaining:.2f}秒")
|
||
|
||
# 显示总体统计信息
|
||
total_duration = time.time() - start_time
|
||
logger.info("="*60)
|
||
logger.info("所有类别处理完成!")
|
||
logger.info(f"总计处理类别数: {total_categories}")
|
||
logger.info(f"更新类别数: {updated_count}")
|
||
logger.info(f"未更改类别数: {unchanged_count}")
|
||
logger.info(f"总耗时: {total_duration:.2f}秒")
|
||
logger.info(f"平均每类别处理时间: {total_duration/total_categories:.2f}秒")
|
||
logger.info("="*60)
|
||
|
||
def main():
|
||
"""主函数"""
|
||
modifier = CategoryModifier()
|
||
|
||
# 检查Ollama服务是否可用
|
||
try:
|
||
response = requests.get("http://localhost:11434/api/tags", timeout=5)
|
||
if response.status_code == 200:
|
||
logger.info("Ollama服务可用")
|
||
else:
|
||
logger.warning("Ollama服务不可用,请确保服务已启动")
|
||
return
|
||
except Exception as e:
|
||
logger.warning(f"无法连接到Ollama服务: {e}")
|
||
logger.info("请确保Ollama服务已在本地运行")
|
||
return
|
||
|
||
# 处理所有类别
|
||
modifier.process_all_categories()
|
||
|
||
if __name__ == "__main__":
|
||
main()
|