feat: 新增股票数据波形图和截图功能

refactor: 重构数据库和LLM分析器逻辑

fix: 修复爬虫解析和UI显示问题

docs: 更新配置文件和注释

style: 优化代码格式和日志输出
This commit is contained in:
2026-01-12 09:19:38 +08:00
parent 5b8b9ec35a
commit 96f206ea78
18 changed files with 1358 additions and 93 deletions

0
api.txt Normal file
View File

25
build.bat Normal file
View File

@@ -0,0 +1,25 @@
@echo off
echo ========================================
echo 股吧人气指示器 - 打包工具
echo ========================================
REM 检查 pyinstaller 是否安装
pip show pyinstaller >nul 2>&1
if errorlevel 1 (
echo 正在安装 pyinstaller...
pip install pyinstaller
)
echo 开始打包...
pyinstaller build.spec
if exist "dist\guba-indicator.exe" (
echo ========================================
echo 打包成功!
echo 可执行文件位置: dist\guba-indicator.exe
echo ========================================
) else (
echo 打包失败,请检查错误信息
)
pause

58
build.spec Normal file
View File

@@ -0,0 +1,58 @@
# -*- mode: python ; coding: utf-8 -*-
import sys
import os
block_cipher = None
# 添加项目路径
project_path = os.path.abspath('.')
if project_path not in sys.path:
sys.path.insert(0, project_path)
a = Analysis(
['main.py'],
pathex=['.', project_path],
binaries=[],
datas=[],
hiddenimports=[],
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=['tkinter', 'IPython', 'pytest', 'matplotlib', 'pandas', 'scipy', 'numpy'],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
[],
exclude_binaries=True,
name='guba-indicator',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
console=False,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
icon='indicator.ico',
)
coll = COLLECT(
exe,
a.binaries,
a.zipfiles,
a.datas,
strip=False,
upx=True,
upx_exclude=[],
name='guba-indicator',
)

View File

@@ -5,6 +5,7 @@ import json
import os
from typing import Any, Dict
from pathlib import Path
from loguru import logger
class ConfigManager:
@@ -12,10 +13,10 @@ class ConfigManager:
DEFAULT_CONFIG = {
"llm_api": {
"base_url": "https://api.openai.com/v1",
"base_url": "https://integrate.api.nvidia.com/v1",
"api_key": "",
"model": "gpt-3.5-turbo",
"timeout": 30,
"model": "deepseek-ai/deepseek-r1",
"timeout": 120,
"retry_times": 3
},
"spider": {
@@ -46,19 +47,24 @@ class ConfigManager:
def __init__(self, config_path: str = "config.json"):
self.config_path = Path(config_path)
self.config = self._load_config()
logger.info(f"配置管理器初始化完成,配置文件: {config_path}")
def _load_config(self) -> Dict[str, Any]:
"""加载配置文件"""
if self.config_path.exists():
try:
logger.info(f"从文件加载配置: {self.config_path}")
with open(self.config_path, 'r', encoding='utf-8') as f:
loaded_config = json.load(f)
# 合并默认配置,确保所有键都存在
return self._merge_config(self.DEFAULT_CONFIG, loaded_config)
merged = self._merge_config(self.DEFAULT_CONFIG, loaded_config)
logger.info("配置加载成功")
return merged
except (json.JSONDecodeError, IOError) as e:
print(f"配置文件加载失败,使用默认配置: {e}")
logger.error(f"配置文件加载失败,使用默认配置: {e}")
return self.DEFAULT_CONFIG.copy()
else:
logger.warning(f"配置文件不存在: {self.config_path},使用默认配置")
return self.DEFAULT_CONFIG.copy()
def _merge_config(self, default: Dict, loaded: Dict) -> Dict:
@@ -74,11 +80,13 @@ class ConfigManager:
def save_config(self) -> bool:
"""保存配置到文件"""
try:
logger.debug(f"保存配置到文件: {self.config_path}")
with open(self.config_path, 'w', encoding='utf-8') as f:
json.dump(self.config, f, ensure_ascii=False, indent=4)
logger.info("配置保存成功")
return True
except IOError as e:
print(f"配置保存失败: {e}")
logger.error(f"配置保存失败: {e}")
return False
def get(self, *keys: str, default: Any = None) -> Any:
@@ -118,6 +126,7 @@ class ConfigManager:
self.config["llm_api"]["timeout"] = timeout
if retry_times:
self.config["llm_api"]["retry_times"] = retry_times
logger.info("LLM API配置已更新")
self.save_config()
def update_spider(self, target_url: str = None, xpath: str = None,
@@ -136,6 +145,7 @@ class ConfigManager:
self.config["spider"]["retry_times"] = retry_times
if retry_interval:
self.config["spider"]["retry_interval"] = retry_interval
logger.info("爬虫配置已更新")
self.save_config()
def update_ui(self, opacity: float = None, is_on_top: bool = None,
@@ -149,6 +159,7 @@ class ConfigManager:
self.config["ui"]["thresholds"]["cold"] = cold_threshold
if warm_threshold is not None:
self.config["ui"]["thresholds"]["warm"] = warm_threshold
logger.info("UI配置已更新")
self.save_config()
@property

View File

@@ -7,6 +7,7 @@ import json
from datetime import datetime
from typing import List, Dict, Optional, Tuple
from pathlib import Path
from loguru import logger
class DatabaseManager:
@@ -15,9 +16,11 @@ class DatabaseManager:
def __init__(self, db_path: str = "guba.db"):
self.db_path = Path(db_path)
self._init_db()
logger.info(f"数据库管理器初始化完成,数据库路径: {db_path}")
def _init_db(self):
"""初始化数据库表"""
logger.debug("初始化数据库表")
conn = self._get_connection()
cursor = conn.cursor()
@@ -59,6 +62,7 @@ class DatabaseManager:
conn.commit()
conn.close()
logger.debug("数据库表初始化完成")
def _get_connection(self) -> sqlite3.Connection:
"""获取数据库连接"""
@@ -83,6 +87,7 @@ class DatabaseManager:
content_hash = self.hash_content(content)
if self.is_comment_exists(content_hash):
logger.debug(f"评论已存在,跳过: {content[:30]}...")
return None # 已存在
conn = self._get_connection()
@@ -94,6 +99,7 @@ class DatabaseManager:
comment_id = cursor.lastrowid
conn.commit()
conn.close()
logger.info(f"添加新评论ID: {comment_id}")
return comment_id
def add_comments_batch(self, comments: List[Dict]) -> List[int]:
@@ -108,16 +114,23 @@ class DatabaseManager:
content_hash = self.hash_content(content)
if self.is_comment_exists(content_hash):
logger.debug(f"评论已存在,跳过: {content[:30]}...")
continue
try:
cursor.execute('''
INSERT INTO comments (content, content_hash, url, created_at)
INSERT OR IGNORE INTO comments (content, content_hash, url, created_at)
VALUES (?, ?, ?, ?)
''', (content, content_hash, url, datetime.now().isoformat()))
if cursor.rowcount > 0:
new_ids.append(cursor.lastrowid)
except Exception as e:
logger.warning(f"插入评论失败(可能已存在): {e}")
continue
conn.commit()
conn.close()
logger.info(f"批量添加评论完成,新增 {len(new_ids)}")
return new_ids
def get_unanalyzed_comments(self, limit: int = 50) -> List[Dict]:
@@ -132,7 +145,9 @@ class DatabaseManager:
''', (limit,))
rows = cursor.fetchall()
conn.close()
return [{'id': row[0], 'content': row[1], 'url': row[2]} for row in rows]
result = [{'id': row[0], 'content': row[1], 'url': row[2]} for row in rows]
logger.debug(f"获取到 {len(result)} 条未分析评论")
return result
def mark_analyzed(self, comment_id: int, sentiment_score: float, analysis_text: str):
"""标记评论已分析"""
@@ -154,6 +169,7 @@ class DatabaseManager:
conn.commit()
conn.close()
logger.debug(f"标记评论 {comment_id} 已分析,分数: {sentiment_score}")
def get_latest_sentiment_score(self) -> Optional[float]:
"""获取最新的情感分数"""
@@ -167,20 +183,34 @@ class DatabaseManager:
''')
row = cursor.fetchone()
conn.close()
return row[0] if row else None
score = row[0] if row else None
logger.debug(f"最新情感分数: {score}")
return score
def get_all_scores(self) -> List[float]:
"""获取所有已分析的分数"""
def get_all_scores(self, limit: int = None) -> List[float]:
"""获取已分析的分数,可指定数量限制"""
conn = self._get_connection()
cursor = conn.cursor()
if limit:
cursor.execute('''
SELECT sentiment_score FROM comments
WHERE analyzed = 1 AND sentiment_score IS NOT NULL
ORDER BY analyzed_at DESC
LIMIT ?
''', (limit,))
else:
cursor.execute('''
SELECT sentiment_score FROM comments
WHERE analyzed = 1 AND sentiment_score IS NOT NULL
ORDER BY analyzed_at DESC
''')
rows = cursor.fetchall()
conn.close()
return [row[0] for row in rows if row[0] is not None]
scores = [row[0] for row in rows if row[0] is not None]
logger.debug(f"获取到 {len(scores)} 个分数")
return scores
def get_comment_count(self) -> int:
"""获取评论总数"""
@@ -189,6 +219,7 @@ class DatabaseManager:
cursor.execute('SELECT COUNT(*) FROM comments')
count = cursor.fetchone()[0]
conn.close()
logger.debug(f"评论总数: {count}")
return count
def get_analyzed_count(self) -> int:
@@ -198,6 +229,7 @@ class DatabaseManager:
cursor.execute('SELECT COUNT(*) FROM comments WHERE analyzed = 1')
count = cursor.fetchone()[0]
conn.close()
logger.debug(f"已分析评论数: {count}")
return count
def get_recent_comments(self, limit: int = 10) -> List[Dict]:
@@ -212,8 +244,10 @@ class DatabaseManager:
''', (limit,))
rows = cursor.fetchall()
conn.close()
return [
result = [
{'id': row[0], 'content': row[1][:50] + '...' if len(row[1]) > 50 else row[1],
'score': row[2], 'analyzed_at': row[3]}
for row in rows
]
logger.debug(f"获取到 {len(result)} 条最近评论")
return result

BIN
indicator.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.2 KiB

View File

@@ -1,11 +1,13 @@
"""
大模型分析模块 - 调用LLM API分析评论情感
支持 OpenAI 兼容 API包括 NVIDIA API
"""
import json
import time
import re
from typing import Dict, Optional, Tuple
from openai import OpenAI, OpenAIError
from typing import Dict, Optional, Tuple, Any
from openai import OpenAI
from loguru import logger
class LLMAnalyzer:
@@ -31,26 +33,34 @@ class LLMAnalyzer:
def __init__(self, config: Dict):
self.config = config
self.base_url = config.get('base_url', 'https://api.openai.com/v1')
self.base_url = config.get('base_url', '')
self.api_key = config.get('api_key', '')
self.model = config.get('model', 'gpt-3.5-turbo')
self.timeout = config.get('timeout', 30)
self.model = config.get('model', '')
self.timeout = config.get('timeout', 120)
self.retry_times = config.get('retry_times', 3)
self.client = None
if self.api_key:
self.last_result = None # 保存最后一次分析结果
logger.info(f"LLM分析器配置 - base_url: {self.base_url}, model: {self.model}, timeout: {self.timeout}s, retry: {self.retry_times}")
if self.base_url and self.api_key:
self._init_client()
else:
logger.warning("LLM API 未配置base_url 或 api_key 为空")
def _init_client(self):
"""初始化OpenAI客户端"""
try:
logger.info(f"初始化LLM客户端: {self.base_url}")
self.client = OpenAI(
api_key=self.api_key,
base_url=self.base_url,
timeout=self.timeout
)
logger.info("LLM客户端初始化成功")
except Exception as e:
print(f"初始化LLM客户端失败: {e}")
logger.error(f"初始化LLM客户端失败: {e}")
def update_config(self, config: Dict):
"""更新配置"""
@@ -61,7 +71,7 @@ class LLMAnalyzer:
self.timeout = config.get('timeout', self.timeout)
self.retry_times = config.get('retry_times', self.retry_times)
if self.api_key:
if self.base_url and self.api_key:
self._init_client()
def analyze(self, comment: str) -> Tuple[Optional[int], Optional[str]]:
@@ -70,13 +80,21 @@ class LLMAnalyzer:
返回 (score, label)
"""
if not self.client:
logger.error("LLM客户端未初始化请检查API配置")
return None, "LLM未配置"
if not comment or not comment.strip():
logger.warning("评论内容为空")
return None, "评论为空"
logger.debug(f"开始分析评论: {comment[:50]}...")
logger.debug(f"使用模型: {self.model}, 超时设置: {self.timeout}")
for attempt in range(self.retry_times):
try:
logger.info(f"API调用尝试 {attempt + 1}/{self.retry_times}")
logger.debug(f"发送请求到 {self.base_url}")
response = self.client.chat.completions.create(
model=self.model,
messages=[
@@ -84,23 +102,45 @@ class LLMAnalyzer:
{"role": "user", "content": f"请分析以下评论的情感倾向:\n\n{comment}"}
],
temperature=0.3,
max_tokens=200
max_tokens=500,
timeout=self.timeout
)
result_text = response.choices[0].message.content.strip()
# 处理 deepseek-r1 的特殊结构(可能有 reasoning_content
message = response.choices[0].message
# 获取推理过程(如果有)
reasoning = getattr(message, 'reasoning_content', None)
if reasoning:
logger.debug(f"推理过程: {reasoning[:100]}...")
# 获取最终回答
result_text = message.content.strip() if message.content else ""
logger.debug(f"API返回原始内容: {result_text[:100]}...")
score, label = self._parse_response(result_text)
# 保存最后结果
self.last_result = {
'score': score,
'label': label,
'reasoning': reasoning,
'raw_response': result_text
}
if score is not None:
logger.info(f"分析完成: {score}分 - {label}")
return score, label
except OpenAIError as e:
print(f"API调用失败 (尝试 {attempt + 1}/{self.retry_times}): {e}")
if attempt < self.retry_times - 1:
time.sleep(2 ** attempt) # 指数退避
except Exception as e:
print(f"分析过程出错: {e}")
break
logger.warning(f"API调用失败 (尝试 {attempt + 1}/{self.retry_times}): {type(e).__name__}: {e}")
logger.debug(f"错误详情: {str(e)}")
if attempt < self.retry_times - 1:
wait_time = 2 ** attempt
logger.info(f"等待 {wait_time} 秒后重试...")
time.sleep(wait_time) # 指数退避
logger.error(f"所有 {self.retry_times} 次重试均失败")
return None, "分析失败"
def _parse_response(self, response: str) -> Tuple[Optional[int], Optional[str]]:
@@ -113,40 +153,57 @@ class LLMAnalyzer:
# 验证分数范围
score = max(0, min(100, int(score)))
logger.debug(f"JSON解析成功: {score} - {label}")
return score, label
except json.JSONDecodeError:
# 尝试文本提取
pass
logger.debug("JSON解析失败尝试文本提取")
# 尝试从文本中提取数字
# 尝试从文本中提取
numbers = re.findall(r'\b(\d{1,3})\b', response)
if numbers:
score = int(numbers[0])
score = max(0, min(100, score))
# 提取标签
label_match = re.search(r'["']([^"']+)["']', response)
label_match = re.search(r'"([^"]+)"', response)
if label_match:
label = label_match.group(1)
else:
label = response.split('\n')[0][:20] if response else '无法判断'
logger.debug(f"文本提取成功: {score} - {label}")
return score, label
logger.warning("无法解析响应")
return None, "解析失败"
def get_last_result(self) -> Optional[Dict[str, Any]]:
"""获取最后一次分析结果"""
return self.last_result
def analyze_batch(self, comments: list, delay: float = 1.0) -> list:
"""
批量分析评论
delay: 每次调用之间的延迟(秒)
"""
logger.info(f"开始批量分析 {len(comments)} 条评论,每次间隔 {delay}")
results = []
success_count = 0
fail_count = 0
for i, comment in enumerate(comments):
print(f"分析评论 {i + 1}/{len(comments)}...")
logger.info(f"正在分析第 {i + 1}/{len(comments)} 条评论")
score, label = self.analyze(comment)
if score is not None:
success_count += 1
logger.debug(f"{i + 1} 条评论分析成功: {score}分 - {label}")
else:
fail_count += 1
logger.warning(f"{i + 1} 条评论分析失败: {label}")
results.append({
'content': comment,
'score': score,
@@ -154,8 +211,10 @@ class LLMAnalyzer:
})
if delay > 0 and i < len(comments) - 1:
logger.debug(f"等待 {delay} 秒后继续...")
time.sleep(delay)
logger.info(f"批量分析完成,成功 {success_count} 条,失败 {fail_count}")
return results
def is_configured(self) -> bool:

179
main.py
View File

@@ -2,11 +2,11 @@
主程序入口 - 股吧人气指示器
"""
import sys
import logging
import time
from datetime import datetime
from PySide6.QtWidgets import QApplication
from PySide6.QtCore import QTimer, Signal, QObject
from PySide6.QtCore import QTimer, Signal, QObject, QThread
from loguru import logger
from config_manager import ConfigManager
from database import DatabaseManager
@@ -20,8 +20,10 @@ class BackendWorker(QObject):
fetch_finished = Signal(list)
analysis_finished = Signal(float)
analysis_result = Signal(dict) # 传递分析结果详情
error_occurred = Signal(str)
status_update = Signal(str)
stock_data_fetched = Signal(str, float) # 股票数据获取信号
def __init__(self, config_manager: ConfigManager, db_manager: DatabaseManager,
spider: SpiderManager, analyzer: LLMAnalyzer):
@@ -32,92 +34,136 @@ class BackendWorker(QObject):
self.analyzer = analyzer
self.running = False
self.last_fetch_time = 0
self.fetch_interval = 60 # 默认60
self.fetch_interval = 15 # 默认15
self.no_new_content_count = 0 # 无新内容计数
self.is_running_cycle = False # 防止并发执行
self.fetch_count = 0 # 爬取次数统计
self.analysis_count = 0 # API分析次数统计
logger.info("BackendWorker 初始化完成")
def start(self):
"""启动后台任务"""
self.running = True
self._run_cycle()
logger.info("后台任务已启动")
# 启动时立即执行第一次任务
QTimer.singleShot(1000, self._run_cycle)
def stop(self):
"""停止后台任务"""
self.running = False
logger.info("后台任务已停止")
def _run_cycle(self):
"""运行一个周期"""
if self.is_running_cycle:
logger.debug("上一个周期仍在执行,跳过本次")
return
self.is_running_cycle = True
if not self.running:
self.is_running_cycle = False
return
try:
# 1. 爬取评论
logger.info("开始爬取评论")
self.status_update.emit("正在爬取评论...")
self.fetch_count += 1
comments = self.spider.fetch()
if not comments:
self.no_new_content_count += 1
interval = self.fetch_interval * (1 + min(self.no_new_content_count * 0.5, 2))
# 爬取失败,使用固定间隔重试,不计入无新内容计数
interval = 15 # 固定15秒重试
logger.warning(f"未获取到新评论,{int(interval)}秒后重试")
self.status_update.emit(f"无新内容,{int(interval)}秒后重试...")
self.is_running_cycle = False
QTimer.singleShot(int(interval * 1000), self._run_cycle)
return
# 2. 写入数据库
logger.info(f"获取到 {len(comments)} 条评论")
self.status_update.emit(f"获取到 {len(comments)} 条评论...")
new_ids = self.db.add_comments_batch(comments)
if new_ids:
self.no_new_content_count = 0
logger.info(f"新增 {len(new_ids)} 条评论到数据库")
self.status_update.emit(f"新增 {len(new_ids)} 条评论")
# 3. 获取未分析评论并分析
unanalyzed = self.db.get_unanalyzed_comments(limit=10)
logger.info(f"获取到 {len(unanalyzed)} 条未分析评论")
if unanalyzed:
self.status_update.emit(f"开始分析 {len(unanalyzed)} 条评论...")
self._analyze_comments(unanalyzed)
else:
self.no_new_content_count += 1
logger.info("评论已存在,未新增")
# 4. 更新指示器
self._update_indicator()
except Exception as e:
logger.error(f"运行错误: {str(e)}")
self.error_occurred.emit(f"运行错误: {str(e)}")
# 安排下一次执行
if self.running:
interval = self.fetch_interval * (1 + min(self.no_new_content_count * 0.5, 2))
interval = self.fetch_interval * (1 + min(self.no_new_content_count * 1.0, 4.0))
logger.debug(f"下次执行将在 {int(interval)} 秒后")
self.is_running_cycle = False
QTimer.singleShot(int(interval * 1000), self._run_cycle)
else:
self.is_running_cycle = False
def _analyze_comments(self, comments):
"""分析评论"""
logger.info(f"开始分析 {len(comments)} 条评论")
for i, comment in enumerate(comments):
if not self.running:
logger.warning("分析被中断")
break
try:
self.status_update.emit(f"分析 {i+1}/{len(comments)}...")
logger.debug(f"分析第 {i+1} 条评论: {comment['content'][:50]}...")
self.analysis_count += 1
score, label = self.analyzer.analyze(comment['content'])
# 获取分析结果详情
last_result = self.analyzer.get_last_result()
if score is not None:
self.db.mark_analyzed(comment['id'], score, label)
logger.info(f"评论 {comment['id']} 分析完成: {score}分 - {label}")
# 更新状态显示为简洁格式
self.status_update.emit(f"分析 {i+1}/{len(comments)}...返回{score}")
# 每条评论分析完成后立即更新指示器
self._update_indicator()
time.sleep(1.0) # 延迟避免API限流
else:
self.db.mark_analyzed(comment['id'], 50, "无法判断")
logger.warning(f"评论 {comment['id']} 无法判断")
except Exception as e:
logger.error(f"分析评论 {comment.get('id', 'unknown')} 失败: {str(e)}")
self.error_occurred.emit(f"分析失败: {str(e)}")
self.db.mark_analyzed(comment['id'], 50, "分析异常")
def _update_indicator(self):
"""更新指示器显示"""
scores = self.db.get_all_scores()
# 获取最新的100条评论的分数
scores = self.db.get_all_scores(limit=100)
if not scores:
logger.debug("暂无分析分数")
return
# 计算平均分
avg_score = sum(scores) / len(scores)
logger.info(f"当前平均分: {avg_score:.2f} (基于最新的 {len(scores)} 条评论)")
# 根据阈值确定标签
thresholds = self.config.get('ui', 'thresholds', default={'cold': 30, 'warm': 70})
@@ -131,68 +177,161 @@ class BackendWorker(QObject):
else:
label = "中性"
self.analysis_finished.emit(avg_score)
logger.info(f"情感倾向: {label}")
# 发送整数分数
self.analysis_finished.emit(int(avg_score))
def fetch_stock_data(self):
"""爬取股票数据并添加到波形图"""
try:
logger.info("开始爬取股票数据")
stock_data = self.spider.fetch_sse_stock_data()
if stock_data and 'time' in stock_data and 'value' in stock_data:
logger.info(f"成功获取股票数据: {stock_data}")
# 发送股票数据到主窗口
self.stock_data_fetched.emit(stock_data['time'], stock_data['value'])
else:
logger.warning("未能获取有效的股票数据")
except Exception as e:
logger.error(f"爬取股票数据失败: {str(e)}")
def manual_refresh(self):
"""手动刷新"""
logger.info("用户手动刷新")
self.no_new_content_count = 0
if not self.is_running_cycle:
self._run_cycle()
else:
logger.info("上一个周期仍在执行,跳过手动刷新")
def update_fetch_interval(self, interval: int):
"""更新爬取间隔"""
logger.info(f"更新爬取间隔: {interval}")
self.fetch_interval = interval
def setup_logging(log_path: str, level: str = "INFO"):
"""配置日志"""
logging.basicConfig(
level=getattr(logging, level.upper(), logging.INFO),
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_path, encoding='utf-8'),
logging.StreamHandler()
]
logger.remove() # 移除默认的处理器
logger.add(
log_path,
rotation="10 MB",
retention="7 days",
level=level,
encoding="utf-8",
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}"
)
logger.add(
sys.stdout,
level=level,
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
)
def main():
"""主函数"""
logger.info("=== 股吧人气指示器启动 ===")
# 创建应用
app = QApplication(sys.argv)
app.setQuitOnLastWindowClosed(False) # 允许最小化到托盘
# 加载配置
logger.info("加载配置文件...")
config = ConfigManager("config.json")
# 配置日志
log_config = config.logging_config
setup_logging(log_config.get('path', 'guba.log'), log_config.get('level', 'INFO'))
logger.info(f"日志配置完成: {log_config.get('path', 'guba.log')}, 级别: {log_config.get('level', 'INFO')}")
# 初始化组件
logger.info("初始化数据库...")
db = DatabaseManager(config.database_config.get('path', 'guba.db'))
logger.info("初始化爬虫...")
spider = SpiderManager(config.spider_config)
logger.info("初始化LLM分析器...")
analyzer = LLMAnalyzer(config.llm_api_config)
# 创建后台工作器
worker = BackendWorker(config, db, spider, analyzer)
worker.update_fetch_interval(config.spider_config.get('fetch_interval', 60))
worker.update_fetch_interval(config.spider_config.get('fetch_interval', 15))
# 创建后台线程
logger.info("创建后台线程...")
worker_thread = QThread()
worker.moveToThread(worker_thread)
# 线程启动时开始工作
worker_thread.started.connect(worker.start)
# 线程结束时停止工作
worker_thread.finished.connect(worker.stop)
# 创建主窗口
window = MainWindow(config)
logger.info("创建主窗口...")
window = MainWindow(config, spider)
window.show()
# 连接信号
worker.status_update.connect(window.update_status)
worker.analysis_finished.connect(window.update_indicator)
worker.error_occurred.connect(lambda msg: window.show_message("错误", msg))
worker.stock_data_fetched.connect(window.add_waveform_data)
# 启动时从数据库初始化指示器显示
worker._update_indicator()
logger.info("初始化指示器显示完成")
# 设置按钮回调
window.set_refresh_callback(worker.manual_refresh)
window.set_config_callback(window.show_config)
# 启动后台任务
worker.start()
# 启动后台线程
worker_thread.start()
logger.info("后台线程已启动")
# 启动股票数据爬取定时器
stock_timer = QTimer()
stock_timer.timeout.connect(worker.fetch_stock_data)
stock_timer.start(60000) # 每分钟爬取一次股票数据
logger.info("股票数据爬取定时器已启动间隔60秒")
# 确保应用退出时清理线程
def cleanup():
logger.info("清理资源,停止后台线程...")
worker.stop()
worker_thread.quit()
worker_thread.wait()
# 显示统计信息
logger.info("=== 程序运行统计 ===")
logger.info(f"爬取网站次数: {worker.fetch_count}")
logger.info(f"提交API分析次数: {worker.analysis_count}")
logger.info("=== 统计结束 ===")
# 写入统计信息到文件
try:
stats_file = "statistics.txt"
with open(stats_file, "a", encoding="utf-8") as f:
f.write(f"=== 程序运行统计 - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ===\n")
f.write(f"爬取网站次数: {worker.fetch_count}\n")
f.write(f"提交API分析次数: {worker.analysis_count}\n")
f.write("=== 统计结束 ===\n\n")
logger.info(f"统计信息已写入到文件: {stats_file}")
except Exception as e:
logger.error(f"写入统计文件失败: {str(e)}")
logger.info("后台线程已停止")
app.aboutToQuit.connect(cleanup)
logger.info("应用启动完成,进入主循环")
# 运行应用
sys.exit(app.exec())

View File

@@ -4,10 +4,13 @@ PySide6 GUI界面模块
from PySide6.QtWidgets import (QWidget, QVBoxLayout, QHBoxLayout, QLabel,
QPushButton, QSlider, QDialog, QFormLayout,
QLineEdit, QSpinBox, QMessageBox, QSystemTrayIcon,
QMenu, QTextEdit, QGroupBox, QDialogButtonBox)
QMenu, QTextEdit, QGroupBox, QDialogButtonBox, QCheckBox)
from PySide6.QtCore import Qt, QTimer, Signal, QPoint
from PySide6.QtGui import QFont, QColor, QPainter, QBrush, QPen, QIcon, QAction
from typing import Callable, Optional
from loguru import logger
from waveform_widget import WaveformWidget
class SentimentIndicator(QWidget):
@@ -57,22 +60,38 @@ class SentimentIndicator(QWidget):
def _get_color(self, score: int) -> QColor:
"""根据分数获取颜色"""
if score < 30:
# 冷色系 - 蓝色/青色
ratio = score / 30
return QColor(int(0 + 100 * ratio), int(150 + 50 * ratio), 255)
elif score < 70:
# 中性 - 灰色/绿色
if score < 50:
ratio = (score - 30) / 20
return QColor(int(100 + 50 * ratio), int(200 + 20 * ratio), int(200 - 50 * ratio))
# 45-55分区间每1分一个颜色从绿色到黄色渐变
if 45 <= score <= 55:
# 在45-55分之间每1分一个颜色
ratio = (score - 45) / 10 # 0到1之间的比例
# 从绿色(0, 200, 0)渐变到黄色(255, 255, 0)
r = int(0 + 255 * ratio)
g = 200 if ratio < 0.5 else int(200 + 55 * (ratio - 0.5) * 2)
b = int(0 + 0 * ratio)
return QColor(r, g, b)
# 45分以下每5分一个颜色从深蓝到绿色渐变
elif score < 45:
# 将分数映射到0-8的区间0-40分每5分一个级别
level = score // 5
# 从深蓝色(0, 100, 255)渐变到绿色(0, 200, 0)
ratio = level / 8 # 0-8对应0-40分
r = int(0 + 0 * ratio)
g = int(100 + 100 * ratio)
b = int(255 - 255 * ratio)
return QColor(r, g, b)
# 55分以上每5分一个颜色从黄色到红色渐变
else:
ratio = (score - 50) / 20
return QColor(int(150 + 50 * ratio), int(220 - 20 * ratio), int(150 - 50 * ratio))
else:
# 暖色系 - 橙色/红色
ratio = (score - 70) / 30
return QColor(255, int(200 - 100 * ratio), int(50 + 50 * ratio))
# 将分数映射到0-8的区间60-100分每5分一个级别
level = (score - 60) // 5
level = max(0, min(8, level)) # 限制在0-8范围内
# 从黄色(255, 255, 0)渐变到红色(255, 0, 0)
ratio = level / 8
r = 255
g = int(255 - 255 * ratio)
b = 0
return QColor(r, g, b)
def get_description(self, score: int) -> str:
"""获取描述文本"""
@@ -125,7 +144,7 @@ class ConfigDialog(QDialog):
self.user_agent_edit = QLineEdit(spider_config.get('user_agent', ''))
self.interval_spin = QSpinBox()
self.interval_spin.setRange(10, 3600)
self.interval_spin.setValue(spider_config.get('fetch_interval', 60))
self.interval_spin.setValue(spider_config.get('fetch_interval', 15))
layout.addRow("目标URL:", self.url_edit)
layout.addRow("XPath:", self.xpath_edit)
@@ -197,10 +216,14 @@ class ConfigDialog(QDialog):
class MainWindow(QWidget):
"""主窗口"""
def __init__(self, config_manager, parent=None):
def __init__(self, config_manager, spider_manager=None, parent=None):
super().__init__(parent)
self.config_manager = config_manager
self.setWindowTitle("股吧人气指示器")
self.spider_manager = spider_manager
# 获取页面标题并设置窗口标题
self._set_window_title()
self._init_ui()
self._apply_config()
@@ -217,7 +240,7 @@ class MainWindow(QWidget):
layout.setContentsMargins(10, 10, 10, 10)
# 标题
self.title_label = QLabel("股吧人气")
self.title_label = QLabel("上证指数sh000001")
self.title_label.setAlignment(Qt.AlignCenter)
title_font = QFont()
title_font.setPointSize(14)
@@ -236,26 +259,64 @@ class MainWindow(QWidget):
status_font.setPointSize(10)
self.status_label.setFont(status_font)
# 波形图组件
self.waveform_widget = WaveformWidget()
self.waveform_widget.setMinimumHeight(200)
# 按钮
btn_layout = QHBoxLayout()
self.refresh_btn = QPushButton("刷新")
self.config_btn = QPushButton("配置")
self.quit_btn = QPushButton("退出")
self.quit_btn.clicked.connect(self.quit_app)
btn_layout.addWidget(self.refresh_btn)
btn_layout.addWidget(self.config_btn)
btn_layout.addWidget(self.quit_btn)
# 添加到主布局
layout.addWidget(self.title_label)
layout.addWidget(self.indicator)
layout.addWidget(self.score_label)
layout.addWidget(self.status_label)
layout.addWidget(self.waveform_widget)
layout.addLayout(btn_layout)
# 设置窗口标志(无边框、可拖拽)
self.setWindowFlags(Qt.FramelessWindowHint | Qt.WindowStaysOnTopHint)
self.setAttribute(Qt.WA_TranslucentBackground)
def _set_window_title(self):
"""设置窗口标题"""
logger.debug("设置窗口标题")
# 尝试从爬虫获取页面标题
if hasattr(self, 'spider_manager') and self.spider_manager:
try:
page_title = self.spider_manager.get_page_title()
if page_title:
# 从页面标题中提取股票名称
import re
match = re.search(r'(上证指数sh\d+)', page_title)
if match:
stock_name = match.group(1)
window_title = f"冷暖值 - {stock_name}"
self.setWindowTitle(window_title)
# 同时更新主标题标签(如果已初始化)
if hasattr(self, 'title_label'):
self.title_label.setText("上证指数sh000001")
logger.info(f"设置窗口标题: {window_title}")
return
except Exception as e:
logger.error(f"获取页面标题失败: {e}")
# 如果获取失败,使用默认标题
self.setWindowTitle("冷暖值 - 股吧人气")
logger.info("使用默认窗口标题")
def _init_tray_icon(self):
"""初始化系统托盘"""
logger.debug("初始化系统托盘")
self.tray_icon = QSystemTrayIcon(self)
self.tray_icon.setToolTip("股吧人气指示器")
@@ -275,9 +336,11 @@ class MainWindow(QWidget):
self.tray_icon.setContextMenu(tray_menu)
self.tray_icon.show()
logger.info("系统托盘初始化完成")
def quit_app(self):
"""退出应用"""
logger.info("退出应用")
self.close()
import sys
sys.exit(0)
@@ -322,7 +385,7 @@ class MainWindow(QWidget):
context_menu.addAction(config_action)
context_menu.addAction(quit_action)
context_menu.exec(event.globalPosition().toPoint())
context_menu.exec(event.globalPos())
def show_config(self):
"""显示配置对话框"""
@@ -336,23 +399,33 @@ class MainWindow(QWidget):
label = self.indicator.get_description(score)
self.indicator.set_value(score, label)
self.score_label.setText(f"{score} - {label}")
logger.debug(f"更新指示灯: {score}分 - {label}")
def update_status(self, text: str):
"""更新状态"""
self.status_label.setText(text)
logger.debug(f"更新状态: {text}")
def set_refresh_callback(self, callback: Callable):
"""设置刷新按钮回调"""
self.refresh_btn.clicked.connect(callback)
logger.debug("设置刷新按钮回调")
def set_config_callback(self, callback: Callable):
"""设置配置按钮回调"""
self.config_btn.clicked.connect(callback)
logger.debug("设置配置按钮回调")
def show_message(self, title: str, message: str, icon=QMessageBox.Information):
def show_message(self, title: str, message: str):
"""显示消息"""
logger.info(f"显示消息: {title} - {message}")
QMessageBox.information(self, title, message)
def add_waveform_data(self, time_str: str, value: float):
"""添加波形图数据点"""
self.waveform_widget.add_data_point(time_str, value)
logger.info(f"添加波形图数据点: 时间={time_str}, 值={value}")
class QCheckBox(QPushButton):
"""自定义复选框"""

View File

@@ -3,3 +3,4 @@ requests>=2.31.0
beautifulsoup4>=4.12.0
lxml>=4.9.0
openai>=1.0.0
playwright>=1.40.0

185
screenshot_manager.py Normal file
View File

@@ -0,0 +1,185 @@
"""
截图管理器 - 用于在非交易时间截取上海证券交易所网站图表
"""
import os
from datetime import datetime
from loguru import logger
try:
from playwright.sync_api import sync_playwright, TimeoutError as PlaywrightTimeoutError
except ImportError:
logger.warning("playwright未安装截图功能将不可用")
class ScreenshotManager:
"""截图管理器"""
def __init__(self, screenshot_dir: str = "screenshots"):
"""初始化截图管理器"""
self.screenshot_dir = screenshot_dir
self.target_url = "https://www.sse.com.cn/"
self.chart_xpath_pattern = "//*[@id=\"hq_area\"]"
# 创建截图目录
os.makedirs(self.screenshot_dir, exist_ok=True)
logger.info(f"截图管理器初始化完成,截图目录: {self.screenshot_dir}")
def capture_chart_screenshot(self) -> str:
"""
截取上海证券交易所网站图表
返回截图文件路径,失败时返回空字符串
"""
try:
# 检查playwright是否可用
if 'sync_playwright' not in globals():
logger.error("playwright未安装无法使用截图功能")
return ""
with sync_playwright() as p:
# 启动浏览器(无头模式,后台运行)
browser = p.chromium.launch(headless=True)
page = browser.new_page()
# 设置页面超时
page.set_default_timeout(60000) # 60秒超时
# 访问目标网页
logger.info(f"访问上海证券交易所网站: {self.target_url}")
page.goto(self.target_url, wait_until="domcontentloaded")
# 等待页面加载完成
page.wait_for_load_state("networkidle")
# 等待页面完全加载
page.wait_for_timeout(5000) # 额外等待5秒
# 等待图表元素出现
logger.info("等待图表元素加载...")
# 使用动态XPath模式查找图表元素
chart_element = None
selectors = [
self.chart_xpath_pattern, # 主要选择器highcharts-xxxxxxx-0格式
"//*[contains(@class, 'highcharts')]",
"//*[contains(@id, 'highcharts')]",
"//svg",
"//canvas",
"//div[contains(@class, 'chart')]",
"//div[contains(@class, 'graph')]"
]
for selector in selectors:
try:
# 等待选择器出现
page.wait_for_selector(selector, timeout=10000)
elements = page.query_selector_all(selector)
if elements:
# 选择第一个可见的元素
for element in elements:
if element.is_visible():
chart_element = element
logger.info(f"找到图表元素: {selector}")
break
if chart_element:
break
except PlaywrightTimeoutError:
logger.debug(f"选择器超时: {selector}")
continue
except Exception as e:
logger.debug(f"选择器错误 {selector}: {e}")
continue
if not chart_element:
logger.warning("未找到任何图表元素,尝试截取整个页面")
# 如果找不到图表元素,截取整个页面
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
screenshot_path = os.path.join(self.screenshot_dir, f"sse_page_{timestamp}.png")
page.screenshot(path=screenshot_path)
logger.info(f"截取整个页面: {screenshot_path}")
browser.close()
return screenshot_path
# 检查元素是否可见
if not chart_element.is_visible():
logger.warning("图表元素不可见,尝试滚动到元素位置")
chart_element.scroll_into_view_if_needed()
# 生成截图文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
screenshot_path = os.path.join(self.screenshot_dir, f"sse_chart_{timestamp}.png")
# 截取图表元素
logger.info("开始截取图表元素")
# 直接使用元素进行截图
chart_element.screenshot(path=screenshot_path)
logger.info(f"✅ 图表截图完成,保存至: {screenshot_path}")
# 关闭浏览器
browser.close()
return screenshot_path
except PlaywrightTimeoutError as e:
logger.error(f"页面加载超时: {e}")
return ""
except Exception as e:
logger.error(f"截图过程中发生错误: {e}")
return ""
def get_latest_screenshot(self) -> str:
"""获取最新的截图文件路径"""
try:
if not os.path.exists(self.screenshot_dir):
return ""
# 获取所有截图文件
screenshot_files = []
for file in os.listdir(self.screenshot_dir):
if file.startswith("sse_chart_") and file.endswith(".png"):
file_path = os.path.join(self.screenshot_dir, file)
screenshot_files.append((file_path, os.path.getmtime(file_path)))
if not screenshot_files:
return ""
# 按修改时间排序,获取最新的文件
screenshot_files.sort(key=lambda x: x[1], reverse=True)
return screenshot_files[0][0]
except Exception as e:
logger.error(f"获取最新截图失败: {e}")
return ""
def cleanup_old_screenshots(self, keep_count: int = 10):
"""清理旧的截图文件,只保留最新的几个"""
try:
if not os.path.exists(self.screenshot_dir):
return
# 获取所有截图文件
screenshot_files = []
for file in os.listdir(self.screenshot_dir):
if file.startswith("sse_chart_") and file.endswith(".png"):
file_path = os.path.join(self.screenshot_dir, file)
screenshot_files.append((file_path, os.path.getmtime(file_path)))
if len(screenshot_files) <= keep_count:
return
# 按修改时间排序,删除旧的文件
screenshot_files.sort(key=lambda x: x[1])
files_to_delete = screenshot_files[:-keep_count]
for file_path, _ in files_to_delete:
try:
os.remove(file_path)
logger.info(f"删除旧截图: {file_path}")
except Exception as e:
logger.error(f"删除截图文件失败 {file_path}: {e}")
except Exception as e:
logger.error(f"清理旧截图失败: {e}")

Binary file not shown.

After

Width:  |  Height:  |  Size: 531 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 531 KiB

153
spider.py
View File

@@ -8,6 +8,7 @@ from typing import List, Dict, Optional
from urllib.parse import urljoin
from bs4 import BeautifulSoup
import random
from loguru import logger
class SpiderManager:
@@ -21,6 +22,7 @@ class SpiderManager:
})
self.retry_times = config.get('retry_times', 3)
self.retry_interval = config.get('retry_interval', 5)
logger.info(f"爬虫管理器初始化完成目标URL: {config.get('target_url', '')}")
def fetch(self, url: str = None, xpath: str = None) -> List[Dict]:
"""
@@ -31,13 +33,18 @@ class SpiderManager:
target_xpath = xpath or self.config.get('xpath', '')
if not target_url:
logger.warning("未设置目标URL")
return []
logger.info(f"开始抓取: {target_url}")
html = self._fetch_with_retry(target_url)
if not html:
logger.warning("网页获取失败")
return []
return self._parse_comments(html, target_xpath, target_url)
comments = self._parse_comments(html, target_xpath, target_url)
logger.info(f"解析完成,获取到 {len(comments)} 条评论")
return comments
def _fetch_with_retry(self, url: str, max_retries: int = None) -> Optional[str]:
"""带重试的网页获取"""
@@ -45,18 +52,49 @@ class SpiderManager:
for attempt in range(max_retries):
try:
logger.debug(f"尝试 {attempt + 1}/{max_retries} 获取网页")
response = self.session.get(url, timeout=30)
response.raise_for_status()
response.encoding = response.apparent_encoding
logger.debug(f"网页获取成功,大小: {len(response.text)} 字节")
return response.text
except requests.RequestException as e:
print(f"请求失败 (尝试 {attempt + 1}/{max_retries}): {e}")
logger.warning(f"请求失败 (尝试 {attempt + 1}/{max_retries}): {e}")
if attempt < max_retries - 1:
time.sleep(self.retry_interval + random.uniform(0, 2))
else:
logger.error(f"所有重试均失败: {url}")
return None
return None
def get_page_title(self, url: str = None) -> str:
"""获取页面标题"""
target_url = url or self.config.get('target_url', '')
if not target_url:
logger.warning("未设置目标URL")
return ""
logger.info(f"获取页面标题: {target_url}")
html = self._fetch_with_retry(target_url)
if not html:
logger.warning("网页获取失败")
return ""
try:
# 使用 lxml 解析页面标题
tree = etree.HTML(html)
title_elements = tree.xpath('//title/text()')
if title_elements:
title = title_elements[0].strip()
logger.info(f"获取到页面标题: {title}")
return title
else:
logger.warning("未找到页面标题")
return ""
except Exception as e:
logger.error(f"解析页面标题失败: {e}")
return ""
def _parse_comments(self, html: str, xpath: str, base_url: str) -> List[Dict]:
"""解析评论"""
comments = []
@@ -65,10 +103,11 @@ class SpiderManager:
# 使用 lxml 解析
tree = etree.HTML(html)
elements = tree.xpath(xpath)
logger.debug(f"XPath 匹配到 {len(elements)} 个元素")
for elem in elements:
try:
text = elem.text_content().strip()
text = etree.tostring(elem, method='text', encoding='unicode').strip()
if text:
# 获取链接的 href如果存在
href = elem.get('href')
@@ -79,11 +118,11 @@ class SpiderManager:
'url': full_url
})
except Exception as e:
print(f"解析元素失败: {e}")
logger.error(f"解析元素失败: {e}")
continue
except Exception as e:
print(f"XPath解析失败: {e}")
logger.error(f"XPath解析失败: {e}")
# 备选解析方法
comments = self._fallback_parse(html, base_url)
@@ -93,6 +132,7 @@ class SpiderManager:
"""备选解析方法 - 使用 BeautifulSoup"""
comments = []
try:
logger.debug("使用备选解析方法")
soup = BeautifulSoup(html, 'lxml')
# 尝试查找常见的评论元素
@@ -106,11 +146,112 @@ class SpiderManager:
'content': text,
'url': base_url
})
logger.debug(f"备选解析获取到 {len(comments)} 条评论")
except Exception as e:
print(f"备选解析失败: {e}")
logger.error(f"备选解析失败: {e}")
return comments
def is_trading_time(self) -> bool:
"""判断当前是否为交易时间"""
from datetime import datetime, time
current_time = datetime.now().time()
# 上午交易时间: 9:30-11:30
morning_start = time(9, 30)
morning_end = time(11, 30)
# 下午交易时间: 13:00-15:00
afternoon_start = time(13, 0)
afternoon_end = time(15, 0)
# 判断是否在交易时间内
is_trading = ((morning_start <= current_time <= morning_end) or
(afternoon_start <= current_time <= afternoon_end))
logger.debug(f"当前时间 {current_time.strftime('%H:%M')} 是否为交易时间: {is_trading}")
return is_trading
def fetch_sse_stock_data(self) -> Dict[str, float]:
"""
爬取上海证券交易所股票数据
返回包含时间和数值的字典
"""
# 检查是否为交易时间
if not self.is_trading_time():
logger.info("当前为非交易时间,跳过股票数据爬取")
return {}
sse_url = "https://www.sse.com.cn/"
xpath = "//*[@id=\"hq_area\"]"
logger.info(f"开始爬取上海证券交易所数据: {sse_url}")
html = self._fetch_with_retry(sse_url)
if not html:
logger.warning("上海证券交易所网页获取失败")
return {}
try:
# 使用 lxml 解析
tree = etree.HTML(html)
# 尝试获取股票数值
elements = tree.xpath(xpath)
if not elements:
logger.warning("未找到股票数据元素尝试备用XPath")
# 尝试备用XPath
backup_xpaths = [
"//*[@id='hq_controller']//td[contains(@class, 'price')]//text()",
"//*[contains(@class, 'stock-price')]//text()",
"//*[contains(@class, 'price')]//text()"
]
for backup_xpath in backup_xpaths:
elements = tree.xpath(backup_xpath)
if elements:
logger.info(f"使用备用XPath获取到数据")
break
if elements:
# 获取文本内容并尝试转换为数值
text_content = etree.tostring(elements[0], method='text', encoding='unicode').strip()
logger.info(f"获取到股票数据文本: {text_content}")
# 提取数值(可能包含逗号、小数点等)
import re
# 匹配数字(包括小数点和逗号分隔符)
numbers = re.findall(r'[\d,]+(?:\.\d+)?', text_content)
if numbers:
# 去除逗号并转换为浮点数
value_str = numbers[0].replace(',', '')
try:
stock_value = float(value_str)
# 获取当前时间
from datetime import datetime
current_time = datetime.now().strftime("%H:%M")
logger.info(f"成功获取股票数据: 时间={current_time}, 值={stock_value}")
return {
'time': current_time,
'value': stock_value
}
except ValueError as e:
logger.error(f"数值转换失败: {value_str}, 错误: {e}")
else:
logger.warning(f"未找到有效数值: {text_content}")
else:
logger.warning("未找到股票数据元素")
except Exception as e:
logger.error(f"解析上海证券交易所数据失败: {e}")
return {}
def set_user_agent(self, user_agent: str):
"""更新User-Agent"""
self.session.headers.update({'User-Agent': user_agent})

32
test_api.py Normal file
View File

@@ -0,0 +1,32 @@
import os
from openai import OpenAI
# 设置 API key 为环境变量
os.environ["NVIDIA_API_KEY"] = "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"
client = OpenAI(
base_url="https://integrate.api.nvidia.com/v1",
api_key=os.environ["NVIDIA_API_KEY"]
)
print("正在测试 API 连接...")
print("发送消息: 天气")
print("-" * 50)
completion = client.chat.completions.create(
model="deepseek-ai/deepseek-r1",
messages=[{"role": "user", "content": "天气"}],
temperature=0.6,
top_p=0.7,
max_tokens=4096,
stream=False
)
reasoning = getattr(completion.choices[0].message, "reasoning_content", None)
if reasoning:
print("推理过程:")
print(reasoning)
print("-" * 50)
print("回答:")
print(completion.choices[0].message.content)

35
test_nvidia_api.py Normal file
View File

@@ -0,0 +1,35 @@
"""
测试 NVIDIA API 连接
"""
from openai import OpenAI
client = OpenAI(
base_url="https://integrate.api.nvidia.com/v1",
api_key="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"
)
try:
print("开始测试 API 调用...")
completion = client.chat.completions.create(
model="deepseek-ai/deepseek-r1",
messages=[{"role": "user", "content": "你好,请简单介绍一下你自己"}],
temperature=0.6,
top_p=0.7,
max_tokens=4096,
stream=False,
timeout=120
)
reasoning = getattr(completion.choices[0].message, "reasoning_content", None)
if reasoning:
print("推理过程:")
print(reasoning)
print("\n" + "="*50 + "\n")
print("回答内容:")
print(completion.choices[0].message.content)
print("\n" + "="*50)
print("API 调用成功!")
except Exception as e:
print(f"API 调用失败: {type(e).__name__}: {e}")

43
test_screenshot.py Normal file
View File

@@ -0,0 +1,43 @@
"""
测试截图功能
"""
import os
from screenshot_manager import ScreenshotManager
from loguru import logger
def test_screenshot():
"""测试截图功能"""
logger.info("开始测试截图功能")
# 创建截图管理器
screenshot_manager = ScreenshotManager()
# 测试截图
screenshot_path = screenshot_manager.capture_chart_screenshot()
if screenshot_path:
logger.info(f"✅ 截图成功: {screenshot_path}")
# 检查文件是否存在
if os.path.exists(screenshot_path):
file_size = os.path.getsize(screenshot_path)
logger.info(f"截图文件大小: {file_size} 字节")
# 获取最新截图
latest_screenshot = screenshot_manager.get_latest_screenshot()
logger.info(f"最新截图: {latest_screenshot}")
# 清理旧截图
screenshot_manager.cleanup_old_screenshots(keep_count=2)
logger.info("清理旧截图完成")
else:
logger.error("截图文件不存在")
else:
logger.error("截图失败")
logger.info("截图功能测试完成")
if __name__ == "__main__":
test_screenshot()

429
waveform_widget.py Normal file
View File

@@ -0,0 +1,429 @@
"""
波形图组件 - 用于绘制股票数据波形图
"""
import math
import os
from datetime import datetime, time
from PySide6.QtWidgets import QWidget
from PySide6.QtGui import QPainter, QPen, QColor, QBrush, QPixmap
from PySide6.QtCore import QPointF, QTimer
from loguru import logger
# 尝试导入截图管理器
try:
from screenshot_manager import ScreenshotManager
except ImportError:
ScreenshotManager = None
logger.warning("截图管理器导入失败,截图功能将不可用")
class WaveformWidget(QWidget):
"""波形图控件"""
def __init__(self, parent=None):
super().__init__(parent)
self.data_points = [] # 存储数据点 [(time, value)]
self.base_value = 0 # 基准值
self.screenshot_manager = None
self.latest_screenshot_path = ""
self.setMinimumSize(600, 300)
# 设置背景色
self.setStyleSheet("background-color: #1e1e1e;")
# 初始化截图管理器
if ScreenshotManager:
self.screenshot_manager = ScreenshotManager()
logger.info("截图管理器初始化完成")
logger.info("WaveformWidget 初始化完成")
def time_to_x_position(self, time_str: str, total_width: int) -> float:
"""
将时间转换为X轴位置
时间折算关系:
- 9:30 -> 最左侧 (x=0)
- 11:30 -> 中间 (x=total_width/2)
- 13:00 -> 中间 (x=total_width/2)
- 15:00 -> 最右侧 (x=total_width)
- 10:30 -> 左侧四分之一 (x=total_width/4)
- 14:00 -> 右侧四分之一 (x=total_width*3/4)
"""
try:
# 解析时间字符串
current_time = datetime.strptime(time_str, "%H:%M").time()
# 定义关键时间点
market_start = time(9, 30) # 9:30
market_mid1 = time(11, 30) # 11:30
market_mid2 = time(13, 0) # 13:00
market_end = time(15, 0) # 15:00
# 计算总交易时间(分钟)
morning_duration = (market_mid1.hour - market_start.hour) * 60 + \
(market_mid1.minute - market_start.minute)
afternoon_duration = (market_end.hour - market_mid2.hour) * 60 + \
(market_end.minute - market_mid2.minute)
total_duration = morning_duration + afternoon_duration
# 计算当前时间相对于开盘时间的分钟数
if current_time <= market_mid1:
# 上午交易时段
minutes_from_start = (current_time.hour - market_start.hour) * 60 + \
(current_time.minute - market_start.minute)
# 上午时段占一半宽度
x_ratio = minutes_from_start / morning_duration * 0.5
elif current_time >= market_mid2:
# 下午交易时段
minutes_from_start = (current_time.hour - market_mid2.hour) * 60 + \
(current_time.minute - market_mid2.minute)
# 下午时段占一半宽度,从中间开始
x_ratio = 0.5 + minutes_from_start / afternoon_duration * 0.5
else:
# 午休时间,统一放在中间
x_ratio = 0.5
return x_ratio * total_width
except Exception as e:
logger.error(f"时间转换错误: {time_str}, 错误: {e}")
return total_width / 2 # 默认返回中间位置
def is_trading_time(self, time_str: str) -> bool:
"""判断是否为交易时间"""
try:
current_time = datetime.strptime(time_str, "%H:%M").time()
# 上午交易时间: 9:30-11:30
morning_start = time(9, 30)
morning_end = time(11, 30)
# 下午交易时间: 13:00-15:00
afternoon_start = time(13, 0)
afternoon_end = time(15, 0)
# 判断是否在交易时间内
is_trading = ((morning_start <= current_time <= morning_end) or
(afternoon_start <= current_time <= afternoon_end))
logger.debug(f"时间 {time_str} 是否为交易时间: {is_trading}")
return is_trading
except Exception as e:
logger.error(f"时间判断错误: {time_str}, 错误: {e}")
return False
def add_data_point(self, time_str: str, value: float):
"""添加数据点"""
# 检查是否为交易时间
if not self.is_trading_time(time_str):
logger.info(f"非交易时间 {time_str},跳过数据点添加")
return
# 如果是第一个数据点,设置基准值
if not self.data_points:
self.base_value = value
logger.info(f"设置基准值: {self.base_value}")
self.data_points.append((time_str, value))
logger.info(f"添加数据点: 时间={time_str}, 值={value}")
# 限制数据点数量,避免内存过大
if len(self.data_points) > 100:
self.data_points = self.data_points[-100:]
# 触发重绘
self.update()
def clear_data(self):
"""清除所有数据"""
self.data_points.clear()
self.base_value = 0
self.update()
logger.info("波形图数据已清除")
def paintEvent(self, event):
"""绘制波形图"""
painter = QPainter(self)
painter.setRenderHint(QPainter.Antialiasing)
width = self.width()
height = self.height()
# 如果没有数据点,显示提示信息
if not self.data_points:
self._draw_no_data_message(painter, width, height)
return
# 检查最后一个数据点的时间是否为交易时间
last_time = self.data_points[-1][0] if self.data_points else ""
if last_time and not self.is_trading_time(last_time):
self._draw_non_trading_message(painter, width, height)
return
# 绘制网格和坐标轴
self._draw_grid(painter, width, height)
# 绘制波形线
self._draw_waveform(painter, width, height)
# 绘制数据点
self._draw_data_points(painter, width, height)
def _draw_grid(self, painter: QPainter, width: int, height: int):
"""绘制网格和坐标轴"""
# 设置网格颜色
grid_color = QColor(100, 100, 100)
painter.setPen(QPen(grid_color, 1, Qt.DashLine))
# 绘制水平网格线
for i in range(1, 5):
y = height * i // 5
painter.drawLine(0, y, width, y)
# 绘制垂直网格线(时间刻度)
time_points = ["9:30", "10:30", "11:30", "13:00", "14:00", "15:00"]
for time_str in time_points:
x = self.time_to_x_position(time_str, width)
painter.drawLine(x, 0, x, height)
# 绘制坐标轴
axis_color = QColor(200, 200, 200)
painter.setPen(QPen(axis_color, 2))
painter.drawLine(0, height // 2, width, height // 2) # X轴
painter.drawLine(0, 0, 0, height) # Y轴
# 绘制时间标签
painter.setPen(QPen(QColor(150, 150, 150), 1))
for time_str in time_points:
x = self.time_to_x_position(time_str, width)
painter.drawText(int(x) - 20, height - 5, time_str)
def _draw_waveform(self, painter: QPainter, width: int, height: int):
"""绘制波形线"""
if len(self.data_points) < 2:
return
# 设置波形线样式
waveform_color = QColor(0, 200, 255)
painter.setPen(QPen(waveform_color, 3))
points = []
# 计算Y轴范围基准值±100点
y_min = self.base_value - 100
y_max = self.base_value + 100
y_range = y_max - y_min
for time_str, value in self.data_points:
x = self.time_to_x_position(time_str, width)
# 计算Y坐标从底部到顶部
y_ratio = (value - y_min) / y_range
y = height - (y_ratio * height)
points.append(QPointF(x, y))
# 绘制折线
for i in range(len(points) - 1):
painter.drawLine(points[i], points[i + 1])
def _draw_data_points(self, painter: QPainter, width: int, height: int):
"""绘制数据点"""
point_color = QColor(255, 100, 100)
painter.setPen(QPen(point_color, 1))
painter.setBrush(QBrush(point_color))
# 计算Y轴范围
y_min = self.base_value - 100
y_max = self.base_value + 100
y_range = y_max - y_min
for time_str, value in self.data_points:
x = self.time_to_x_position(time_str, width)
y_ratio = (value - y_min) / y_range
y = height - (y_ratio * height)
# 绘制数据点圆圈
painter.drawEllipse(int(x) - 3, int(y) - 3, 6, 6)
# 显示数值标签
painter.setPen(QPen(QColor(200, 200, 200), 1))
painter.drawText(int(x) + 5, int(y) - 5, f"{value:.2f}")
def _draw_no_data_message(self, painter: QPainter, width: int, height: int):
"""绘制无数据提示信息"""
# 设置提示信息颜色
message_color = QColor(150, 150, 150)
painter.setPen(QPen(message_color, 2))
# 设置字体
font = painter.font()
font.setPointSize(14)
font.setBold(True)
painter.setFont(font)
# 绘制提示信息
message = "等待交易时间数据..."
text_rect = painter.fontMetrics().boundingRect(message)
x = (width - text_rect.width()) // 2
y = height // 2
painter.drawText(x, y, message)
# 绘制交易时间说明
font.setPointSize(10)
font.setBold(False)
painter.setFont(font)
info = "交易时间: 9:30-11:30, 13:00-15:00"
info_rect = painter.fontMetrics().boundingRect(info)
x_info = (width - info_rect.width()) // 2
y_info = y + 30
painter.drawText(x_info, y_info, info)
def _draw_non_trading_message(self, painter: QPainter, width: int, height: int):
"""绘制非交易时间提示信息或截图"""
# 如果有截图管理器,尝试显示截图
if self.screenshot_manager:
screenshot_path = self._get_or_capture_screenshot()
if screenshot_path:
self._draw_screenshot(painter, screenshot_path, width, height)
return
# 如果没有截图或截图失败,显示文本提示
self._draw_text_message(painter, width, height)
def _get_or_capture_screenshot(self) -> str:
"""获取或捕获最新的截图"""
try:
# 检查是否有最新的截图
latest_screenshot = self.screenshot_manager.get_latest_screenshot()
# 如果截图文件存在且较新5分钟内使用现有截图
if latest_screenshot and os.path.exists(latest_screenshot):
file_time = datetime.fromtimestamp(os.path.getmtime(latest_screenshot))
current_time = datetime.now()
# 如果截图在5分钟内直接使用
if (current_time - file_time).total_seconds() < 300: # 5分钟
return latest_screenshot
# 否则捕获新的截图
logger.info("开始捕获上海证券交易所网站截图")
new_screenshot = self.screenshot_manager.capture_chart_screenshot()
if new_screenshot:
# 清理旧截图
self.screenshot_manager.cleanup_old_screenshots()
return new_screenshot
except Exception as e:
logger.error(f"获取截图失败: {e}")
return ""
def _draw_screenshot(self, painter: QPainter, screenshot_path: str, width: int, height: int):
"""绘制截图"""
try:
# 加载截图
pixmap = QPixmap(screenshot_path)
if pixmap.isNull():
logger.warning(f"截图加载失败: {screenshot_path}")
self._draw_text_message(painter, width, height)
return
# 计算缩放比例,保持宽高比
pixmap_width = pixmap.width()
pixmap_height = pixmap.height()
# 计算缩放比例,使截图适应显示区域
scale_x = width / pixmap_width
scale_y = height / pixmap_height
scale = min(scale_x, scale_y) * 0.8 # 留出边距
# 计算显示尺寸
display_width = int(pixmap_width * scale)
display_height = int(pixmap_height * scale)
# 计算显示位置(居中)
x = (width - display_width) // 2
y = (height - display_height) // 2
# 绘制截图
scaled_pixmap = pixmap.scaled(display_width, display_height)
painter.drawPixmap(x, y, scaled_pixmap)
# 绘制标题
font = painter.font()
font.setPointSize(12)
font.setBold(True)
painter.setFont(font)
title_color = QColor(255, 255, 255)
painter.setPen(QPen(title_color, 2))
title = "上海证券交易所实时图表"
title_rect = painter.fontMetrics().boundingRect(title)
title_x = (width - title_rect.width()) // 2
title_y = y - 10
painter.drawText(title_x, title_y, title)
# 绘制更新时间
font.setPointSize(8)
font.setBold(False)
painter.setFont(font)
update_time = datetime.now().strftime("更新时间: %H:%M:%S")
time_rect = painter.fontMetrics().boundingRect(update_time)
time_x = (width - time_rect.width()) // 2
time_y = y + display_height + 20
painter.drawText(time_x, time_y, update_time)
except Exception as e:
logger.error(f"绘制截图失败: {e}")
self._draw_text_message(painter, width, height)
def _draw_text_message(self, painter: QPainter, width: int, height: int):
"""绘制文本提示信息"""
# 设置提示信息颜色
message_color = QColor(200, 100, 100)
painter.setPen(QPen(message_color, 2))
# 设置字体
font = painter.font()
font.setPointSize(14)
font.setBold(True)
painter.setFont(font)
# 绘制提示信息
message = "非交易时间"
text_rect = painter.fontMetrics().boundingRect(message)
x = (width - text_rect.width()) // 2
y = height // 2
painter.drawText(x, y, message)
# 绘制交易时间说明
font.setPointSize(10)
font.setBold(False)
painter.setFont(font)
info = "交易时间: 9:30-11:30, 13:00-15:00"
info_rect = painter.fontMetrics().boundingRect(info)
x_info = (width - info_rect.width()) // 2
y_info = y + 30
painter.drawText(x_info, y_info, info)
# 绘制当前时间
current_time = datetime.now().strftime("%H:%M")
time_info = f"当前时间: {current_time}"
time_rect = painter.fontMetrics().boundingRect(time_info)
x_time = (width - time_rect.width()) // 2
y_time = y_info + 25
painter.drawText(x_time, y_time, time_info)