Files
tophux_scrape/product/difficulty_scorer.py

250 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
产品难度评分脚本
读取product_analysis表增加难度评分字段使用Ollama API进行智能评分
"""
import sqlite3
import os
import time
from typing import List, Tuple, Optional
from loguru import logger
import requests
import json
class DifficultyScorer:
"""产品难度评分器"""
def __init__(self, db_path: str = "products.db"):
"""
初始化评分器
Args:
db_path: 数据库文件路径
"""
self.db_path = db_path
self.api_url = "http://localhost:11434/api/generate"
# 检查数据库文件是否存在
if not os.path.exists(db_path):
current_dir_db = os.path.join(os.path.dirname(__file__), db_path)
if os.path.exists(current_dir_db):
self.db_path = current_dir_db
logger.info(f"使用当前目录下的数据库文件: {current_dir_db}")
else:
raise FileNotFoundError(f"数据库文件不存在: {db_path}{current_dir_db}")
logger.info(f"初始化产品难度评分器,数据库: {self.db_path}")
def connect_to_database(self) -> sqlite3.Connection:
"""连接到SQLite数据库"""
try:
conn = sqlite3.connect(self.db_path)
logger.success(f"成功连接到数据库: {self.db_path}")
return conn
except Exception as e:
logger.error(f"连接数据库失败: {e}")
raise
def add_difficulty_score_column(self, conn: sqlite3.Connection):
"""添加难度评分字段"""
try:
cursor = conn.cursor()
# 检查字段是否已存在
cursor.execute("PRAGMA table_info(product_analysis)")
columns = [row[1] for row in cursor.fetchall()]
if 'difficulty_score' not in columns:
cursor.execute("ALTER TABLE product_analysis ADD COLUMN difficulty_score INTEGER")
conn.commit()
logger.success("成功添加difficulty_score字段")
else:
logger.info("difficulty_score字段已存在")
except Exception as e:
logger.error(f"添加难度评分字段失败: {e}")
raise
def get_unscored_products(self, conn: sqlite3.Connection) -> List[Tuple]:
"""
获取未评分的产品数据
Args:
conn: 数据库连接
Returns:
产品数据列表,每个元素为(id, ai_response)
"""
try:
cursor = conn.cursor()
# 查询未评分的产品
cursor.execute("""
SELECT id, ai_response
FROM product_analysis
WHERE difficulty_score IS NULL
AND ai_response IS NOT NULL
AND ai_response != ''
""")
products = cursor.fetchall()
logger.info(f"找到 {len(products)} 个未评分的产品")
return products
except Exception as e:
logger.error(f"获取未评分产品数据失败: {e}")
raise
def call_ollama_for_scoring(self, ai_response: str) -> Optional[int]:
"""
调用Ollama API进行难度评分
Args:
ai_response: AI响应内容
Returns:
评分0-100失败时返回None
"""
try:
# 构建评分提示
prompt = f"""
请根据以下产品开发难度描述给出一个0-100分的难度评分
难度描述:{ai_response}
评分标准:
- 90-100分个人开发极其困难需要大量专业知识和团队协作
- 70-89分相对困难需要较强的技术能力和较多时间
- 50-69分中等难度需要一定的技术基础
- 30-49分相对简单有基础即可开发
- 10-29分非常简单入门级别
- 0-9分极其简单几乎无难度
请只返回一个数字,不要有任何其他文字。
"""
data = {
"model": "qwen3:8b",
"prompt": prompt.strip(),
"stream": False
}
headers = {
"Content-Type": "application/json"
}
logger.info(f"调用Ollama API进行难度评分")
response = requests.post(
self.api_url,
headers=headers,
data=json.dumps(data, ensure_ascii=False),
timeout=60
)
if response.status_code == 200:
result = response.json()
score_text = result.get("response", "").strip()
# 尝试解析评分
try:
score = int(score_text)
# 确保评分在有效范围内
score = max(0, min(100, score))
logger.success(f"获得评分: {score}")
return score
except ValueError:
logger.error(f"无法解析评分: {score_text}")
return None
else:
logger.error(f"API调用失败: {response.status_code}, {response.text}")
return None
except Exception as e:
logger.error(f"调用Ollama API时出错: {e}")
return None
def update_difficulty_score(self, conn: sqlite3.Connection, product_id: int, score: int):
"""更新产品难度评分"""
try:
cursor = conn.cursor()
cursor.execute("""
UPDATE product_analysis
SET difficulty_score = ?
WHERE id = ?
""", (score, product_id))
conn.commit()
logger.success(f"更新产品ID {product_id} 的难度评分为: {score}")
except Exception as e:
logger.error(f"更新难度评分失败: {e}")
raise
def score_products(self):
"""评分所有未评分的产品"""
logger.info("开始产品难度评分")
conn = None
try:
# 连接数据库
conn = self.connect_to_database()
# 添加难度评分字段
self.add_difficulty_score_column(conn)
# 获取未评分的产品
products = self.get_unscored_products(conn)
if not products:
logger.info("没有需要评分的产品")
return
logger.info(f"准备评分 {len(products)} 个产品")
# 逐个评分
success_count = 0
for i, (product_id, ai_response) in enumerate(products, 1):
logger.info(f"\n评分进度: {i}/{len(products)} - 产品ID: {product_id}")
# 调用AI进行评分
score = self.call_ollama_for_scoring(ai_response)
if score is not None:
# 更新数据库
self.update_difficulty_score(conn, product_id, score)
success_count += 1
logger.success(f"评分完成: {score}")
else:
logger.error(f"评分失败: 产品ID {product_id}")
# 延时避免API过载
logger.info("等待2秒后继续...")
time.sleep(2)
logger.success(f"评分完成! 成功评分 {success_count} 个产品")
except Exception as e:
logger.error(f"评分过程中出错: {e}")
finally:
if conn:
conn.close()
logger.info("数据库连接已关闭")
def main():
"""主函数"""
# 配置日志
logger.add("difficulty_scorer.log", rotation="10 MB", level="INFO")
# 创建评分器
scorer = DifficultyScorer()
# 开始评分
scorer.score_products()
if __name__ == "__main__":
main()