250 lines
8.4 KiB
Python
250 lines
8.4 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
"""
|
|||
|
|
产品难度评分脚本
|
|||
|
|
读取product_analysis表,增加难度评分字段,使用Ollama API进行智能评分
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import sqlite3
|
|||
|
|
import os
|
|||
|
|
import time
|
|||
|
|
from typing import List, Tuple, Optional
|
|||
|
|
from loguru import logger
|
|||
|
|
import requests
|
|||
|
|
import json
|
|||
|
|
|
|||
|
|
class DifficultyScorer:
|
|||
|
|
"""产品难度评分器"""
|
|||
|
|
|
|||
|
|
def __init__(self, db_path: str = "products.db"):
|
|||
|
|
"""
|
|||
|
|
初始化评分器
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
db_path: 数据库文件路径
|
|||
|
|
"""
|
|||
|
|
self.db_path = db_path
|
|||
|
|
self.api_url = "http://localhost:11434/api/generate"
|
|||
|
|
|
|||
|
|
# 检查数据库文件是否存在
|
|||
|
|
if not os.path.exists(db_path):
|
|||
|
|
current_dir_db = os.path.join(os.path.dirname(__file__), db_path)
|
|||
|
|
if os.path.exists(current_dir_db):
|
|||
|
|
self.db_path = current_dir_db
|
|||
|
|
logger.info(f"使用当前目录下的数据库文件: {current_dir_db}")
|
|||
|
|
else:
|
|||
|
|
raise FileNotFoundError(f"数据库文件不存在: {db_path} 和 {current_dir_db}")
|
|||
|
|
|
|||
|
|
logger.info(f"初始化产品难度评分器,数据库: {self.db_path}")
|
|||
|
|
|
|||
|
|
def connect_to_database(self) -> sqlite3.Connection:
|
|||
|
|
"""连接到SQLite数据库"""
|
|||
|
|
try:
|
|||
|
|
conn = sqlite3.connect(self.db_path)
|
|||
|
|
logger.success(f"成功连接到数据库: {self.db_path}")
|
|||
|
|
return conn
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"连接数据库失败: {e}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
def add_difficulty_score_column(self, conn: sqlite3.Connection):
|
|||
|
|
"""添加难度评分字段"""
|
|||
|
|
try:
|
|||
|
|
cursor = conn.cursor()
|
|||
|
|
|
|||
|
|
# 检查字段是否已存在
|
|||
|
|
cursor.execute("PRAGMA table_info(product_analysis)")
|
|||
|
|
columns = [row[1] for row in cursor.fetchall()]
|
|||
|
|
|
|||
|
|
if 'difficulty_score' not in columns:
|
|||
|
|
cursor.execute("ALTER TABLE product_analysis ADD COLUMN difficulty_score INTEGER")
|
|||
|
|
conn.commit()
|
|||
|
|
logger.success("成功添加difficulty_score字段")
|
|||
|
|
else:
|
|||
|
|
logger.info("difficulty_score字段已存在")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"添加难度评分字段失败: {e}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
def get_unscored_products(self, conn: sqlite3.Connection) -> List[Tuple]:
|
|||
|
|
"""
|
|||
|
|
获取未评分的产品数据
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
conn: 数据库连接
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
产品数据列表,每个元素为(id, ai_response)
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
cursor = conn.cursor()
|
|||
|
|
|
|||
|
|
# 查询未评分的产品
|
|||
|
|
cursor.execute("""
|
|||
|
|
SELECT id, ai_response
|
|||
|
|
FROM product_analysis
|
|||
|
|
WHERE difficulty_score IS NULL
|
|||
|
|
AND ai_response IS NOT NULL
|
|||
|
|
AND ai_response != ''
|
|||
|
|
""")
|
|||
|
|
|
|||
|
|
products = cursor.fetchall()
|
|||
|
|
logger.info(f"找到 {len(products)} 个未评分的产品")
|
|||
|
|
|
|||
|
|
return products
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"获取未评分产品数据失败: {e}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
def call_ollama_for_scoring(self, ai_response: str) -> Optional[int]:
|
|||
|
|
"""
|
|||
|
|
调用Ollama API进行难度评分
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
ai_response: AI响应内容
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
评分(0-100),失败时返回None
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
# 构建评分提示
|
|||
|
|
prompt = f"""
|
|||
|
|
请根据以下产品开发难度描述,给出一个0-100分的难度评分:
|
|||
|
|
|
|||
|
|
难度描述:{ai_response}
|
|||
|
|
|
|||
|
|
评分标准:
|
|||
|
|
- 90-100分:个人开发极其困难,需要大量专业知识和团队协作
|
|||
|
|
- 70-89分:相对困难,需要较强的技术能力和较多时间
|
|||
|
|
- 50-69分:中等难度,需要一定的技术基础
|
|||
|
|
- 30-49分:相对简单,有基础即可开发
|
|||
|
|
- 10-29分:非常简单,入门级别
|
|||
|
|
- 0-9分:极其简单,几乎无难度
|
|||
|
|
|
|||
|
|
请只返回一个数字,不要有任何其他文字。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
data = {
|
|||
|
|
"model": "qwen3:8b",
|
|||
|
|
"prompt": prompt.strip(),
|
|||
|
|
"stream": False
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
headers = {
|
|||
|
|
"Content-Type": "application/json"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
logger.info(f"调用Ollama API进行难度评分")
|
|||
|
|
|
|||
|
|
response = requests.post(
|
|||
|
|
self.api_url,
|
|||
|
|
headers=headers,
|
|||
|
|
data=json.dumps(data, ensure_ascii=False),
|
|||
|
|
timeout=60
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if response.status_code == 200:
|
|||
|
|
result = response.json()
|
|||
|
|
score_text = result.get("response", "").strip()
|
|||
|
|
|
|||
|
|
# 尝试解析评分
|
|||
|
|
try:
|
|||
|
|
score = int(score_text)
|
|||
|
|
# 确保评分在有效范围内
|
|||
|
|
score = max(0, min(100, score))
|
|||
|
|
logger.success(f"获得评分: {score}")
|
|||
|
|
return score
|
|||
|
|
except ValueError:
|
|||
|
|
logger.error(f"无法解析评分: {score_text}")
|
|||
|
|
return None
|
|||
|
|
else:
|
|||
|
|
logger.error(f"API调用失败: {response.status_code}, {response.text}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"调用Ollama API时出错: {e}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def update_difficulty_score(self, conn: sqlite3.Connection, product_id: int, score: int):
|
|||
|
|
"""更新产品难度评分"""
|
|||
|
|
try:
|
|||
|
|
cursor = conn.cursor()
|
|||
|
|
|
|||
|
|
cursor.execute("""
|
|||
|
|
UPDATE product_analysis
|
|||
|
|
SET difficulty_score = ?
|
|||
|
|
WHERE id = ?
|
|||
|
|
""", (score, product_id))
|
|||
|
|
|
|||
|
|
conn.commit()
|
|||
|
|
logger.success(f"更新产品ID {product_id} 的难度评分为: {score}")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"更新难度评分失败: {e}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
def score_products(self):
|
|||
|
|
"""评分所有未评分的产品"""
|
|||
|
|
logger.info("开始产品难度评分")
|
|||
|
|
|
|||
|
|
conn = None
|
|||
|
|
try:
|
|||
|
|
# 连接数据库
|
|||
|
|
conn = self.connect_to_database()
|
|||
|
|
|
|||
|
|
# 添加难度评分字段
|
|||
|
|
self.add_difficulty_score_column(conn)
|
|||
|
|
|
|||
|
|
# 获取未评分的产品
|
|||
|
|
products = self.get_unscored_products(conn)
|
|||
|
|
|
|||
|
|
if not products:
|
|||
|
|
logger.info("没有需要评分的产品")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
logger.info(f"准备评分 {len(products)} 个产品")
|
|||
|
|
|
|||
|
|
# 逐个评分
|
|||
|
|
success_count = 0
|
|||
|
|
for i, (product_id, ai_response) in enumerate(products, 1):
|
|||
|
|
logger.info(f"\n评分进度: {i}/{len(products)} - 产品ID: {product_id}")
|
|||
|
|
|
|||
|
|
# 调用AI进行评分
|
|||
|
|
score = self.call_ollama_for_scoring(ai_response)
|
|||
|
|
|
|||
|
|
if score is not None:
|
|||
|
|
# 更新数据库
|
|||
|
|
self.update_difficulty_score(conn, product_id, score)
|
|||
|
|
success_count += 1
|
|||
|
|
logger.success(f"评分完成: {score}分")
|
|||
|
|
else:
|
|||
|
|
logger.error(f"评分失败: 产品ID {product_id}")
|
|||
|
|
|
|||
|
|
# 延时避免API过载
|
|||
|
|
logger.info("等待2秒后继续...")
|
|||
|
|
time.sleep(2)
|
|||
|
|
|
|||
|
|
logger.success(f"评分完成! 成功评分 {success_count} 个产品")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"评分过程中出错: {e}")
|
|||
|
|
finally:
|
|||
|
|
if conn:
|
|||
|
|
conn.close()
|
|||
|
|
logger.info("数据库连接已关闭")
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
"""主函数"""
|
|||
|
|
# 配置日志
|
|||
|
|
logger.add("difficulty_scorer.log", rotation="10 MB", level="INFO")
|
|||
|
|
|
|||
|
|
# 创建评分器
|
|||
|
|
scorer = DifficultyScorer()
|
|||
|
|
|
|||
|
|
# 开始评分
|
|||
|
|
scorer.score_products()
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|