221 lines
7.9 KiB
Python
221 lines
7.9 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
"""
|
|||
|
|
打开tophub_data.db数据库,读取表单,提取所有的类
|
|||
|
|
访问本地ollama的api,修改类的名称为2-4个字,去掉中间的空格、特殊字符等字符
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import requests
|
|||
|
|
import sqlite3
|
|||
|
|
import re
|
|||
|
|
import time
|
|||
|
|
from loguru import logger
|
|||
|
|
|
|||
|
|
# 配置日志
|
|||
|
|
logger.add("db_modify.log", rotation="10 MB", level="INFO")
|
|||
|
|
|
|||
|
|
class CategoryModifier:
|
|||
|
|
"""类别修改器,用于优化数据库中的类别名称"""
|
|||
|
|
|
|||
|
|
def __init__(self, db_path="tophub_data.db"):
|
|||
|
|
"""
|
|||
|
|
初始化类别修改器
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
db_path (str): 数据库路径
|
|||
|
|
"""
|
|||
|
|
self.db_path = db_path
|
|||
|
|
self.ollama_url = "http://localhost:11434/api/generate"
|
|||
|
|
self.model = "qwen3:8b"
|
|||
|
|
|
|||
|
|
def get_all_categories(self):
|
|||
|
|
"""
|
|||
|
|
从数据库中获取所有唯一的类别
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
list: 包含所有唯一类别的列表
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
conn = sqlite3.connect(self.db_path)
|
|||
|
|
cursor = conn.cursor()
|
|||
|
|
|
|||
|
|
cursor.execute("SELECT DISTINCT category FROM articles")
|
|||
|
|
categories = [row[0] for row in cursor.fetchall() if row[0]]
|
|||
|
|
|
|||
|
|
conn.close()
|
|||
|
|
logger.info(f"成功获取 {len(categories)} 个唯一类别")
|
|||
|
|
return categories
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"获取类别时出错: {e}")
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
def clean_category_name(self, category):
|
|||
|
|
"""
|
|||
|
|
清理类别名称,移除特殊字符和多余空格
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
category (str): 原始类别名称
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
str: 清理后的类别名称
|
|||
|
|
"""
|
|||
|
|
# 移除特殊字符,只保留中文、英文和数字
|
|||
|
|
cleaned = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9]', '', category)
|
|||
|
|
# 移除多余的空格
|
|||
|
|
cleaned = re.sub(r'\s+', '', cleaned)
|
|||
|
|
return cleaned
|
|||
|
|
|
|||
|
|
def optimize_category_with_ollama(self, category):
|
|||
|
|
"""
|
|||
|
|
使用Ollama API优化类别名称
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
category (str): 原始类别名称
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
str: 优化后的类别名称
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
# 构造提示词
|
|||
|
|
prompt = f"请将以下类别名称简化为3-6个汉字,去除空格和特殊符号,更容易理解,并保持原意:'{category}'。" + \
|
|||
|
|
"例子一:'新科科技',优化为'新质生产力'。例子二:'产设',优化为'产品设计'。例子三:'史人',优化为'历史人物'。"
|
|||
|
|
|
|||
|
|
# 准备请求数据
|
|||
|
|
data = {
|
|||
|
|
"model": self.model,
|
|||
|
|
"prompt": prompt,
|
|||
|
|
"stream": False
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 发送请求到Ollama API
|
|||
|
|
response = requests.post(self.ollama_url, json=data, timeout=30)
|
|||
|
|
response.raise_for_status()
|
|||
|
|
|
|||
|
|
# 解析响应
|
|||
|
|
result = response.json()
|
|||
|
|
optimized = result.get("response", "").strip()
|
|||
|
|
|
|||
|
|
# 清理优化后的名称
|
|||
|
|
optimized = self.clean_category_name(optimized)
|
|||
|
|
|
|||
|
|
logger.info(f"类别 '{category}' 优化为 '{optimized}'")
|
|||
|
|
return optimized
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"优化类别 '{category}' 时出错: {e}")
|
|||
|
|
# 如果API调用失败,返回清理后的原始名称
|
|||
|
|
return self.clean_category_name(category)
|
|||
|
|
|
|||
|
|
def update_category_in_db(self, old_category, new_category):
|
|||
|
|
"""
|
|||
|
|
更新数据库中的类别名称
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
old_category (str): 原始类别名称
|
|||
|
|
new_category (str): 新的类别名称
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
conn = sqlite3.connect(self.db_path)
|
|||
|
|
cursor = conn.cursor()
|
|||
|
|
|
|||
|
|
cursor.execute(
|
|||
|
|
"UPDATE articles SET category = ? WHERE category = ?",
|
|||
|
|
(new_category, old_category)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
count = cursor.rowcount
|
|||
|
|
conn.commit()
|
|||
|
|
conn.close()
|
|||
|
|
|
|||
|
|
logger.info(f"成功更新类别 '{old_category}' 为 '{new_category}',影响 {count} 条记录")
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"更新类别 '{old_category}' 时出错: {e}")
|
|||
|
|
|
|||
|
|
def process_all_categories(self):
|
|||
|
|
"""
|
|||
|
|
处理所有类别
|
|||
|
|
"""
|
|||
|
|
logger.info("开始处理所有类别...")
|
|||
|
|
|
|||
|
|
# 获取所有类别
|
|||
|
|
categories = self.get_all_categories()
|
|||
|
|
|
|||
|
|
if not categories:
|
|||
|
|
logger.warning("未找到任何类别")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 初始化进度统计
|
|||
|
|
total_categories = len(categories)
|
|||
|
|
processed_count = 0
|
|||
|
|
unchanged_count = 0
|
|||
|
|
updated_count = 0
|
|||
|
|
start_time = time.time()
|
|||
|
|
|
|||
|
|
logger.info(f"总共需要处理 {total_categories} 个类别")
|
|||
|
|
|
|||
|
|
# 处理每个类别
|
|||
|
|
for i, category in enumerate(categories, 1):
|
|||
|
|
category_start_time = time.time()
|
|||
|
|
logger.info(f"处理进度: {i}/{total_categories} ({i/total_categories*100:.1f}%) - 类别: {category}")
|
|||
|
|
|
|||
|
|
# 使用Ollama API优化类别名称
|
|||
|
|
optimized_category = self.optimize_category_with_ollama(category)
|
|||
|
|
|
|||
|
|
# 如果优化后的名称与原始名称不同,则更新数据库
|
|||
|
|
if optimized_category != category:
|
|||
|
|
self.update_category_in_db(category, optimized_category)
|
|||
|
|
updated_count += 1
|
|||
|
|
logger.info(f"类别 '{category}' 已更新为 '{optimized_category}'")
|
|||
|
|
else:
|
|||
|
|
unchanged_count += 1
|
|||
|
|
logger.info(f"类别 '{category}' 无需更改")
|
|||
|
|
|
|||
|
|
processed_count += 1
|
|||
|
|
category_end_time = time.time()
|
|||
|
|
category_duration = category_end_time - category_start_time
|
|||
|
|
|
|||
|
|
# 显示当前类别处理时间和平均处理时间
|
|||
|
|
elapsed_time = time.time() - start_time
|
|||
|
|
avg_time_per_category = elapsed_time / processed_count
|
|||
|
|
estimated_remaining = avg_time_per_category * (total_categories - processed_count)
|
|||
|
|
|
|||
|
|
logger.info(f"类别 '{category}' 处理完成,耗时: {category_duration:.2f}秒")
|
|||
|
|
logger.info(f"累计处理: {processed_count}/{total_categories} | "
|
|||
|
|
f"已更新: {updated_count} | 未更改: {unchanged_count} | "
|
|||
|
|
f"平均耗时: {avg_time_per_category:.2f}秒/类别 | "
|
|||
|
|
f"预计剩余时间: {estimated_remaining:.2f}秒")
|
|||
|
|
|
|||
|
|
# 显示总体统计信息
|
|||
|
|
total_duration = time.time() - start_time
|
|||
|
|
logger.info("="*60)
|
|||
|
|
logger.info("所有类别处理完成!")
|
|||
|
|
logger.info(f"总计处理类别数: {total_categories}")
|
|||
|
|
logger.info(f"更新类别数: {updated_count}")
|
|||
|
|
logger.info(f"未更改类别数: {unchanged_count}")
|
|||
|
|
logger.info(f"总耗时: {total_duration:.2f}秒")
|
|||
|
|
logger.info(f"平均每类别处理时间: {total_duration/total_categories:.2f}秒")
|
|||
|
|
logger.info("="*60)
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
"""主函数"""
|
|||
|
|
modifier = CategoryModifier()
|
|||
|
|
|
|||
|
|
# 检查Ollama服务是否可用
|
|||
|
|
try:
|
|||
|
|
response = requests.get("http://localhost:11434/api/tags", timeout=5)
|
|||
|
|
if response.status_code == 200:
|
|||
|
|
logger.info("Ollama服务可用")
|
|||
|
|
else:
|
|||
|
|
logger.warning("Ollama服务不可用,请确保服务已启动")
|
|||
|
|
return
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"无法连接到Ollama服务: {e}")
|
|||
|
|
logger.info("请确保Ollama服务已在本地运行")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 处理所有类别
|
|||
|
|
modifier.process_all_categories()
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|