diff --git a/api.txt b/api.txt new file mode 100644 index 0000000..e69de29 diff --git a/build.bat b/build.bat new file mode 100644 index 0000000..938a814 --- /dev/null +++ b/build.bat @@ -0,0 +1,25 @@ +@echo off +echo ======================================== +echo 股吧人气指示器 - 打包工具 +echo ======================================== + +REM 检查 pyinstaller 是否安装 +pip show pyinstaller >nul 2>&1 +if errorlevel 1 ( + echo 正在安装 pyinstaller... + pip install pyinstaller +) + +echo 开始打包... +pyinstaller build.spec + +if exist "dist\guba-indicator.exe" ( + echo ======================================== + echo 打包成功! + echo 可执行文件位置: dist\guba-indicator.exe + echo ======================================== +) else ( + echo 打包失败,请检查错误信息 +) + +pause diff --git a/build.spec b/build.spec new file mode 100644 index 0000000..f4dabed --- /dev/null +++ b/build.spec @@ -0,0 +1,58 @@ +# -*- mode: python ; coding: utf-8 -*- +import sys +import os + +block_cipher = None + +# 添加项目路径 +project_path = os.path.abspath('.') +if project_path not in sys.path: + sys.path.insert(0, project_path) + +a = Analysis( + ['main.py'], + pathex=['.', project_path], + binaries=[], + datas=[], + hiddenimports=[], + hookspath=[], + hooksconfig={}, + runtime_hooks=[], + excludes=['tkinter', 'IPython', 'pytest', 'matplotlib', 'pandas', 'scipy', 'numpy'], + win_no_prefer_redirects=False, + win_private_assemblies=False, + cipher=block_cipher, + noarchive=False, +) + +pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher) + +exe = EXE( + pyz, + a.scripts, + [], + exclude_binaries=True, + name='guba-indicator', + debug=False, + bootloader_ignore_signals=False, + strip=False, + upx=True, + console=False, + disable_windowed_traceback=False, + argv_emulation=False, + target_arch=None, + codesign_identity=None, + entitlements_file=None, + icon='indicator.ico', +) + +coll = COLLECT( + exe, + a.binaries, + a.zipfiles, + a.datas, + strip=False, + upx=True, + upx_exclude=[], + name='guba-indicator', +) \ No newline at end of file diff --git a/config_manager.py b/config_manager.py index 90fc92b..ecbe9bc 100644 --- a/config_manager.py +++ b/config_manager.py @@ -5,6 +5,7 @@ import json import os from typing import Any, Dict from pathlib import Path +from loguru import logger class ConfigManager: @@ -12,10 +13,10 @@ class ConfigManager: DEFAULT_CONFIG = { "llm_api": { - "base_url": "https://api.openai.com/v1", + "base_url": "https://integrate.api.nvidia.com/v1", "api_key": "", - "model": "gpt-3.5-turbo", - "timeout": 30, + "model": "deepseek-ai/deepseek-r1", + "timeout": 120, "retry_times": 3 }, "spider": { @@ -46,19 +47,24 @@ class ConfigManager: def __init__(self, config_path: str = "config.json"): self.config_path = Path(config_path) self.config = self._load_config() + logger.info(f"配置管理器初始化完成,配置文件: {config_path}") def _load_config(self) -> Dict[str, Any]: """加载配置文件""" if self.config_path.exists(): try: + logger.info(f"从文件加载配置: {self.config_path}") with open(self.config_path, 'r', encoding='utf-8') as f: loaded_config = json.load(f) # 合并默认配置,确保所有键都存在 - return self._merge_config(self.DEFAULT_CONFIG, loaded_config) + merged = self._merge_config(self.DEFAULT_CONFIG, loaded_config) + logger.info("配置加载成功") + return merged except (json.JSONDecodeError, IOError) as e: - print(f"配置文件加载失败,使用默认配置: {e}") + logger.error(f"配置文件加载失败,使用默认配置: {e}") return self.DEFAULT_CONFIG.copy() else: + logger.warning(f"配置文件不存在: {self.config_path},使用默认配置") return self.DEFAULT_CONFIG.copy() def _merge_config(self, default: Dict, loaded: Dict) -> Dict: @@ -74,11 +80,13 @@ class ConfigManager: def save_config(self) -> bool: """保存配置到文件""" try: + logger.debug(f"保存配置到文件: {self.config_path}") with open(self.config_path, 'w', encoding='utf-8') as f: json.dump(self.config, f, ensure_ascii=False, indent=4) + logger.info("配置保存成功") return True except IOError as e: - print(f"配置保存失败: {e}") + logger.error(f"配置保存失败: {e}") return False def get(self, *keys: str, default: Any = None) -> Any: @@ -118,6 +126,7 @@ class ConfigManager: self.config["llm_api"]["timeout"] = timeout if retry_times: self.config["llm_api"]["retry_times"] = retry_times + logger.info("LLM API配置已更新") self.save_config() def update_spider(self, target_url: str = None, xpath: str = None, @@ -136,6 +145,7 @@ class ConfigManager: self.config["spider"]["retry_times"] = retry_times if retry_interval: self.config["spider"]["retry_interval"] = retry_interval + logger.info("爬虫配置已更新") self.save_config() def update_ui(self, opacity: float = None, is_on_top: bool = None, @@ -149,6 +159,7 @@ class ConfigManager: self.config["ui"]["thresholds"]["cold"] = cold_threshold if warm_threshold is not None: self.config["ui"]["thresholds"]["warm"] = warm_threshold + logger.info("UI配置已更新") self.save_config() @property diff --git a/database.py b/database.py index ad7f54e..d329ddb 100644 --- a/database.py +++ b/database.py @@ -7,6 +7,7 @@ import json from datetime import datetime from typing import List, Dict, Optional, Tuple from pathlib import Path +from loguru import logger class DatabaseManager: @@ -15,9 +16,11 @@ class DatabaseManager: def __init__(self, db_path: str = "guba.db"): self.db_path = Path(db_path) self._init_db() + logger.info(f"数据库管理器初始化完成,数据库路径: {db_path}") def _init_db(self): """初始化数据库表""" + logger.debug("初始化数据库表") conn = self._get_connection() cursor = conn.cursor() @@ -59,6 +62,7 @@ class DatabaseManager: conn.commit() conn.close() + logger.debug("数据库表初始化完成") def _get_connection(self) -> sqlite3.Connection: """获取数据库连接""" @@ -83,6 +87,7 @@ class DatabaseManager: content_hash = self.hash_content(content) if self.is_comment_exists(content_hash): + logger.debug(f"评论已存在,跳过: {content[:30]}...") return None # 已存在 conn = self._get_connection() @@ -94,6 +99,7 @@ class DatabaseManager: comment_id = cursor.lastrowid conn.commit() conn.close() + logger.info(f"添加新评论,ID: {comment_id}") return comment_id def add_comments_batch(self, comments: List[Dict]) -> List[int]: @@ -108,16 +114,23 @@ class DatabaseManager: content_hash = self.hash_content(content) if self.is_comment_exists(content_hash): + logger.debug(f"评论已存在,跳过: {content[:30]}...") continue - cursor.execute(''' - INSERT INTO comments (content, content_hash, url, created_at) - VALUES (?, ?, ?, ?) - ''', (content, content_hash, url, datetime.now().isoformat())) - new_ids.append(cursor.lastrowid) + try: + cursor.execute(''' + INSERT OR IGNORE INTO comments (content, content_hash, url, created_at) + VALUES (?, ?, ?, ?) + ''', (content, content_hash, url, datetime.now().isoformat())) + if cursor.rowcount > 0: + new_ids.append(cursor.lastrowid) + except Exception as e: + logger.warning(f"插入评论失败(可能已存在): {e}") + continue conn.commit() conn.close() + logger.info(f"批量添加评论完成,新增 {len(new_ids)} 条") return new_ids def get_unanalyzed_comments(self, limit: int = 50) -> List[Dict]: @@ -132,7 +145,9 @@ class DatabaseManager: ''', (limit,)) rows = cursor.fetchall() conn.close() - return [{'id': row[0], 'content': row[1], 'url': row[2]} for row in rows] + result = [{'id': row[0], 'content': row[1], 'url': row[2]} for row in rows] + logger.debug(f"获取到 {len(result)} 条未分析评论") + return result def mark_analyzed(self, comment_id: int, sentiment_score: float, analysis_text: str): """标记评论已分析""" @@ -154,6 +169,7 @@ class DatabaseManager: conn.commit() conn.close() + logger.debug(f"标记评论 {comment_id} 已分析,分数: {sentiment_score}") def get_latest_sentiment_score(self) -> Optional[float]: """获取最新的情感分数""" @@ -167,20 +183,34 @@ class DatabaseManager: ''') row = cursor.fetchone() conn.close() - return row[0] if row else None + score = row[0] if row else None + logger.debug(f"最新情感分数: {score}") + return score - def get_all_scores(self) -> List[float]: - """获取所有已分析的分数""" + def get_all_scores(self, limit: int = None) -> List[float]: + """获取已分析的分数,可指定数量限制""" conn = self._get_connection() cursor = conn.cursor() - cursor.execute(''' - SELECT sentiment_score FROM comments - WHERE analyzed = 1 AND sentiment_score IS NOT NULL - ORDER BY analyzed_at DESC - ''') + + if limit: + cursor.execute(''' + SELECT sentiment_score FROM comments + WHERE analyzed = 1 AND sentiment_score IS NOT NULL + ORDER BY analyzed_at DESC + LIMIT ? + ''', (limit,)) + else: + cursor.execute(''' + SELECT sentiment_score FROM comments + WHERE analyzed = 1 AND sentiment_score IS NOT NULL + ORDER BY analyzed_at DESC + ''') + rows = cursor.fetchall() conn.close() - return [row[0] for row in rows if row[0] is not None] + scores = [row[0] for row in rows if row[0] is not None] + logger.debug(f"获取到 {len(scores)} 个分数") + return scores def get_comment_count(self) -> int: """获取评论总数""" @@ -189,6 +219,7 @@ class DatabaseManager: cursor.execute('SELECT COUNT(*) FROM comments') count = cursor.fetchone()[0] conn.close() + logger.debug(f"评论总数: {count}") return count def get_analyzed_count(self) -> int: @@ -198,6 +229,7 @@ class DatabaseManager: cursor.execute('SELECT COUNT(*) FROM comments WHERE analyzed = 1') count = cursor.fetchone()[0] conn.close() + logger.debug(f"已分析评论数: {count}") return count def get_recent_comments(self, limit: int = 10) -> List[Dict]: @@ -212,8 +244,10 @@ class DatabaseManager: ''', (limit,)) rows = cursor.fetchall() conn.close() - return [ + result = [ {'id': row[0], 'content': row[1][:50] + '...' if len(row[1]) > 50 else row[1], 'score': row[2], 'analyzed_at': row[3]} for row in rows ] + logger.debug(f"获取到 {len(result)} 条最近评论") + return result diff --git a/indicator.ico b/indicator.ico new file mode 100644 index 0000000..e79ca9e Binary files /dev/null and b/indicator.ico differ diff --git a/llm_analyzer.py b/llm_analyzer.py index 365e8e0..132f539 100644 --- a/llm_analyzer.py +++ b/llm_analyzer.py @@ -1,11 +1,13 @@ """ 大模型分析模块 - 调用LLM API分析评论情感 +支持 OpenAI 兼容 API,包括 NVIDIA API """ import json import time import re -from typing import Dict, Optional, Tuple -from openai import OpenAI, OpenAIError +from typing import Dict, Optional, Tuple, Any +from openai import OpenAI +from loguru import logger class LLMAnalyzer: @@ -31,26 +33,34 @@ class LLMAnalyzer: def __init__(self, config: Dict): self.config = config - self.base_url = config.get('base_url', 'https://api.openai.com/v1') + self.base_url = config.get('base_url', '') self.api_key = config.get('api_key', '') - self.model = config.get('model', 'gpt-3.5-turbo') - self.timeout = config.get('timeout', 30) + self.model = config.get('model', '') + self.timeout = config.get('timeout', 120) self.retry_times = config.get('retry_times', 3) self.client = None - if self.api_key: + self.last_result = None # 保存最后一次分析结果 + + logger.info(f"LLM分析器配置 - base_url: {self.base_url}, model: {self.model}, timeout: {self.timeout}s, retry: {self.retry_times}次") + + if self.base_url and self.api_key: self._init_client() + else: + logger.warning("LLM API 未配置,base_url 或 api_key 为空") def _init_client(self): """初始化OpenAI客户端""" try: + logger.info(f"初始化LLM客户端: {self.base_url}") self.client = OpenAI( api_key=self.api_key, base_url=self.base_url, timeout=self.timeout ) + logger.info("LLM客户端初始化成功") except Exception as e: - print(f"初始化LLM客户端失败: {e}") + logger.error(f"初始化LLM客户端失败: {e}") def update_config(self, config: Dict): """更新配置""" @@ -61,7 +71,7 @@ class LLMAnalyzer: self.timeout = config.get('timeout', self.timeout) self.retry_times = config.get('retry_times', self.retry_times) - if self.api_key: + if self.base_url and self.api_key: self._init_client() def analyze(self, comment: str) -> Tuple[Optional[int], Optional[str]]: @@ -70,13 +80,21 @@ class LLMAnalyzer: 返回 (score, label) """ if not self.client: + logger.error("LLM客户端未初始化,请检查API配置") return None, "LLM未配置" if not comment or not comment.strip(): + logger.warning("评论内容为空") return None, "评论为空" + logger.debug(f"开始分析评论: {comment[:50]}...") + logger.debug(f"使用模型: {self.model}, 超时设置: {self.timeout}秒") + for attempt in range(self.retry_times): try: + logger.info(f"API调用尝试 {attempt + 1}/{self.retry_times}") + + logger.debug(f"发送请求到 {self.base_url}") response = self.client.chat.completions.create( model=self.model, messages=[ @@ -84,23 +102,45 @@ class LLMAnalyzer: {"role": "user", "content": f"请分析以下评论的情感倾向:\n\n{comment}"} ], temperature=0.3, - max_tokens=200 + max_tokens=500, + timeout=self.timeout ) - result_text = response.choices[0].message.content.strip() + # 处理 deepseek-r1 的特殊结构(可能有 reasoning_content) + message = response.choices[0].message + + # 获取推理过程(如果有) + reasoning = getattr(message, 'reasoning_content', None) + if reasoning: + logger.debug(f"推理过程: {reasoning[:100]}...") + + # 获取最终回答 + result_text = message.content.strip() if message.content else "" + logger.debug(f"API返回原始内容: {result_text[:100]}...") + score, label = self._parse_response(result_text) + # 保存最后结果 + self.last_result = { + 'score': score, + 'label': label, + 'reasoning': reasoning, + 'raw_response': result_text + } + if score is not None: + logger.info(f"分析完成: {score}分 - {label}") return score, label - except OpenAIError as e: - print(f"API调用失败 (尝试 {attempt + 1}/{self.retry_times}): {e}") - if attempt < self.retry_times - 1: - time.sleep(2 ** attempt) # 指数退避 except Exception as e: - print(f"分析过程出错: {e}") - break + logger.warning(f"API调用失败 (尝试 {attempt + 1}/{self.retry_times}): {type(e).__name__}: {e}") + logger.debug(f"错误详情: {str(e)}") + if attempt < self.retry_times - 1: + wait_time = 2 ** attempt + logger.info(f"等待 {wait_time} 秒后重试...") + time.sleep(wait_time) # 指数退避 + logger.error(f"所有 {self.retry_times} 次重试均失败") return None, "分析失败" def _parse_response(self, response: str) -> Tuple[Optional[int], Optional[str]]: @@ -113,40 +153,57 @@ class LLMAnalyzer: # 验证分数范围 score = max(0, min(100, int(score))) + logger.debug(f"JSON解析成功: {score} - {label}") return score, label except json.JSONDecodeError: - # 尝试从文本中提取 - pass + logger.debug("JSON解析失败,尝试文本提取") - # 尝试从文本中提取数字 + # 尝试从文本中提取 numbers = re.findall(r'\b(\d{1,3})\b', response) if numbers: score = int(numbers[0]) score = max(0, min(100, score)) # 提取标签 - label_match = re.search(r'["']([^"']+)["']', response) + label_match = re.search(r'"([^"]+)"', response) if label_match: label = label_match.group(1) else: label = response.split('\n')[0][:20] if response else '无法判断' + logger.debug(f"文本提取成功: {score} - {label}") return score, label + logger.warning("无法解析响应") return None, "解析失败" + def get_last_result(self) -> Optional[Dict[str, Any]]: + """获取最后一次分析结果""" + return self.last_result + def analyze_batch(self, comments: list, delay: float = 1.0) -> list: """ 批量分析评论 delay: 每次调用之间的延迟(秒) """ + logger.info(f"开始批量分析 {len(comments)} 条评论,每次间隔 {delay} 秒") results = [] + success_count = 0 + fail_count = 0 for i, comment in enumerate(comments): - print(f"分析评论 {i + 1}/{len(comments)}...") + logger.info(f"正在分析第 {i + 1}/{len(comments)} 条评论") score, label = self.analyze(comment) + + if score is not None: + success_count += 1 + logger.debug(f"第 {i + 1} 条评论分析成功: {score}分 - {label}") + else: + fail_count += 1 + logger.warning(f"第 {i + 1} 条评论分析失败: {label}") + results.append({ 'content': comment, 'score': score, @@ -154,10 +211,12 @@ class LLMAnalyzer: }) if delay > 0 and i < len(comments) - 1: + logger.debug(f"等待 {delay} 秒后继续...") time.sleep(delay) + logger.info(f"批量分析完成,成功 {success_count} 条,失败 {fail_count} 条") return results def is_configured(self) -> bool: """检查是否已配置""" - return bool(self.client and self.api_key) \ No newline at end of file + return bool(self.client and self.api_key) diff --git a/main.py b/main.py index 6b26198..e18e7e5 100644 --- a/main.py +++ b/main.py @@ -2,11 +2,11 @@ 主程序入口 - 股吧人气指示器 """ import sys -import logging import time from datetime import datetime from PySide6.QtWidgets import QApplication -from PySide6.QtCore import QTimer, Signal, QObject +from PySide6.QtCore import QTimer, Signal, QObject, QThread +from loguru import logger from config_manager import ConfigManager from database import DatabaseManager @@ -20,8 +20,10 @@ class BackendWorker(QObject): fetch_finished = Signal(list) analysis_finished = Signal(float) + analysis_result = Signal(dict) # 传递分析结果详情 error_occurred = Signal(str) status_update = Signal(str) + stock_data_fetched = Signal(str, float) # 股票数据获取信号 def __init__(self, config_manager: ConfigManager, db_manager: DatabaseManager, spider: SpiderManager, analyzer: LLMAnalyzer): @@ -32,92 +34,136 @@ class BackendWorker(QObject): self.analyzer = analyzer self.running = False self.last_fetch_time = 0 - self.fetch_interval = 60 # 默认60秒 + self.fetch_interval = 15 # 默认15秒 self.no_new_content_count = 0 # 无新内容计数 + self.is_running_cycle = False # 防止并发执行 + self.fetch_count = 0 # 爬取次数统计 + self.analysis_count = 0 # API分析次数统计 + logger.info("BackendWorker 初始化完成") def start(self): """启动后台任务""" self.running = True - self._run_cycle() + logger.info("后台任务已启动") + # 启动时立即执行第一次任务 + QTimer.singleShot(1000, self._run_cycle) def stop(self): """停止后台任务""" self.running = False + logger.info("后台任务已停止") def _run_cycle(self): """运行一个周期""" + if self.is_running_cycle: + logger.debug("上一个周期仍在执行,跳过本次") + return + + self.is_running_cycle = True + if not self.running: + self.is_running_cycle = False return try: # 1. 爬取评论 + logger.info("开始爬取评论") self.status_update.emit("正在爬取评论...") + self.fetch_count += 1 comments = self.spider.fetch() if not comments: - self.no_new_content_count += 1 - interval = self.fetch_interval * (1 + min(self.no_new_content_count * 0.5, 2)) + # 爬取失败,使用固定间隔重试,不计入无新内容计数 + interval = 15 # 固定15秒重试 + logger.warning(f"未获取到新评论,{int(interval)}秒后重试") self.status_update.emit(f"无新内容,{int(interval)}秒后重试...") + self.is_running_cycle = False QTimer.singleShot(int(interval * 1000), self._run_cycle) return # 2. 写入数据库 + logger.info(f"获取到 {len(comments)} 条评论") self.status_update.emit(f"获取到 {len(comments)} 条评论...") new_ids = self.db.add_comments_batch(comments) if new_ids: self.no_new_content_count = 0 + logger.info(f"新增 {len(new_ids)} 条评论到数据库") self.status_update.emit(f"新增 {len(new_ids)} 条评论") # 3. 获取未分析评论并分析 unanalyzed = self.db.get_unanalyzed_comments(limit=10) + logger.info(f"获取到 {len(unanalyzed)} 条未分析评论") if unanalyzed: self.status_update.emit(f"开始分析 {len(unanalyzed)} 条评论...") self._analyze_comments(unanalyzed) else: self.no_new_content_count += 1 + logger.info("评论已存在,未新增") # 4. 更新指示器 self._update_indicator() except Exception as e: + logger.error(f"运行错误: {str(e)}") self.error_occurred.emit(f"运行错误: {str(e)}") # 安排下一次执行 if self.running: - interval = self.fetch_interval * (1 + min(self.no_new_content_count * 0.5, 2)) + interval = self.fetch_interval * (1 + min(self.no_new_content_count * 1.0, 4.0)) + logger.debug(f"下次执行将在 {int(interval)} 秒后") + self.is_running_cycle = False QTimer.singleShot(int(interval * 1000), self._run_cycle) + else: + self.is_running_cycle = False def _analyze_comments(self, comments): """分析评论""" + logger.info(f"开始分析 {len(comments)} 条评论") for i, comment in enumerate(comments): if not self.running: + logger.warning("分析被中断") break try: self.status_update.emit(f"分析 {i+1}/{len(comments)}...") + logger.debug(f"分析第 {i+1} 条评论: {comment['content'][:50]}...") + self.analysis_count += 1 score, label = self.analyzer.analyze(comment['content']) + # 获取分析结果详情 + last_result = self.analyzer.get_last_result() + if score is not None: self.db.mark_analyzed(comment['id'], score, label) + logger.info(f"评论 {comment['id']} 分析完成: {score}分 - {label}") + # 更新状态显示为简洁格式 + self.status_update.emit(f"分析 {i+1}/{len(comments)}...返回{score}分") + # 每条评论分析完成后立即更新指示器 + self._update_indicator() time.sleep(1.0) # 延迟,避免API限流 else: self.db.mark_analyzed(comment['id'], 50, "无法判断") + logger.warning(f"评论 {comment['id']} 无法判断") except Exception as e: + logger.error(f"分析评论 {comment.get('id', 'unknown')} 失败: {str(e)}") self.error_occurred.emit(f"分析失败: {str(e)}") self.db.mark_analyzed(comment['id'], 50, "分析异常") def _update_indicator(self): """更新指示器显示""" - scores = self.db.get_all_scores() + # 获取最新的100条评论的分数 + scores = self.db.get_all_scores(limit=100) if not scores: + logger.debug("暂无分析分数") return # 计算平均分 avg_score = sum(scores) / len(scores) + logger.info(f"当前平均分: {avg_score:.2f} (基于最新的 {len(scores)} 条评论)") # 根据阈值确定标签 thresholds = self.config.get('ui', 'thresholds', default={'cold': 30, 'warm': 70}) @@ -131,68 +177,161 @@ class BackendWorker(QObject): else: label = "中性" - self.analysis_finished.emit(avg_score) + logger.info(f"情感倾向: {label}") + # 发送整数分数 + self.analysis_finished.emit(int(avg_score)) + + def fetch_stock_data(self): + """爬取股票数据并添加到波形图""" + try: + logger.info("开始爬取股票数据") + stock_data = self.spider.fetch_sse_stock_data() + + if stock_data and 'time' in stock_data and 'value' in stock_data: + logger.info(f"成功获取股票数据: {stock_data}") + # 发送股票数据到主窗口 + self.stock_data_fetched.emit(stock_data['time'], stock_data['value']) + else: + logger.warning("未能获取有效的股票数据") + + except Exception as e: + logger.error(f"爬取股票数据失败: {str(e)}") def manual_refresh(self): """手动刷新""" + logger.info("用户手动刷新") self.no_new_content_count = 0 - self._run_cycle() + if not self.is_running_cycle: + self._run_cycle() + else: + logger.info("上一个周期仍在执行,跳过手动刷新") def update_fetch_interval(self, interval: int): """更新爬取间隔""" + logger.info(f"更新爬取间隔: {interval} 秒") self.fetch_interval = interval def setup_logging(log_path: str, level: str = "INFO"): """配置日志""" - logging.basicConfig( - level=getattr(logging, level.upper(), logging.INFO), - format='%(asctime)s - %(levelname)s - %(message)s', - handlers=[ - logging.FileHandler(log_path, encoding='utf-8'), - logging.StreamHandler() - ] + logger.remove() # 移除默认的处理器 + + logger.add( + log_path, + rotation="10 MB", + retention="7 days", + level=level, + encoding="utf-8", + format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}" + ) + + logger.add( + sys.stdout, + level=level, + format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}" ) def main(): """主函数""" + logger.info("=== 股吧人气指示器启动 ===") + # 创建应用 app = QApplication(sys.argv) app.setQuitOnLastWindowClosed(False) # 允许最小化到托盘 # 加载配置 + logger.info("加载配置文件...") config = ConfigManager("config.json") # 配置日志 log_config = config.logging_config setup_logging(log_config.get('path', 'guba.log'), log_config.get('level', 'INFO')) + logger.info(f"日志配置完成: {log_config.get('path', 'guba.log')}, 级别: {log_config.get('level', 'INFO')}") # 初始化组件 + logger.info("初始化数据库...") db = DatabaseManager(config.database_config.get('path', 'guba.db')) + + logger.info("初始化爬虫...") spider = SpiderManager(config.spider_config) + + logger.info("初始化LLM分析器...") analyzer = LLMAnalyzer(config.llm_api_config) # 创建后台工作器 worker = BackendWorker(config, db, spider, analyzer) - worker.update_fetch_interval(config.spider_config.get('fetch_interval', 60)) + worker.update_fetch_interval(config.spider_config.get('fetch_interval', 15)) + + # 创建后台线程 + logger.info("创建后台线程...") + worker_thread = QThread() + worker.moveToThread(worker_thread) + + # 线程启动时开始工作 + worker_thread.started.connect(worker.start) + # 线程结束时停止工作 + worker_thread.finished.connect(worker.stop) # 创建主窗口 - window = MainWindow(config) + logger.info("创建主窗口...") + window = MainWindow(config, spider) window.show() # 连接信号 worker.status_update.connect(window.update_status) worker.analysis_finished.connect(window.update_indicator) worker.error_occurred.connect(lambda msg: window.show_message("错误", msg)) + worker.stock_data_fetched.connect(window.add_waveform_data) + + # 启动时从数据库初始化指示器显示 + worker._update_indicator() + logger.info("初始化指示器显示完成") # 设置按钮回调 window.set_refresh_callback(worker.manual_refresh) window.set_config_callback(window.show_config) - # 启动后台任务 - worker.start() + # 启动后台线程 + worker_thread.start() + logger.info("后台线程已启动") + + # 启动股票数据爬取定时器 + stock_timer = QTimer() + stock_timer.timeout.connect(worker.fetch_stock_data) + stock_timer.start(60000) # 每分钟爬取一次股票数据 + logger.info("股票数据爬取定时器已启动,间隔60秒") + # 确保应用退出时清理线程 + def cleanup(): + logger.info("清理资源,停止后台线程...") + worker.stop() + worker_thread.quit() + worker_thread.wait() + + # 显示统计信息 + logger.info("=== 程序运行统计 ===") + logger.info(f"爬取网站次数: {worker.fetch_count}") + logger.info(f"提交API分析次数: {worker.analysis_count}") + logger.info("=== 统计结束 ===") + + # 写入统计信息到文件 + try: + stats_file = "statistics.txt" + with open(stats_file, "a", encoding="utf-8") as f: + f.write(f"=== 程序运行统计 - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ===\n") + f.write(f"爬取网站次数: {worker.fetch_count}\n") + f.write(f"提交API分析次数: {worker.analysis_count}\n") + f.write("=== 统计结束 ===\n\n") + logger.info(f"统计信息已写入到文件: {stats_file}") + except Exception as e: + logger.error(f"写入统计文件失败: {str(e)}") + + logger.info("后台线程已停止") + + app.aboutToQuit.connect(cleanup) + + logger.info("应用启动完成,进入主循环") # 运行应用 sys.exit(app.exec()) diff --git a/main_window.py b/main_window.py index dd679a2..384cbfe 100644 --- a/main_window.py +++ b/main_window.py @@ -4,10 +4,13 @@ PySide6 GUI界面模块 from PySide6.QtWidgets import (QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QSlider, QDialog, QFormLayout, QLineEdit, QSpinBox, QMessageBox, QSystemTrayIcon, - QMenu, QTextEdit, QGroupBox, QDialogButtonBox) + QMenu, QTextEdit, QGroupBox, QDialogButtonBox, QCheckBox) from PySide6.QtCore import Qt, QTimer, Signal, QPoint from PySide6.QtGui import QFont, QColor, QPainter, QBrush, QPen, QIcon, QAction from typing import Callable, Optional +from loguru import logger + +from waveform_widget import WaveformWidget class SentimentIndicator(QWidget): @@ -57,22 +60,38 @@ class SentimentIndicator(QWidget): def _get_color(self, score: int) -> QColor: """根据分数获取颜色""" - if score < 30: - # 冷色系 - 蓝色/青色 - ratio = score / 30 - return QColor(int(0 + 100 * ratio), int(150 + 50 * ratio), 255) - elif score < 70: - # 中性 - 灰色/绿色 - if score < 50: - ratio = (score - 30) / 20 - return QColor(int(100 + 50 * ratio), int(200 + 20 * ratio), int(200 - 50 * ratio)) - else: - ratio = (score - 50) / 20 - return QColor(int(150 + 50 * ratio), int(220 - 20 * ratio), int(150 - 50 * ratio)) + # 45-55分区间:每1分一个颜色(从绿色到黄色渐变) + if 45 <= score <= 55: + # 在45-55分之间,每1分一个颜色 + ratio = (score - 45) / 10 # 0到1之间的比例 + # 从绿色(0, 200, 0)渐变到黄色(255, 255, 0) + r = int(0 + 255 * ratio) + g = 200 if ratio < 0.5 else int(200 + 55 * (ratio - 0.5) * 2) + b = int(0 + 0 * ratio) + return QColor(r, g, b) + + # 45分以下:每5分一个颜色(从深蓝到绿色渐变) + elif score < 45: + # 将分数映射到0-8的区间(0-40分,每5分一个级别) + level = score // 5 + # 从深蓝色(0, 100, 255)渐变到绿色(0, 200, 0) + ratio = level / 8 # 0-8对应0-40分 + r = int(0 + 0 * ratio) + g = int(100 + 100 * ratio) + b = int(255 - 255 * ratio) + return QColor(r, g, b) + + # 55分以上:每5分一个颜色(从黄色到红色渐变) else: - # 暖色系 - 橙色/红色 - ratio = (score - 70) / 30 - return QColor(255, int(200 - 100 * ratio), int(50 + 50 * ratio)) + # 将分数映射到0-8的区间(60-100分,每5分一个级别) + level = (score - 60) // 5 + level = max(0, min(8, level)) # 限制在0-8范围内 + # 从黄色(255, 255, 0)渐变到红色(255, 0, 0) + ratio = level / 8 + r = 255 + g = int(255 - 255 * ratio) + b = 0 + return QColor(r, g, b) def get_description(self, score: int) -> str: """获取描述文本""" @@ -125,7 +144,7 @@ class ConfigDialog(QDialog): self.user_agent_edit = QLineEdit(spider_config.get('user_agent', '')) self.interval_spin = QSpinBox() self.interval_spin.setRange(10, 3600) - self.interval_spin.setValue(spider_config.get('fetch_interval', 60)) + self.interval_spin.setValue(spider_config.get('fetch_interval', 15)) layout.addRow("目标URL:", self.url_edit) layout.addRow("XPath:", self.xpath_edit) @@ -197,10 +216,14 @@ class ConfigDialog(QDialog): class MainWindow(QWidget): """主窗口""" - def __init__(self, config_manager, parent=None): + def __init__(self, config_manager, spider_manager=None, parent=None): super().__init__(parent) self.config_manager = config_manager - self.setWindowTitle("股吧人气指示器") + self.spider_manager = spider_manager + + # 获取页面标题并设置窗口标题 + self._set_window_title() + self._init_ui() self._apply_config() @@ -217,7 +240,7 @@ class MainWindow(QWidget): layout.setContentsMargins(10, 10, 10, 10) # 标题 - self.title_label = QLabel("股吧人气") + self.title_label = QLabel("上证指数sh000001") self.title_label.setAlignment(Qt.AlignCenter) title_font = QFont() title_font.setPointSize(14) @@ -236,26 +259,64 @@ class MainWindow(QWidget): status_font.setPointSize(10) self.status_label.setFont(status_font) + # 波形图组件 + self.waveform_widget = WaveformWidget() + self.waveform_widget.setMinimumHeight(200) + # 按钮 btn_layout = QHBoxLayout() self.refresh_btn = QPushButton("刷新") self.config_btn = QPushButton("配置") + self.quit_btn = QPushButton("退出") + self.quit_btn.clicked.connect(self.quit_app) btn_layout.addWidget(self.refresh_btn) btn_layout.addWidget(self.config_btn) + btn_layout.addWidget(self.quit_btn) # 添加到主布局 layout.addWidget(self.title_label) layout.addWidget(self.indicator) layout.addWidget(self.score_label) layout.addWidget(self.status_label) + layout.addWidget(self.waveform_widget) layout.addLayout(btn_layout) # 设置窗口标志(无边框、可拖拽) self.setWindowFlags(Qt.FramelessWindowHint | Qt.WindowStaysOnTopHint) self.setAttribute(Qt.WA_TranslucentBackground) + def _set_window_title(self): + """设置窗口标题""" + logger.debug("设置窗口标题") + + # 尝试从爬虫获取页面标题 + if hasattr(self, 'spider_manager') and self.spider_manager: + try: + page_title = self.spider_manager.get_page_title() + if page_title: + # 从页面标题中提取股票名称 + import re + match = re.search(r'(上证指数sh\d+)', page_title) + if match: + stock_name = match.group(1) + window_title = f"冷暖值 - {stock_name}" + self.setWindowTitle(window_title) + + # 同时更新主标题标签(如果已初始化) + if hasattr(self, 'title_label'): + self.title_label.setText("上证指数sh000001") + logger.info(f"设置窗口标题: {window_title}") + return + except Exception as e: + logger.error(f"获取页面标题失败: {e}") + + # 如果获取失败,使用默认标题 + self.setWindowTitle("冷暖值 - 股吧人气") + logger.info("使用默认窗口标题") + def _init_tray_icon(self): """初始化系统托盘""" + logger.debug("初始化系统托盘") self.tray_icon = QSystemTrayIcon(self) self.tray_icon.setToolTip("股吧人气指示器") @@ -275,9 +336,11 @@ class MainWindow(QWidget): self.tray_icon.setContextMenu(tray_menu) self.tray_icon.show() + logger.info("系统托盘初始化完成") def quit_app(self): """退出应用""" + logger.info("退出应用") self.close() import sys sys.exit(0) @@ -322,7 +385,7 @@ class MainWindow(QWidget): context_menu.addAction(config_action) context_menu.addAction(quit_action) - context_menu.exec(event.globalPosition().toPoint()) + context_menu.exec(event.globalPos()) def show_config(self): """显示配置对话框""" @@ -336,23 +399,33 @@ class MainWindow(QWidget): label = self.indicator.get_description(score) self.indicator.set_value(score, label) self.score_label.setText(f"{score} - {label}") + logger.debug(f"更新指示灯: {score}分 - {label}") def update_status(self, text: str): """更新状态""" self.status_label.setText(text) + logger.debug(f"更新状态: {text}") def set_refresh_callback(self, callback: Callable): """设置刷新按钮回调""" self.refresh_btn.clicked.connect(callback) + logger.debug("设置刷新按钮回调") def set_config_callback(self, callback: Callable): """设置配置按钮回调""" self.config_btn.clicked.connect(callback) + logger.debug("设置配置按钮回调") - def show_message(self, title: str, message: str, icon=QMessageBox.Information): + def show_message(self, title: str, message: str): """显示消息""" + logger.info(f"显示消息: {title} - {message}") QMessageBox.information(self, title, message) + def add_waveform_data(self, time_str: str, value: float): + """添加波形图数据点""" + self.waveform_widget.add_data_point(time_str, value) + logger.info(f"添加波形图数据点: 时间={time_str}, 值={value}") + class QCheckBox(QPushButton): """自定义复选框""" diff --git a/requirements.txt b/requirements.txt index 62917cf..bddce44 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ requests>=2.31.0 beautifulsoup4>=4.12.0 lxml>=4.9.0 openai>=1.0.0 +playwright>=1.40.0 diff --git a/screenshot_manager.py b/screenshot_manager.py new file mode 100644 index 0000000..cff90d0 --- /dev/null +++ b/screenshot_manager.py @@ -0,0 +1,185 @@ +""" +截图管理器 - 用于在非交易时间截取上海证券交易所网站图表 +""" +import os +from datetime import datetime +from loguru import logger + + +try: + from playwright.sync_api import sync_playwright, TimeoutError as PlaywrightTimeoutError +except ImportError: + logger.warning("playwright未安装,截图功能将不可用") + + +class ScreenshotManager: + """截图管理器""" + + def __init__(self, screenshot_dir: str = "screenshots"): + """初始化截图管理器""" + self.screenshot_dir = screenshot_dir + self.target_url = "https://www.sse.com.cn/" + self.chart_xpath_pattern = "//*[@id=\"hq_area\"]" + + # 创建截图目录 + os.makedirs(self.screenshot_dir, exist_ok=True) + + logger.info(f"截图管理器初始化完成,截图目录: {self.screenshot_dir}") + + def capture_chart_screenshot(self) -> str: + """ + 截取上海证券交易所网站图表 + 返回截图文件路径,失败时返回空字符串 + """ + try: + # 检查playwright是否可用 + if 'sync_playwright' not in globals(): + logger.error("playwright未安装,无法使用截图功能") + return "" + + with sync_playwright() as p: + # 启动浏览器(无头模式,后台运行) + browser = p.chromium.launch(headless=True) + page = browser.new_page() + + # 设置页面超时 + page.set_default_timeout(60000) # 60秒超时 + + # 访问目标网页 + logger.info(f"访问上海证券交易所网站: {self.target_url}") + page.goto(self.target_url, wait_until="domcontentloaded") + + # 等待页面加载完成 + page.wait_for_load_state("networkidle") + + # 等待页面完全加载 + page.wait_for_timeout(5000) # 额外等待5秒 + + # 等待图表元素出现 + logger.info("等待图表元素加载...") + + # 使用动态XPath模式查找图表元素 + chart_element = None + selectors = [ + self.chart_xpath_pattern, # 主要选择器:highcharts-xxxxxxx-0格式 + "//*[contains(@class, 'highcharts')]", + "//*[contains(@id, 'highcharts')]", + "//svg", + "//canvas", + "//div[contains(@class, 'chart')]", + "//div[contains(@class, 'graph')]" + ] + + for selector in selectors: + try: + # 等待选择器出现 + page.wait_for_selector(selector, timeout=10000) + elements = page.query_selector_all(selector) + if elements: + # 选择第一个可见的元素 + for element in elements: + if element.is_visible(): + chart_element = element + logger.info(f"找到图表元素: {selector}") + break + if chart_element: + break + except PlaywrightTimeoutError: + logger.debug(f"选择器超时: {selector}") + continue + except Exception as e: + logger.debug(f"选择器错误 {selector}: {e}") + continue + + if not chart_element: + logger.warning("未找到任何图表元素,尝试截取整个页面") + # 如果找不到图表元素,截取整个页面 + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + screenshot_path = os.path.join(self.screenshot_dir, f"sse_page_{timestamp}.png") + page.screenshot(path=screenshot_path) + logger.info(f"截取整个页面: {screenshot_path}") + browser.close() + return screenshot_path + + # 检查元素是否可见 + if not chart_element.is_visible(): + logger.warning("图表元素不可见,尝试滚动到元素位置") + chart_element.scroll_into_view_if_needed() + + # 生成截图文件名 + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + screenshot_path = os.path.join(self.screenshot_dir, f"sse_chart_{timestamp}.png") + + # 截取图表元素 + logger.info("开始截取图表元素") + + # 直接使用元素进行截图 + chart_element.screenshot(path=screenshot_path) + + logger.info(f"✅ 图表截图完成,保存至: {screenshot_path}") + + # 关闭浏览器 + browser.close() + + return screenshot_path + + except PlaywrightTimeoutError as e: + logger.error(f"页面加载超时: {e}") + return "" + except Exception as e: + logger.error(f"截图过程中发生错误: {e}") + return "" + + def get_latest_screenshot(self) -> str: + """获取最新的截图文件路径""" + try: + if not os.path.exists(self.screenshot_dir): + return "" + + # 获取所有截图文件 + screenshot_files = [] + for file in os.listdir(self.screenshot_dir): + if file.startswith("sse_chart_") and file.endswith(".png"): + file_path = os.path.join(self.screenshot_dir, file) + screenshot_files.append((file_path, os.path.getmtime(file_path))) + + if not screenshot_files: + return "" + + # 按修改时间排序,获取最新的文件 + screenshot_files.sort(key=lambda x: x[1], reverse=True) + return screenshot_files[0][0] + + except Exception as e: + logger.error(f"获取最新截图失败: {e}") + return "" + + def cleanup_old_screenshots(self, keep_count: int = 10): + """清理旧的截图文件,只保留最新的几个""" + try: + if not os.path.exists(self.screenshot_dir): + return + + # 获取所有截图文件 + screenshot_files = [] + for file in os.listdir(self.screenshot_dir): + if file.startswith("sse_chart_") and file.endswith(".png"): + file_path = os.path.join(self.screenshot_dir, file) + screenshot_files.append((file_path, os.path.getmtime(file_path))) + + if len(screenshot_files) <= keep_count: + return + + # 按修改时间排序,删除旧的文件 + screenshot_files.sort(key=lambda x: x[1]) + files_to_delete = screenshot_files[:-keep_count] + + for file_path, _ in files_to_delete: + try: + os.remove(file_path) + logger.info(f"删除旧截图: {file_path}") + except Exception as e: + logger.error(f"删除截图文件失败 {file_path}: {e}") + + except Exception as e: + logger.error(f"清理旧截图失败: {e}") \ No newline at end of file diff --git a/screenshots/sse_page_20260109_173924.png b/screenshots/sse_page_20260109_173924.png new file mode 100644 index 0000000..177b6d4 Binary files /dev/null and b/screenshots/sse_page_20260109_173924.png differ diff --git a/screenshots/sse_page_20260109_174348.png b/screenshots/sse_page_20260109_174348.png new file mode 100644 index 0000000..177b6d4 Binary files /dev/null and b/screenshots/sse_page_20260109_174348.png differ diff --git a/spider.py b/spider.py index 03fb0d6..304bb01 100644 --- a/spider.py +++ b/spider.py @@ -8,6 +8,7 @@ from typing import List, Dict, Optional from urllib.parse import urljoin from bs4 import BeautifulSoup import random +from loguru import logger class SpiderManager: @@ -21,6 +22,7 @@ class SpiderManager: }) self.retry_times = config.get('retry_times', 3) self.retry_interval = config.get('retry_interval', 5) + logger.info(f"爬虫管理器初始化完成,目标URL: {config.get('target_url', '')}") def fetch(self, url: str = None, xpath: str = None) -> List[Dict]: """ @@ -31,13 +33,18 @@ class SpiderManager: target_xpath = xpath or self.config.get('xpath', '') if not target_url: + logger.warning("未设置目标URL") return [] + logger.info(f"开始抓取: {target_url}") html = self._fetch_with_retry(target_url) if not html: + logger.warning("网页获取失败") return [] - return self._parse_comments(html, target_xpath, target_url) + comments = self._parse_comments(html, target_xpath, target_url) + logger.info(f"解析完成,获取到 {len(comments)} 条评论") + return comments def _fetch_with_retry(self, url: str, max_retries: int = None) -> Optional[str]: """带重试的网页获取""" @@ -45,18 +52,49 @@ class SpiderManager: for attempt in range(max_retries): try: + logger.debug(f"尝试 {attempt + 1}/{max_retries} 获取网页") response = self.session.get(url, timeout=30) response.raise_for_status() response.encoding = response.apparent_encoding + logger.debug(f"网页获取成功,大小: {len(response.text)} 字节") return response.text except requests.RequestException as e: - print(f"请求失败 (尝试 {attempt + 1}/{max_retries}): {e}") + logger.warning(f"请求失败 (尝试 {attempt + 1}/{max_retries}): {e}") if attempt < max_retries - 1: time.sleep(self.retry_interval + random.uniform(0, 2)) else: + logger.error(f"所有重试均失败: {url}") return None return None + def get_page_title(self, url: str = None) -> str: + """获取页面标题""" + target_url = url or self.config.get('target_url', '') + if not target_url: + logger.warning("未设置目标URL") + return "" + + logger.info(f"获取页面标题: {target_url}") + html = self._fetch_with_retry(target_url) + if not html: + logger.warning("网页获取失败") + return "" + + try: + # 使用 lxml 解析页面标题 + tree = etree.HTML(html) + title_elements = tree.xpath('//title/text()') + if title_elements: + title = title_elements[0].strip() + logger.info(f"获取到页面标题: {title}") + return title + else: + logger.warning("未找到页面标题") + return "" + except Exception as e: + logger.error(f"解析页面标题失败: {e}") + return "" + def _parse_comments(self, html: str, xpath: str, base_url: str) -> List[Dict]: """解析评论""" comments = [] @@ -65,10 +103,11 @@ class SpiderManager: # 使用 lxml 解析 tree = etree.HTML(html) elements = tree.xpath(xpath) + logger.debug(f"XPath 匹配到 {len(elements)} 个元素") for elem in elements: try: - text = elem.text_content().strip() + text = etree.tostring(elem, method='text', encoding='unicode').strip() if text: # 获取链接的 href(如果存在) href = elem.get('href') @@ -79,11 +118,11 @@ class SpiderManager: 'url': full_url }) except Exception as e: - print(f"解析元素失败: {e}") + logger.error(f"解析元素失败: {e}") continue except Exception as e: - print(f"XPath解析失败: {e}") + logger.error(f"XPath解析失败: {e}") # 备选解析方法 comments = self._fallback_parse(html, base_url) @@ -93,6 +132,7 @@ class SpiderManager: """备选解析方法 - 使用 BeautifulSoup""" comments = [] try: + logger.debug("使用备选解析方法") soup = BeautifulSoup(html, 'lxml') # 尝试查找常见的评论元素 @@ -106,11 +146,112 @@ class SpiderManager: 'content': text, 'url': base_url }) + logger.debug(f"备选解析获取到 {len(comments)} 条评论") + except Exception as e: - print(f"备选解析失败: {e}") + logger.error(f"备选解析失败: {e}") return comments + def is_trading_time(self) -> bool: + """判断当前是否为交易时间""" + from datetime import datetime, time + + current_time = datetime.now().time() + + # 上午交易时间: 9:30-11:30 + morning_start = time(9, 30) + morning_end = time(11, 30) + + # 下午交易时间: 13:00-15:00 + afternoon_start = time(13, 0) + afternoon_end = time(15, 0) + + # 判断是否在交易时间内 + is_trading = ((morning_start <= current_time <= morning_end) or + (afternoon_start <= current_time <= afternoon_end)) + + logger.debug(f"当前时间 {current_time.strftime('%H:%M')} 是否为交易时间: {is_trading}") + return is_trading + + def fetch_sse_stock_data(self) -> Dict[str, float]: + """ + 爬取上海证券交易所股票数据 + 返回包含时间和数值的字典 + """ + # 检查是否为交易时间 + if not self.is_trading_time(): + logger.info("当前为非交易时间,跳过股票数据爬取") + return {} + + sse_url = "https://www.sse.com.cn/" + xpath = "//*[@id=\"hq_area\"]" + + logger.info(f"开始爬取上海证券交易所数据: {sse_url}") + + html = self._fetch_with_retry(sse_url) + if not html: + logger.warning("上海证券交易所网页获取失败") + return {} + + try: + # 使用 lxml 解析 + tree = etree.HTML(html) + + # 尝试获取股票数值 + elements = tree.xpath(xpath) + if not elements: + logger.warning("未找到股票数据元素,尝试备用XPath") + # 尝试备用XPath + backup_xpaths = [ + "//*[@id='hq_controller']//td[contains(@class, 'price')]//text()", + "//*[contains(@class, 'stock-price')]//text()", + "//*[contains(@class, 'price')]//text()" + ] + + for backup_xpath in backup_xpaths: + elements = tree.xpath(backup_xpath) + if elements: + logger.info(f"使用备用XPath获取到数据") + break + + if elements: + # 获取文本内容并尝试转换为数值 + text_content = etree.tostring(elements[0], method='text', encoding='unicode').strip() + logger.info(f"获取到股票数据文本: {text_content}") + + # 提取数值(可能包含逗号、小数点等) + import re + # 匹配数字(包括小数点和逗号分隔符) + numbers = re.findall(r'[\d,]+(?:\.\d+)?', text_content) + if numbers: + # 去除逗号并转换为浮点数 + value_str = numbers[0].replace(',', '') + try: + stock_value = float(value_str) + + # 获取当前时间 + from datetime import datetime + current_time = datetime.now().strftime("%H:%M") + + logger.info(f"成功获取股票数据: 时间={current_time}, 值={stock_value}") + + return { + 'time': current_time, + 'value': stock_value + } + except ValueError as e: + logger.error(f"数值转换失败: {value_str}, 错误: {e}") + else: + logger.warning(f"未找到有效数值: {text_content}") + else: + logger.warning("未找到股票数据元素") + + except Exception as e: + logger.error(f"解析上海证券交易所数据失败: {e}") + + return {} + def set_user_agent(self, user_agent: str): """更新User-Agent""" self.session.headers.update({'User-Agent': user_agent}) diff --git a/test_api.py b/test_api.py new file mode 100644 index 0000000..d40b42b --- /dev/null +++ b/test_api.py @@ -0,0 +1,32 @@ +import os +from openai import OpenAI + +# 设置 API key 为环境变量 +os.environ["NVIDIA_API_KEY"] = "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" + +client = OpenAI( + base_url="https://integrate.api.nvidia.com/v1", + api_key=os.environ["NVIDIA_API_KEY"] +) + +print("正在测试 API 连接...") +print("发送消息: 天气") +print("-" * 50) + +completion = client.chat.completions.create( + model="deepseek-ai/deepseek-r1", + messages=[{"role": "user", "content": "天气"}], + temperature=0.6, + top_p=0.7, + max_tokens=4096, + stream=False +) + +reasoning = getattr(completion.choices[0].message, "reasoning_content", None) +if reasoning: + print("推理过程:") + print(reasoning) + print("-" * 50) + +print("回答:") +print(completion.choices[0].message.content) diff --git a/test_nvidia_api.py b/test_nvidia_api.py new file mode 100644 index 0000000..0eedb79 --- /dev/null +++ b/test_nvidia_api.py @@ -0,0 +1,35 @@ +""" +测试 NVIDIA API 连接 +""" +from openai import OpenAI + +client = OpenAI( + base_url="https://integrate.api.nvidia.com/v1", + api_key="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" +) + +try: + print("开始测试 API 调用...") + completion = client.chat.completions.create( + model="deepseek-ai/deepseek-r1", + messages=[{"role": "user", "content": "你好,请简单介绍一下你自己"}], + temperature=0.6, + top_p=0.7, + max_tokens=4096, + stream=False, + timeout=120 + ) + + reasoning = getattr(completion.choices[0].message, "reasoning_content", None) + if reasoning: + print("推理过程:") + print(reasoning) + print("\n" + "="*50 + "\n") + + print("回答内容:") + print(completion.choices[0].message.content) + print("\n" + "="*50) + print("API 调用成功!") + +except Exception as e: + print(f"API 调用失败: {type(e).__name__}: {e}") diff --git a/test_screenshot.py b/test_screenshot.py new file mode 100644 index 0000000..ac70687 --- /dev/null +++ b/test_screenshot.py @@ -0,0 +1,43 @@ +""" +测试截图功能 +""" +import os +from screenshot_manager import ScreenshotManager +from loguru import logger + + +def test_screenshot(): + """测试截图功能""" + logger.info("开始测试截图功能") + + # 创建截图管理器 + screenshot_manager = ScreenshotManager() + + # 测试截图 + screenshot_path = screenshot_manager.capture_chart_screenshot() + + if screenshot_path: + logger.info(f"✅ 截图成功: {screenshot_path}") + + # 检查文件是否存在 + if os.path.exists(screenshot_path): + file_size = os.path.getsize(screenshot_path) + logger.info(f"截图文件大小: {file_size} 字节") + + # 获取最新截图 + latest_screenshot = screenshot_manager.get_latest_screenshot() + logger.info(f"最新截图: {latest_screenshot}") + + # 清理旧截图 + screenshot_manager.cleanup_old_screenshots(keep_count=2) + logger.info("清理旧截图完成") + else: + logger.error("截图文件不存在") + else: + logger.error("截图失败") + + logger.info("截图功能测试完成") + + +if __name__ == "__main__": + test_screenshot() \ No newline at end of file diff --git a/waveform_widget.py b/waveform_widget.py new file mode 100644 index 0000000..d3cf188 --- /dev/null +++ b/waveform_widget.py @@ -0,0 +1,429 @@ +""" +波形图组件 - 用于绘制股票数据波形图 +""" +import math +import os +from datetime import datetime, time +from PySide6.QtWidgets import QWidget +from PySide6.QtGui import QPainter, QPen, QColor, QBrush, QPixmap +from PySide6.QtCore import QPointF, QTimer +from loguru import logger + +# 尝试导入截图管理器 +try: + from screenshot_manager import ScreenshotManager +except ImportError: + ScreenshotManager = None + logger.warning("截图管理器导入失败,截图功能将不可用") + + +class WaveformWidget(QWidget): + """波形图控件""" + + def __init__(self, parent=None): + super().__init__(parent) + self.data_points = [] # 存储数据点 [(time, value)] + self.base_value = 0 # 基准值 + self.screenshot_manager = None + self.latest_screenshot_path = "" + self.setMinimumSize(600, 300) + + # 设置背景色 + self.setStyleSheet("background-color: #1e1e1e;") + + # 初始化截图管理器 + if ScreenshotManager: + self.screenshot_manager = ScreenshotManager() + logger.info("截图管理器初始化完成") + + logger.info("WaveformWidget 初始化完成") + + def time_to_x_position(self, time_str: str, total_width: int) -> float: + """ + 将时间转换为X轴位置 + 时间折算关系: + - 9:30 -> 最左侧 (x=0) + - 11:30 -> 中间 (x=total_width/2) + - 13:00 -> 中间 (x=total_width/2) + - 15:00 -> 最右侧 (x=total_width) + - 10:30 -> 左侧四分之一 (x=total_width/4) + - 14:00 -> 右侧四分之一 (x=total_width*3/4) + """ + try: + # 解析时间字符串 + current_time = datetime.strptime(time_str, "%H:%M").time() + + # 定义关键时间点 + market_start = time(9, 30) # 9:30 + market_mid1 = time(11, 30) # 11:30 + market_mid2 = time(13, 0) # 13:00 + market_end = time(15, 0) # 15:00 + + # 计算总交易时间(分钟) + morning_duration = (market_mid1.hour - market_start.hour) * 60 + \ + (market_mid1.minute - market_start.minute) + afternoon_duration = (market_end.hour - market_mid2.hour) * 60 + \ + (market_end.minute - market_mid2.minute) + total_duration = morning_duration + afternoon_duration + + # 计算当前时间相对于开盘时间的分钟数 + if current_time <= market_mid1: + # 上午交易时段 + minutes_from_start = (current_time.hour - market_start.hour) * 60 + \ + (current_time.minute - market_start.minute) + # 上午时段占一半宽度 + x_ratio = minutes_from_start / morning_duration * 0.5 + elif current_time >= market_mid2: + # 下午交易时段 + minutes_from_start = (current_time.hour - market_mid2.hour) * 60 + \ + (current_time.minute - market_mid2.minute) + # 下午时段占一半宽度,从中间开始 + x_ratio = 0.5 + minutes_from_start / afternoon_duration * 0.5 + else: + # 午休时间,统一放在中间 + x_ratio = 0.5 + + return x_ratio * total_width + + except Exception as e: + logger.error(f"时间转换错误: {time_str}, 错误: {e}") + return total_width / 2 # 默认返回中间位置 + + def is_trading_time(self, time_str: str) -> bool: + """判断是否为交易时间""" + try: + current_time = datetime.strptime(time_str, "%H:%M").time() + + # 上午交易时间: 9:30-11:30 + morning_start = time(9, 30) + morning_end = time(11, 30) + + # 下午交易时间: 13:00-15:00 + afternoon_start = time(13, 0) + afternoon_end = time(15, 0) + + # 判断是否在交易时间内 + is_trading = ((morning_start <= current_time <= morning_end) or + (afternoon_start <= current_time <= afternoon_end)) + + logger.debug(f"时间 {time_str} 是否为交易时间: {is_trading}") + return is_trading + + except Exception as e: + logger.error(f"时间判断错误: {time_str}, 错误: {e}") + return False + + def add_data_point(self, time_str: str, value: float): + """添加数据点""" + # 检查是否为交易时间 + if not self.is_trading_time(time_str): + logger.info(f"非交易时间 {time_str},跳过数据点添加") + return + + # 如果是第一个数据点,设置基准值 + if not self.data_points: + self.base_value = value + logger.info(f"设置基准值: {self.base_value}") + + self.data_points.append((time_str, value)) + logger.info(f"添加数据点: 时间={time_str}, 值={value}") + + # 限制数据点数量,避免内存过大 + if len(self.data_points) > 100: + self.data_points = self.data_points[-100:] + + # 触发重绘 + self.update() + + def clear_data(self): + """清除所有数据""" + self.data_points.clear() + self.base_value = 0 + self.update() + logger.info("波形图数据已清除") + + def paintEvent(self, event): + """绘制波形图""" + painter = QPainter(self) + painter.setRenderHint(QPainter.Antialiasing) + + width = self.width() + height = self.height() + + # 如果没有数据点,显示提示信息 + if not self.data_points: + self._draw_no_data_message(painter, width, height) + return + + # 检查最后一个数据点的时间是否为交易时间 + last_time = self.data_points[-1][0] if self.data_points else "" + if last_time and not self.is_trading_time(last_time): + self._draw_non_trading_message(painter, width, height) + return + + # 绘制网格和坐标轴 + self._draw_grid(painter, width, height) + + # 绘制波形线 + self._draw_waveform(painter, width, height) + + # 绘制数据点 + self._draw_data_points(painter, width, height) + + def _draw_grid(self, painter: QPainter, width: int, height: int): + """绘制网格和坐标轴""" + # 设置网格颜色 + grid_color = QColor(100, 100, 100) + painter.setPen(QPen(grid_color, 1, Qt.DashLine)) + + # 绘制水平网格线 + for i in range(1, 5): + y = height * i // 5 + painter.drawLine(0, y, width, y) + + # 绘制垂直网格线(时间刻度) + time_points = ["9:30", "10:30", "11:30", "13:00", "14:00", "15:00"] + for time_str in time_points: + x = self.time_to_x_position(time_str, width) + painter.drawLine(x, 0, x, height) + + # 绘制坐标轴 + axis_color = QColor(200, 200, 200) + painter.setPen(QPen(axis_color, 2)) + painter.drawLine(0, height // 2, width, height // 2) # X轴 + painter.drawLine(0, 0, 0, height) # Y轴 + + # 绘制时间标签 + painter.setPen(QPen(QColor(150, 150, 150), 1)) + for time_str in time_points: + x = self.time_to_x_position(time_str, width) + painter.drawText(int(x) - 20, height - 5, time_str) + + def _draw_waveform(self, painter: QPainter, width: int, height: int): + """绘制波形线""" + if len(self.data_points) < 2: + return + + # 设置波形线样式 + waveform_color = QColor(0, 200, 255) + painter.setPen(QPen(waveform_color, 3)) + + points = [] + + # 计算Y轴范围(基准值±100点) + y_min = self.base_value - 100 + y_max = self.base_value + 100 + y_range = y_max - y_min + + for time_str, value in self.data_points: + x = self.time_to_x_position(time_str, width) + + # 计算Y坐标(从底部到顶部) + y_ratio = (value - y_min) / y_range + y = height - (y_ratio * height) + + points.append(QPointF(x, y)) + + # 绘制折线 + for i in range(len(points) - 1): + painter.drawLine(points[i], points[i + 1]) + + def _draw_data_points(self, painter: QPainter, width: int, height: int): + """绘制数据点""" + point_color = QColor(255, 100, 100) + painter.setPen(QPen(point_color, 1)) + painter.setBrush(QBrush(point_color)) + + # 计算Y轴范围 + y_min = self.base_value - 100 + y_max = self.base_value + 100 + y_range = y_max - y_min + + for time_str, value in self.data_points: + x = self.time_to_x_position(time_str, width) + y_ratio = (value - y_min) / y_range + y = height - (y_ratio * height) + + # 绘制数据点圆圈 + painter.drawEllipse(int(x) - 3, int(y) - 3, 6, 6) + + # 显示数值标签 + painter.setPen(QPen(QColor(200, 200, 200), 1)) + painter.drawText(int(x) + 5, int(y) - 5, f"{value:.2f}") + + def _draw_no_data_message(self, painter: QPainter, width: int, height: int): + """绘制无数据提示信息""" + # 设置提示信息颜色 + message_color = QColor(150, 150, 150) + painter.setPen(QPen(message_color, 2)) + + # 设置字体 + font = painter.font() + font.setPointSize(14) + font.setBold(True) + painter.setFont(font) + + # 绘制提示信息 + message = "等待交易时间数据..." + text_rect = painter.fontMetrics().boundingRect(message) + x = (width - text_rect.width()) // 2 + y = height // 2 + + painter.drawText(x, y, message) + + # 绘制交易时间说明 + font.setPointSize(10) + font.setBold(False) + painter.setFont(font) + + info = "交易时间: 9:30-11:30, 13:00-15:00" + info_rect = painter.fontMetrics().boundingRect(info) + x_info = (width - info_rect.width()) // 2 + y_info = y + 30 + + painter.drawText(x_info, y_info, info) + + def _draw_non_trading_message(self, painter: QPainter, width: int, height: int): + """绘制非交易时间提示信息或截图""" + # 如果有截图管理器,尝试显示截图 + if self.screenshot_manager: + screenshot_path = self._get_or_capture_screenshot() + if screenshot_path: + self._draw_screenshot(painter, screenshot_path, width, height) + return + + # 如果没有截图或截图失败,显示文本提示 + self._draw_text_message(painter, width, height) + + def _get_or_capture_screenshot(self) -> str: + """获取或捕获最新的截图""" + try: + # 检查是否有最新的截图 + latest_screenshot = self.screenshot_manager.get_latest_screenshot() + + # 如果截图文件存在且较新(5分钟内),使用现有截图 + if latest_screenshot and os.path.exists(latest_screenshot): + file_time = datetime.fromtimestamp(os.path.getmtime(latest_screenshot)) + current_time = datetime.now() + + # 如果截图在5分钟内,直接使用 + if (current_time - file_time).total_seconds() < 300: # 5分钟 + return latest_screenshot + + # 否则捕获新的截图 + logger.info("开始捕获上海证券交易所网站截图") + new_screenshot = self.screenshot_manager.capture_chart_screenshot() + + if new_screenshot: + # 清理旧截图 + self.screenshot_manager.cleanup_old_screenshots() + return new_screenshot + + except Exception as e: + logger.error(f"获取截图失败: {e}") + + return "" + + def _draw_screenshot(self, painter: QPainter, screenshot_path: str, width: int, height: int): + """绘制截图""" + try: + # 加载截图 + pixmap = QPixmap(screenshot_path) + if pixmap.isNull(): + logger.warning(f"截图加载失败: {screenshot_path}") + self._draw_text_message(painter, width, height) + return + + # 计算缩放比例,保持宽高比 + pixmap_width = pixmap.width() + pixmap_height = pixmap.height() + + # 计算缩放比例,使截图适应显示区域 + scale_x = width / pixmap_width + scale_y = height / pixmap_height + scale = min(scale_x, scale_y) * 0.8 # 留出边距 + + # 计算显示尺寸 + display_width = int(pixmap_width * scale) + display_height = int(pixmap_height * scale) + + # 计算显示位置(居中) + x = (width - display_width) // 2 + y = (height - display_height) // 2 + + # 绘制截图 + scaled_pixmap = pixmap.scaled(display_width, display_height) + painter.drawPixmap(x, y, scaled_pixmap) + + # 绘制标题 + font = painter.font() + font.setPointSize(12) + font.setBold(True) + painter.setFont(font) + + title_color = QColor(255, 255, 255) + painter.setPen(QPen(title_color, 2)) + + title = "上海证券交易所实时图表" + title_rect = painter.fontMetrics().boundingRect(title) + title_x = (width - title_rect.width()) // 2 + title_y = y - 10 + + painter.drawText(title_x, title_y, title) + + # 绘制更新时间 + font.setPointSize(8) + font.setBold(False) + painter.setFont(font) + + update_time = datetime.now().strftime("更新时间: %H:%M:%S") + time_rect = painter.fontMetrics().boundingRect(update_time) + time_x = (width - time_rect.width()) // 2 + time_y = y + display_height + 20 + + painter.drawText(time_x, time_y, update_time) + + except Exception as e: + logger.error(f"绘制截图失败: {e}") + self._draw_text_message(painter, width, height) + + def _draw_text_message(self, painter: QPainter, width: int, height: int): + """绘制文本提示信息""" + # 设置提示信息颜色 + message_color = QColor(200, 100, 100) + painter.setPen(QPen(message_color, 2)) + + # 设置字体 + font = painter.font() + font.setPointSize(14) + font.setBold(True) + painter.setFont(font) + + # 绘制提示信息 + message = "非交易时间" + text_rect = painter.fontMetrics().boundingRect(message) + x = (width - text_rect.width()) // 2 + y = height // 2 + + painter.drawText(x, y, message) + + # 绘制交易时间说明 + font.setPointSize(10) + font.setBold(False) + painter.setFont(font) + + info = "交易时间: 9:30-11:30, 13:00-15:00" + info_rect = painter.fontMetrics().boundingRect(info) + x_info = (width - info_rect.width()) // 2 + y_info = y + 30 + + painter.drawText(x_info, y_info, info) + + # 绘制当前时间 + current_time = datetime.now().strftime("%H:%M") + time_info = f"当前时间: {current_time}" + time_rect = painter.fontMetrics().boundingRect(time_info) + x_time = (width - time_rect.width()) // 2 + y_time = y_info + 25 + + painter.drawText(x_time, y_time, time_info) \ No newline at end of file