refactor: 重构配置管理和截图功能,切换LLM提供商至智谱AI
重构配置管理器以支持打包环境路径处理 将LLM分析器从OpenAI迁移至智谱AI API 替换Playwright截图功能为Selenium实现 更新默认配置中的API端点和模型
This commit is contained in:
@@ -13,9 +13,9 @@ class ConfigManager:
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
"llm_api": {
|
||||
"base_url": "https://integrate.api.nvidia.com/v1",
|
||||
"api_key": "",
|
||||
"model": "deepseek-ai/deepseek-r1",
|
||||
"base_url": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"model": "glm-4.7-flash",
|
||||
"timeout": 120,
|
||||
"retry_times": 3
|
||||
},
|
||||
@@ -46,9 +46,17 @@ class ConfigManager:
|
||||
}
|
||||
|
||||
def __init__(self, config_path: str = "config.json"):
|
||||
self.config_path = Path(config_path)
|
||||
import sys
|
||||
# 确定配置文件的正确路径
|
||||
if getattr(sys, 'frozen', False):
|
||||
# 打包后的环境
|
||||
current_dir = Path(sys.executable).parent
|
||||
self.config_path = current_dir / config_path
|
||||
else:
|
||||
# 开发环境
|
||||
self.config_path = Path(config_path)
|
||||
self.config = self._load_config()
|
||||
logger.info(f"配置管理器初始化完成,配置文件: {config_path}")
|
||||
logger.info(f"配置管理器初始化完成,配置文件: {self.config_path}")
|
||||
|
||||
def _load_config(self) -> Dict[str, Any]:
|
||||
"""加载配置文件"""
|
||||
@@ -59,13 +67,14 @@ class ConfigManager:
|
||||
loaded_config = json.load(f)
|
||||
# 合并默认配置,确保所有键都存在
|
||||
merged = self._merge_config(self.DEFAULT_CONFIG, loaded_config)
|
||||
logger.info("配置加载成功")
|
||||
logger.info(f"配置加载成功,目标URL: {merged.get('spider', {}).get('target_url', '未设置')}")
|
||||
return merged
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
logger.error(f"配置文件加载失败,使用默认配置: {e}")
|
||||
return self.DEFAULT_CONFIG.copy()
|
||||
else:
|
||||
logger.warning(f"配置文件不存在: {self.config_path},使用默认配置")
|
||||
logger.warning(f"默认配置目标URL: {self.DEFAULT_CONFIG.get('spider', {}).get('target_url', '未设置')}")
|
||||
return self.DEFAULT_CONFIG.copy()
|
||||
|
||||
def _merge_config(self, default: Dict, loaded: Dict) -> Dict:
|
||||
@@ -114,13 +123,13 @@ class ConfigManager:
|
||||
current[keys[-1]] = value
|
||||
return self.save_config()
|
||||
|
||||
def update_llm_api(self, base_url: str = None, api_key: str = None,
|
||||
def update_llm_api(self, api_key: str = None, base_url: str = None,
|
||||
model: str = None, timeout: int = None, retry_times: int = None):
|
||||
"""更新LLM API配置"""
|
||||
if base_url:
|
||||
self.config["llm_api"]["base_url"] = base_url
|
||||
if api_key:
|
||||
self.config["llm_api"]["api_key"] = api_key
|
||||
if base_url:
|
||||
self.config["llm_api"]["base_url"] = base_url
|
||||
if model:
|
||||
self.config["llm_api"]["model"] = model
|
||||
if timeout:
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
"""
|
||||
大模型分析模块 - 调用LLM API分析评论情感
|
||||
支持 OpenAI 兼容 API,包括 NVIDIA API
|
||||
支持智谱AI API
|
||||
"""
|
||||
import json
|
||||
import time
|
||||
import re
|
||||
from typing import Dict, Optional, Tuple, Any
|
||||
from openai import OpenAI
|
||||
from zai import ZhipuAiClient
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@@ -36,8 +36,8 @@ class LLMAnalyzer:
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
self.config = config
|
||||
self.base_url = config.get('base_url', '')
|
||||
self.api_key = config.get('api_key', '')
|
||||
self.base_url = config.get('base_url', '')
|
||||
self.model = config.get('model', '')
|
||||
self.timeout = config.get('timeout', 120)
|
||||
self.retry_times = config.get('retry_times', 3)
|
||||
@@ -47,34 +47,33 @@ class LLMAnalyzer:
|
||||
|
||||
logger.info(f"LLM分析器配置 - base_url: {self.base_url}, model: {self.model}, timeout: {self.timeout}s, retry: {self.retry_times}次")
|
||||
|
||||
if self.base_url and self.api_key:
|
||||
if self.api_key:
|
||||
self._init_client()
|
||||
else:
|
||||
logger.warning("LLM API 未配置,base_url 或 api_key 为空")
|
||||
logger.warning("LLM API 未配置,api_key 为空")
|
||||
|
||||
def _init_client(self):
|
||||
"""初始化OpenAI客户端"""
|
||||
"""初始化智谱AI客户端"""
|
||||
try:
|
||||
logger.info(f"初始化LLM客户端: {self.base_url}")
|
||||
self.client = OpenAI(
|
||||
logger.info(f"初始化智谱AI客户端: {self.base_url}")
|
||||
self.client = ZhipuAiClient(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
timeout=self.timeout
|
||||
base_url=self.base_url
|
||||
)
|
||||
logger.info("LLM客户端初始化成功")
|
||||
logger.info("智谱AI客户端初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"初始化LLM客户端失败: {e}")
|
||||
logger.error(f"初始化智谱AI客户端失败: {e}")
|
||||
|
||||
def update_config(self, config: Dict):
|
||||
"""更新配置"""
|
||||
self.config.update(config)
|
||||
self.base_url = config.get('base_url', self.base_url)
|
||||
self.api_key = config.get('api_key', self.api_key)
|
||||
self.base_url = config.get('base_url', self.base_url)
|
||||
self.model = config.get('model', self.model)
|
||||
self.timeout = config.get('timeout', self.timeout)
|
||||
self.retry_times = config.get('retry_times', self.retry_times)
|
||||
|
||||
if self.base_url and self.api_key:
|
||||
if self.api_key:
|
||||
self._init_client()
|
||||
|
||||
def analyze(self, comment: str) -> Tuple[Optional[int], Optional[str]]:
|
||||
@@ -97,16 +96,18 @@ class LLMAnalyzer:
|
||||
try:
|
||||
logger.info(f"API调用尝试 {attempt + 1}/{self.retry_times}")
|
||||
|
||||
logger.debug(f"发送请求到 {self.base_url}")
|
||||
logger.debug("发送请求到智谱AI API")
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
model="glm-4.7-flash",
|
||||
messages=[
|
||||
{"role": "system", "content": self.SYSTEM_PROMPT},
|
||||
{"role": "user", "content": f"请分析以下评论的情感倾向:\n\n{comment}"}
|
||||
],
|
||||
thinking={
|
||||
"type": "disabled", # 禁用深度思考模式
|
||||
},
|
||||
temperature=0.3,
|
||||
max_tokens=500,
|
||||
timeout=self.timeout
|
||||
max_tokens=500
|
||||
)
|
||||
|
||||
# 处理 deepseek-r1 的特殊结构(可能有 reasoning_content)
|
||||
|
||||
@@ -2,22 +2,35 @@
|
||||
截图管理器 - 用于在非交易时间截取上海证券交易所网站图表
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
|
||||
|
||||
try:
|
||||
from playwright.sync_api import sync_playwright, TimeoutError as PlaywrightTimeoutError
|
||||
except ImportError:
|
||||
logger.warning("playwright未安装,截图功能将不可用")
|
||||
# 使用Selenium替代Playwright
|
||||
from selenium import webdriver
|
||||
from selenium.webdriver.chrome.options import Options
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.support.ui import WebDriverWait
|
||||
from selenium.webdriver.support import expected_conditions as EC
|
||||
|
||||
|
||||
class ScreenshotManager:
|
||||
"""截图管理器"""
|
||||
|
||||
def __init__(self, screenshot_dir: str = "screenshots"):
|
||||
"""初始化截图管理器"""
|
||||
self.screenshot_dir = screenshot_dir
|
||||
"""
|
||||
初始化截图管理器
|
||||
"""
|
||||
# 确定截图目录的正确路径
|
||||
if getattr(sys, 'frozen', False):
|
||||
# 打包后的环境
|
||||
current_dir = os.path.dirname(sys.executable)
|
||||
self.screenshot_dir = os.path.join(current_dir, screenshot_dir)
|
||||
else:
|
||||
# 开发环境
|
||||
self.screenshot_dir = screenshot_dir
|
||||
|
||||
self.target_url = "https://www.sse.com.cn/"
|
||||
self.chart_xpath_pattern = "//*[@id=\"hq_area\"]"
|
||||
|
||||
@@ -32,102 +45,102 @@ class ScreenshotManager:
|
||||
返回截图文件路径,失败时返回空字符串
|
||||
"""
|
||||
try:
|
||||
# 检查playwright是否可用
|
||||
if 'sync_playwright' not in globals():
|
||||
logger.error("playwright未安装,无法使用截图功能")
|
||||
return ""
|
||||
# 配置Chrome选项
|
||||
chrome_options = Options()
|
||||
chrome_options.add_argument('--headless')
|
||||
chrome_options.add_argument('--no-sandbox')
|
||||
chrome_options.add_argument('--disable-dev-shm-usage')
|
||||
chrome_options.add_argument('--disable-gpu')
|
||||
chrome_options.add_argument('--window-size=1920,1080')
|
||||
chrome_options.add_argument('user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36')
|
||||
|
||||
with sync_playwright() as p:
|
||||
# 启动浏览器(无头模式,后台运行)
|
||||
browser = p.chromium.launch(headless=True)
|
||||
page = browser.new_page()
|
||||
|
||||
# 设置页面超时
|
||||
page.set_default_timeout(60000) # 60秒超时
|
||||
|
||||
# 访问目标网页
|
||||
logger.info(f"访问上海证券交易所网站: {self.target_url}")
|
||||
page.goto(self.target_url, wait_until="domcontentloaded")
|
||||
|
||||
# 等待页面加载完成
|
||||
page.wait_for_load_state("networkidle")
|
||||
|
||||
# 等待页面完全加载
|
||||
page.wait_for_timeout(5000) # 额外等待5秒
|
||||
|
||||
# 等待图表元素出现
|
||||
logger.info("等待图表元素加载...")
|
||||
|
||||
# 使用动态XPath模式查找图表元素
|
||||
chart_element = None
|
||||
selectors = [
|
||||
self.chart_xpath_pattern, # 主要选择器:highcharts-xxxxxxx-0格式
|
||||
"//*[contains(@class, 'highcharts')]",
|
||||
"//*[contains(@id, 'highcharts')]",
|
||||
"//svg",
|
||||
"//canvas",
|
||||
"//div[contains(@class, 'chart')]",
|
||||
"//div[contains(@class, 'graph')]"
|
||||
]
|
||||
|
||||
for selector in selectors:
|
||||
try:
|
||||
# 等待选择器出现
|
||||
page.wait_for_selector(selector, timeout=10000)
|
||||
elements = page.query_selector_all(selector)
|
||||
if elements:
|
||||
# 选择第一个可见的元素
|
||||
for element in elements:
|
||||
if element.is_visible():
|
||||
chart_element = element
|
||||
logger.info(f"找到图表元素: {selector}")
|
||||
break
|
||||
if chart_element:
|
||||
break
|
||||
except PlaywrightTimeoutError:
|
||||
logger.debug(f"选择器超时: {selector}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.debug(f"选择器错误 {selector}: {e}")
|
||||
continue
|
||||
|
||||
if not chart_element:
|
||||
logger.warning("未找到任何图表元素,尝试截取整个页面")
|
||||
# 如果找不到图表元素,截取整个页面
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
screenshot_path = os.path.join(self.screenshot_dir, f"sse_page_{timestamp}.png")
|
||||
page.screenshot(path=screenshot_path)
|
||||
logger.info(f"截取整个页面: {screenshot_path}")
|
||||
browser.close()
|
||||
return screenshot_path
|
||||
|
||||
# 检查元素是否可见
|
||||
if not chart_element.is_visible():
|
||||
logger.warning("图表元素不可见,尝试滚动到元素位置")
|
||||
chart_element.scroll_into_view_if_needed()
|
||||
|
||||
# 生成截图文件名
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
screenshot_path = os.path.join(self.screenshot_dir, f"sse_chart_{timestamp}.png")
|
||||
|
||||
# 截取图表元素
|
||||
logger.info("开始截取图表元素")
|
||||
|
||||
# 直接使用元素进行截图
|
||||
chart_element.screenshot(path=screenshot_path)
|
||||
|
||||
logger.info(f"✅ 图表截图完成,保存至: {screenshot_path}")
|
||||
|
||||
# 关闭浏览器
|
||||
browser.close()
|
||||
|
||||
# 创建WebDriver
|
||||
logger.info("初始化Chrome WebDriver...")
|
||||
driver = webdriver.Chrome(options=chrome_options)
|
||||
driver.set_page_load_timeout(60)
|
||||
driver.implicitly_wait(30)
|
||||
|
||||
# 访问目标网页
|
||||
logger.info(f"访问上海证券交易所网站: {self.target_url}")
|
||||
driver.get(self.target_url)
|
||||
|
||||
# 等待页面加载完成
|
||||
logger.info("等待页面加载完成...")
|
||||
WebDriverWait(driver, 30).until(
|
||||
EC.presence_of_element_located((By.XPATH, "//body"))
|
||||
)
|
||||
# 额外等待确保数据加载完成
|
||||
import time
|
||||
time.sleep(5)
|
||||
|
||||
# 等待图表元素出现
|
||||
logger.info("等待图表元素加载...")
|
||||
|
||||
# 使用动态XPath模式查找图表元素
|
||||
chart_element = None
|
||||
selectors = [
|
||||
self.chart_xpath_pattern, # 主要选择器:highcharts-xxxxxxx-0格式
|
||||
"//*[contains(@class, 'highcharts')]",
|
||||
"//*[contains(@id, 'highcharts')]",
|
||||
"//svg",
|
||||
"//canvas",
|
||||
"//div[contains(@class, 'chart')]",
|
||||
"//div[contains(@class, 'graph')]"
|
||||
]
|
||||
|
||||
for selector in selectors:
|
||||
try:
|
||||
# 等待选择器出现
|
||||
element = WebDriverWait(driver, 10).until(
|
||||
EC.presence_of_element_located((By.XPATH, selector))
|
||||
)
|
||||
# 检查元素是否可见
|
||||
if element.is_displayed():
|
||||
chart_element = element
|
||||
logger.info(f"找到图表元素: {selector}")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.debug(f"选择器错误 {selector}: {e}")
|
||||
continue
|
||||
|
||||
# 生成截图文件名
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
if not chart_element:
|
||||
logger.warning("未找到任何图表元素,尝试截取整个页面")
|
||||
# 如果找不到图表元素,截取整个页面
|
||||
screenshot_path = os.path.join(self.screenshot_dir, f"sse_page_{timestamp}.png")
|
||||
driver.save_screenshot(screenshot_path)
|
||||
logger.info(f"截取整个页面: {screenshot_path}")
|
||||
driver.quit()
|
||||
return screenshot_path
|
||||
|
||||
except PlaywrightTimeoutError as e:
|
||||
logger.error(f"页面加载超时: {e}")
|
||||
return ""
|
||||
|
||||
# 检查元素是否可见
|
||||
if not chart_element.is_displayed():
|
||||
logger.warning("图表元素不可见,尝试滚动到元素位置")
|
||||
driver.execute_script("arguments[0].scrollIntoView({behavior: 'smooth', block: 'center', inline: 'center'});", chart_element)
|
||||
time.sleep(2)
|
||||
|
||||
# 生成截图文件名
|
||||
screenshot_path = os.path.join(self.screenshot_dir, f"sse_chart_{timestamp}.png")
|
||||
|
||||
# 截取图表元素
|
||||
logger.info("开始截取图表元素")
|
||||
|
||||
# 直接使用元素进行截图
|
||||
chart_element.screenshot(screenshot_path)
|
||||
|
||||
logger.info(f"✅ 图表截图完成,保存至: {screenshot_path}")
|
||||
|
||||
# 关闭浏览器
|
||||
driver.quit()
|
||||
|
||||
return screenshot_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"截图过程中发生错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return ""
|
||||
|
||||
def get_latest_screenshot(self) -> str:
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 43 KiB After Width: | Height: | Size: 61 KiB |
Reference in New Issue
Block a user