diff --git a/config_manager.py b/config_manager.py index 94cf8d9..04ed9aa 100644 --- a/config_manager.py +++ b/config_manager.py @@ -13,9 +13,9 @@ class ConfigManager: DEFAULT_CONFIG = { "llm_api": { - "base_url": "https://integrate.api.nvidia.com/v1", "api_key": "", - "model": "deepseek-ai/deepseek-r1", + "base_url": "https://open.bigmodel.cn/api/paas/v4", + "model": "glm-4.7-flash", "timeout": 120, "retry_times": 3 }, @@ -46,9 +46,17 @@ class ConfigManager: } def __init__(self, config_path: str = "config.json"): - self.config_path = Path(config_path) + import sys + # 确定配置文件的正确路径 + if getattr(sys, 'frozen', False): + # 打包后的环境 + current_dir = Path(sys.executable).parent + self.config_path = current_dir / config_path + else: + # 开发环境 + self.config_path = Path(config_path) self.config = self._load_config() - logger.info(f"配置管理器初始化完成,配置文件: {config_path}") + logger.info(f"配置管理器初始化完成,配置文件: {self.config_path}") def _load_config(self) -> Dict[str, Any]: """加载配置文件""" @@ -59,13 +67,14 @@ class ConfigManager: loaded_config = json.load(f) # 合并默认配置,确保所有键都存在 merged = self._merge_config(self.DEFAULT_CONFIG, loaded_config) - logger.info("配置加载成功") + logger.info(f"配置加载成功,目标URL: {merged.get('spider', {}).get('target_url', '未设置')}") return merged except (json.JSONDecodeError, IOError) as e: logger.error(f"配置文件加载失败,使用默认配置: {e}") return self.DEFAULT_CONFIG.copy() else: logger.warning(f"配置文件不存在: {self.config_path},使用默认配置") + logger.warning(f"默认配置目标URL: {self.DEFAULT_CONFIG.get('spider', {}).get('target_url', '未设置')}") return self.DEFAULT_CONFIG.copy() def _merge_config(self, default: Dict, loaded: Dict) -> Dict: @@ -114,13 +123,13 @@ class ConfigManager: current[keys[-1]] = value return self.save_config() - def update_llm_api(self, base_url: str = None, api_key: str = None, + def update_llm_api(self, api_key: str = None, base_url: str = None, model: str = None, timeout: int = None, retry_times: int = None): """更新LLM API配置""" - if base_url: - self.config["llm_api"]["base_url"] = base_url if api_key: self.config["llm_api"]["api_key"] = api_key + if base_url: + self.config["llm_api"]["base_url"] = base_url if model: self.config["llm_api"]["model"] = model if timeout: diff --git a/llm_analyzer.py b/llm_analyzer.py index b884369..f09099b 100644 --- a/llm_analyzer.py +++ b/llm_analyzer.py @@ -1,12 +1,12 @@ """ 大模型分析模块 - 调用LLM API分析评论情感 -支持 OpenAI 兼容 API,包括 NVIDIA API +支持智谱AI API """ import json import time import re from typing import Dict, Optional, Tuple, Any -from openai import OpenAI +from zai import ZhipuAiClient from loguru import logger @@ -36,8 +36,8 @@ class LLMAnalyzer: def __init__(self, config: Dict): self.config = config - self.base_url = config.get('base_url', '') self.api_key = config.get('api_key', '') + self.base_url = config.get('base_url', '') self.model = config.get('model', '') self.timeout = config.get('timeout', 120) self.retry_times = config.get('retry_times', 3) @@ -47,34 +47,33 @@ class LLMAnalyzer: logger.info(f"LLM分析器配置 - base_url: {self.base_url}, model: {self.model}, timeout: {self.timeout}s, retry: {self.retry_times}次") - if self.base_url and self.api_key: + if self.api_key: self._init_client() else: - logger.warning("LLM API 未配置,base_url 或 api_key 为空") + logger.warning("LLM API 未配置,api_key 为空") def _init_client(self): - """初始化OpenAI客户端""" + """初始化智谱AI客户端""" try: - logger.info(f"初始化LLM客户端: {self.base_url}") - self.client = OpenAI( + logger.info(f"初始化智谱AI客户端: {self.base_url}") + self.client = ZhipuAiClient( api_key=self.api_key, - base_url=self.base_url, - timeout=self.timeout + base_url=self.base_url ) - logger.info("LLM客户端初始化成功") + logger.info("智谱AI客户端初始化成功") except Exception as e: - logger.error(f"初始化LLM客户端失败: {e}") + logger.error(f"初始化智谱AI客户端失败: {e}") def update_config(self, config: Dict): """更新配置""" self.config.update(config) - self.base_url = config.get('base_url', self.base_url) self.api_key = config.get('api_key', self.api_key) + self.base_url = config.get('base_url', self.base_url) self.model = config.get('model', self.model) self.timeout = config.get('timeout', self.timeout) self.retry_times = config.get('retry_times', self.retry_times) - if self.base_url and self.api_key: + if self.api_key: self._init_client() def analyze(self, comment: str) -> Tuple[Optional[int], Optional[str]]: @@ -97,16 +96,18 @@ class LLMAnalyzer: try: logger.info(f"API调用尝试 {attempt + 1}/{self.retry_times}") - logger.debug(f"发送请求到 {self.base_url}") + logger.debug("发送请求到智谱AI API") response = self.client.chat.completions.create( - model=self.model, + model="glm-4.7-flash", messages=[ {"role": "system", "content": self.SYSTEM_PROMPT}, {"role": "user", "content": f"请分析以下评论的情感倾向:\n\n{comment}"} ], + thinking={ + "type": "disabled", # 禁用深度思考模式 + }, temperature=0.3, - max_tokens=500, - timeout=self.timeout + max_tokens=500 ) # 处理 deepseek-r1 的特殊结构(可能有 reasoning_content) diff --git a/main.exe b/main.exe deleted file mode 100644 index 537a664..0000000 Binary files a/main.exe and /dev/null differ diff --git a/screenshot_manager.py b/screenshot_manager.py index cff90d0..4831c5b 100644 --- a/screenshot_manager.py +++ b/screenshot_manager.py @@ -2,22 +2,35 @@ 截图管理器 - 用于在非交易时间截取上海证券交易所网站图表 """ import os +import sys from datetime import datetime from loguru import logger -try: - from playwright.sync_api import sync_playwright, TimeoutError as PlaywrightTimeoutError -except ImportError: - logger.warning("playwright未安装,截图功能将不可用") +# 使用Selenium替代Playwright +from selenium import webdriver +from selenium.webdriver.chrome.options import Options +from selenium.webdriver.common.by import By +from selenium.webdriver.support.ui import WebDriverWait +from selenium.webdriver.support import expected_conditions as EC class ScreenshotManager: """截图管理器""" def __init__(self, screenshot_dir: str = "screenshots"): - """初始化截图管理器""" - self.screenshot_dir = screenshot_dir + """ + 初始化截图管理器 + """ + # 确定截图目录的正确路径 + if getattr(sys, 'frozen', False): + # 打包后的环境 + current_dir = os.path.dirname(sys.executable) + self.screenshot_dir = os.path.join(current_dir, screenshot_dir) + else: + # 开发环境 + self.screenshot_dir = screenshot_dir + self.target_url = "https://www.sse.com.cn/" self.chart_xpath_pattern = "//*[@id=\"hq_area\"]" @@ -32,102 +45,102 @@ class ScreenshotManager: 返回截图文件路径,失败时返回空字符串 """ try: - # 检查playwright是否可用 - if 'sync_playwright' not in globals(): - logger.error("playwright未安装,无法使用截图功能") - return "" + # 配置Chrome选项 + chrome_options = Options() + chrome_options.add_argument('--headless') + chrome_options.add_argument('--no-sandbox') + chrome_options.add_argument('--disable-dev-shm-usage') + chrome_options.add_argument('--disable-gpu') + chrome_options.add_argument('--window-size=1920,1080') + chrome_options.add_argument('user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36') - with sync_playwright() as p: - # 启动浏览器(无头模式,后台运行) - browser = p.chromium.launch(headless=True) - page = browser.new_page() - - # 设置页面超时 - page.set_default_timeout(60000) # 60秒超时 - - # 访问目标网页 - logger.info(f"访问上海证券交易所网站: {self.target_url}") - page.goto(self.target_url, wait_until="domcontentloaded") - - # 等待页面加载完成 - page.wait_for_load_state("networkidle") - - # 等待页面完全加载 - page.wait_for_timeout(5000) # 额外等待5秒 - - # 等待图表元素出现 - logger.info("等待图表元素加载...") - - # 使用动态XPath模式查找图表元素 - chart_element = None - selectors = [ - self.chart_xpath_pattern, # 主要选择器:highcharts-xxxxxxx-0格式 - "//*[contains(@class, 'highcharts')]", - "//*[contains(@id, 'highcharts')]", - "//svg", - "//canvas", - "//div[contains(@class, 'chart')]", - "//div[contains(@class, 'graph')]" - ] - - for selector in selectors: - try: - # 等待选择器出现 - page.wait_for_selector(selector, timeout=10000) - elements = page.query_selector_all(selector) - if elements: - # 选择第一个可见的元素 - for element in elements: - if element.is_visible(): - chart_element = element - logger.info(f"找到图表元素: {selector}") - break - if chart_element: - break - except PlaywrightTimeoutError: - logger.debug(f"选择器超时: {selector}") - continue - except Exception as e: - logger.debug(f"选择器错误 {selector}: {e}") - continue - - if not chart_element: - logger.warning("未找到任何图表元素,尝试截取整个页面") - # 如果找不到图表元素,截取整个页面 - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - screenshot_path = os.path.join(self.screenshot_dir, f"sse_page_{timestamp}.png") - page.screenshot(path=screenshot_path) - logger.info(f"截取整个页面: {screenshot_path}") - browser.close() - return screenshot_path - - # 检查元素是否可见 - if not chart_element.is_visible(): - logger.warning("图表元素不可见,尝试滚动到元素位置") - chart_element.scroll_into_view_if_needed() - - # 生成截图文件名 - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - screenshot_path = os.path.join(self.screenshot_dir, f"sse_chart_{timestamp}.png") - - # 截取图表元素 - logger.info("开始截取图表元素") - - # 直接使用元素进行截图 - chart_element.screenshot(path=screenshot_path) - - logger.info(f"✅ 图表截图完成,保存至: {screenshot_path}") - - # 关闭浏览器 - browser.close() - + # 创建WebDriver + logger.info("初始化Chrome WebDriver...") + driver = webdriver.Chrome(options=chrome_options) + driver.set_page_load_timeout(60) + driver.implicitly_wait(30) + + # 访问目标网页 + logger.info(f"访问上海证券交易所网站: {self.target_url}") + driver.get(self.target_url) + + # 等待页面加载完成 + logger.info("等待页面加载完成...") + WebDriverWait(driver, 30).until( + EC.presence_of_element_located((By.XPATH, "//body")) + ) + # 额外等待确保数据加载完成 + import time + time.sleep(5) + + # 等待图表元素出现 + logger.info("等待图表元素加载...") + + # 使用动态XPath模式查找图表元素 + chart_element = None + selectors = [ + self.chart_xpath_pattern, # 主要选择器:highcharts-xxxxxxx-0格式 + "//*[contains(@class, 'highcharts')]", + "//*[contains(@id, 'highcharts')]", + "//svg", + "//canvas", + "//div[contains(@class, 'chart')]", + "//div[contains(@class, 'graph')]" + ] + + for selector in selectors: + try: + # 等待选择器出现 + element = WebDriverWait(driver, 10).until( + EC.presence_of_element_located((By.XPATH, selector)) + ) + # 检查元素是否可见 + if element.is_displayed(): + chart_element = element + logger.info(f"找到图表元素: {selector}") + break + except Exception as e: + logger.debug(f"选择器错误 {selector}: {e}") + continue + + # 生成截图文件名 + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + if not chart_element: + logger.warning("未找到任何图表元素,尝试截取整个页面") + # 如果找不到图表元素,截取整个页面 + screenshot_path = os.path.join(self.screenshot_dir, f"sse_page_{timestamp}.png") + driver.save_screenshot(screenshot_path) + logger.info(f"截取整个页面: {screenshot_path}") + driver.quit() return screenshot_path - - except PlaywrightTimeoutError as e: - logger.error(f"页面加载超时: {e}") - return "" + + # 检查元素是否可见 + if not chart_element.is_displayed(): + logger.warning("图表元素不可见,尝试滚动到元素位置") + driver.execute_script("arguments[0].scrollIntoView({behavior: 'smooth', block: 'center', inline: 'center'});", chart_element) + time.sleep(2) + + # 生成截图文件名 + screenshot_path = os.path.join(self.screenshot_dir, f"sse_chart_{timestamp}.png") + + # 截取图表元素 + logger.info("开始截取图表元素") + + # 直接使用元素进行截图 + chart_element.screenshot(screenshot_path) + + logger.info(f"✅ 图表截图完成,保存至: {screenshot_path}") + + # 关闭浏览器 + driver.quit() + + return screenshot_path + except Exception as e: logger.error(f"截图过程中发生错误: {e}") + import traceback + traceback.print_exc() return "" def get_latest_screenshot(self) -> str: diff --git a/sse_screenshot.png b/sse_screenshot.png index 2183f16..ddd8a60 100644 Binary files a/sse_screenshot.png and b/sse_screenshot.png differ