feat: 新增股票数据波形图和截图功能
refactor: 重构数据库和LLM分析器逻辑 fix: 修复爬虫解析和UI显示问题 docs: 更新配置文件和注释 style: 优化代码格式和日志输出
This commit is contained in:
25
build.bat
Normal file
25
build.bat
Normal file
@@ -0,0 +1,25 @@
|
||||
@echo off
|
||||
echo ========================================
|
||||
echo 股吧人气指示器 - 打包工具
|
||||
echo ========================================
|
||||
|
||||
REM 检查 pyinstaller 是否安装
|
||||
pip show pyinstaller >nul 2>&1
|
||||
if errorlevel 1 (
|
||||
echo 正在安装 pyinstaller...
|
||||
pip install pyinstaller
|
||||
)
|
||||
|
||||
echo 开始打包...
|
||||
pyinstaller build.spec
|
||||
|
||||
if exist "dist\guba-indicator.exe" (
|
||||
echo ========================================
|
||||
echo 打包成功!
|
||||
echo 可执行文件位置: dist\guba-indicator.exe
|
||||
echo ========================================
|
||||
) else (
|
||||
echo 打包失败,请检查错误信息
|
||||
)
|
||||
|
||||
pause
|
||||
58
build.spec
Normal file
58
build.spec
Normal file
@@ -0,0 +1,58 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
import sys
|
||||
import os
|
||||
|
||||
block_cipher = None
|
||||
|
||||
# 添加项目路径
|
||||
project_path = os.path.abspath('.')
|
||||
if project_path not in sys.path:
|
||||
sys.path.insert(0, project_path)
|
||||
|
||||
a = Analysis(
|
||||
['main.py'],
|
||||
pathex=['.', project_path],
|
||||
binaries=[],
|
||||
datas=[],
|
||||
hiddenimports=[],
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
excludes=['tkinter', 'IPython', 'pytest', 'matplotlib', 'pandas', 'scipy', 'numpy'],
|
||||
win_no_prefer_redirects=False,
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher,
|
||||
noarchive=False,
|
||||
)
|
||||
|
||||
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||
|
||||
exe = EXE(
|
||||
pyz,
|
||||
a.scripts,
|
||||
[],
|
||||
exclude_binaries=True,
|
||||
name='guba-indicator',
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
upx=True,
|
||||
console=False,
|
||||
disable_windowed_traceback=False,
|
||||
argv_emulation=False,
|
||||
target_arch=None,
|
||||
codesign_identity=None,
|
||||
entitlements_file=None,
|
||||
icon='indicator.ico',
|
||||
)
|
||||
|
||||
coll = COLLECT(
|
||||
exe,
|
||||
a.binaries,
|
||||
a.zipfiles,
|
||||
a.datas,
|
||||
strip=False,
|
||||
upx=True,
|
||||
upx_exclude=[],
|
||||
name='guba-indicator',
|
||||
)
|
||||
@@ -5,6 +5,7 @@ import json
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
@@ -12,10 +13,10 @@ class ConfigManager:
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
"llm_api": {
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"base_url": "https://integrate.api.nvidia.com/v1",
|
||||
"api_key": "",
|
||||
"model": "gpt-3.5-turbo",
|
||||
"timeout": 30,
|
||||
"model": "deepseek-ai/deepseek-r1",
|
||||
"timeout": 120,
|
||||
"retry_times": 3
|
||||
},
|
||||
"spider": {
|
||||
@@ -46,19 +47,24 @@ class ConfigManager:
|
||||
def __init__(self, config_path: str = "config.json"):
|
||||
self.config_path = Path(config_path)
|
||||
self.config = self._load_config()
|
||||
logger.info(f"配置管理器初始化完成,配置文件: {config_path}")
|
||||
|
||||
def _load_config(self) -> Dict[str, Any]:
|
||||
"""加载配置文件"""
|
||||
if self.config_path.exists():
|
||||
try:
|
||||
logger.info(f"从文件加载配置: {self.config_path}")
|
||||
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||||
loaded_config = json.load(f)
|
||||
# 合并默认配置,确保所有键都存在
|
||||
return self._merge_config(self.DEFAULT_CONFIG, loaded_config)
|
||||
merged = self._merge_config(self.DEFAULT_CONFIG, loaded_config)
|
||||
logger.info("配置加载成功")
|
||||
return merged
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
print(f"配置文件加载失败,使用默认配置: {e}")
|
||||
logger.error(f"配置文件加载失败,使用默认配置: {e}")
|
||||
return self.DEFAULT_CONFIG.copy()
|
||||
else:
|
||||
logger.warning(f"配置文件不存在: {self.config_path},使用默认配置")
|
||||
return self.DEFAULT_CONFIG.copy()
|
||||
|
||||
def _merge_config(self, default: Dict, loaded: Dict) -> Dict:
|
||||
@@ -74,11 +80,13 @@ class ConfigManager:
|
||||
def save_config(self) -> bool:
|
||||
"""保存配置到文件"""
|
||||
try:
|
||||
logger.debug(f"保存配置到文件: {self.config_path}")
|
||||
with open(self.config_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(self.config, f, ensure_ascii=False, indent=4)
|
||||
logger.info("配置保存成功")
|
||||
return True
|
||||
except IOError as e:
|
||||
print(f"配置保存失败: {e}")
|
||||
logger.error(f"配置保存失败: {e}")
|
||||
return False
|
||||
|
||||
def get(self, *keys: str, default: Any = None) -> Any:
|
||||
@@ -118,6 +126,7 @@ class ConfigManager:
|
||||
self.config["llm_api"]["timeout"] = timeout
|
||||
if retry_times:
|
||||
self.config["llm_api"]["retry_times"] = retry_times
|
||||
logger.info("LLM API配置已更新")
|
||||
self.save_config()
|
||||
|
||||
def update_spider(self, target_url: str = None, xpath: str = None,
|
||||
@@ -136,6 +145,7 @@ class ConfigManager:
|
||||
self.config["spider"]["retry_times"] = retry_times
|
||||
if retry_interval:
|
||||
self.config["spider"]["retry_interval"] = retry_interval
|
||||
logger.info("爬虫配置已更新")
|
||||
self.save_config()
|
||||
|
||||
def update_ui(self, opacity: float = None, is_on_top: bool = None,
|
||||
@@ -149,6 +159,7 @@ class ConfigManager:
|
||||
self.config["ui"]["thresholds"]["cold"] = cold_threshold
|
||||
if warm_threshold is not None:
|
||||
self.config["ui"]["thresholds"]["warm"] = warm_threshold
|
||||
logger.info("UI配置已更新")
|
||||
self.save_config()
|
||||
|
||||
@property
|
||||
|
||||
48
database.py
48
database.py
@@ -7,6 +7,7 @@ import json
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
@@ -15,9 +16,11 @@ class DatabaseManager:
|
||||
def __init__(self, db_path: str = "guba.db"):
|
||||
self.db_path = Path(db_path)
|
||||
self._init_db()
|
||||
logger.info(f"数据库管理器初始化完成,数据库路径: {db_path}")
|
||||
|
||||
def _init_db(self):
|
||||
"""初始化数据库表"""
|
||||
logger.debug("初始化数据库表")
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
@@ -59,6 +62,7 @@ class DatabaseManager:
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
logger.debug("数据库表初始化完成")
|
||||
|
||||
def _get_connection(self) -> sqlite3.Connection:
|
||||
"""获取数据库连接"""
|
||||
@@ -83,6 +87,7 @@ class DatabaseManager:
|
||||
content_hash = self.hash_content(content)
|
||||
|
||||
if self.is_comment_exists(content_hash):
|
||||
logger.debug(f"评论已存在,跳过: {content[:30]}...")
|
||||
return None # 已存在
|
||||
|
||||
conn = self._get_connection()
|
||||
@@ -94,6 +99,7 @@ class DatabaseManager:
|
||||
comment_id = cursor.lastrowid
|
||||
conn.commit()
|
||||
conn.close()
|
||||
logger.info(f"添加新评论,ID: {comment_id}")
|
||||
return comment_id
|
||||
|
||||
def add_comments_batch(self, comments: List[Dict]) -> List[int]:
|
||||
@@ -108,16 +114,23 @@ class DatabaseManager:
|
||||
content_hash = self.hash_content(content)
|
||||
|
||||
if self.is_comment_exists(content_hash):
|
||||
logger.debug(f"评论已存在,跳过: {content[:30]}...")
|
||||
continue
|
||||
|
||||
try:
|
||||
cursor.execute('''
|
||||
INSERT INTO comments (content, content_hash, url, created_at)
|
||||
INSERT OR IGNORE INTO comments (content, content_hash, url, created_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
''', (content, content_hash, url, datetime.now().isoformat()))
|
||||
if cursor.rowcount > 0:
|
||||
new_ids.append(cursor.lastrowid)
|
||||
except Exception as e:
|
||||
logger.warning(f"插入评论失败(可能已存在): {e}")
|
||||
continue
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
logger.info(f"批量添加评论完成,新增 {len(new_ids)} 条")
|
||||
return new_ids
|
||||
|
||||
def get_unanalyzed_comments(self, limit: int = 50) -> List[Dict]:
|
||||
@@ -132,7 +145,9 @@ class DatabaseManager:
|
||||
''', (limit,))
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
return [{'id': row[0], 'content': row[1], 'url': row[2]} for row in rows]
|
||||
result = [{'id': row[0], 'content': row[1], 'url': row[2]} for row in rows]
|
||||
logger.debug(f"获取到 {len(result)} 条未分析评论")
|
||||
return result
|
||||
|
||||
def mark_analyzed(self, comment_id: int, sentiment_score: float, analysis_text: str):
|
||||
"""标记评论已分析"""
|
||||
@@ -154,6 +169,7 @@ class DatabaseManager:
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
logger.debug(f"标记评论 {comment_id} 已分析,分数: {sentiment_score}")
|
||||
|
||||
def get_latest_sentiment_score(self) -> Optional[float]:
|
||||
"""获取最新的情感分数"""
|
||||
@@ -167,20 +183,34 @@ class DatabaseManager:
|
||||
''')
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
return row[0] if row else None
|
||||
score = row[0] if row else None
|
||||
logger.debug(f"最新情感分数: {score}")
|
||||
return score
|
||||
|
||||
def get_all_scores(self) -> List[float]:
|
||||
"""获取所有已分析的分数"""
|
||||
def get_all_scores(self, limit: int = None) -> List[float]:
|
||||
"""获取已分析的分数,可指定数量限制"""
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
if limit:
|
||||
cursor.execute('''
|
||||
SELECT sentiment_score FROM comments
|
||||
WHERE analyzed = 1 AND sentiment_score IS NOT NULL
|
||||
ORDER BY analyzed_at DESC
|
||||
LIMIT ?
|
||||
''', (limit,))
|
||||
else:
|
||||
cursor.execute('''
|
||||
SELECT sentiment_score FROM comments
|
||||
WHERE analyzed = 1 AND sentiment_score IS NOT NULL
|
||||
ORDER BY analyzed_at DESC
|
||||
''')
|
||||
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
return [row[0] for row in rows if row[0] is not None]
|
||||
scores = [row[0] for row in rows if row[0] is not None]
|
||||
logger.debug(f"获取到 {len(scores)} 个分数")
|
||||
return scores
|
||||
|
||||
def get_comment_count(self) -> int:
|
||||
"""获取评论总数"""
|
||||
@@ -189,6 +219,7 @@ class DatabaseManager:
|
||||
cursor.execute('SELECT COUNT(*) FROM comments')
|
||||
count = cursor.fetchone()[0]
|
||||
conn.close()
|
||||
logger.debug(f"评论总数: {count}")
|
||||
return count
|
||||
|
||||
def get_analyzed_count(self) -> int:
|
||||
@@ -198,6 +229,7 @@ class DatabaseManager:
|
||||
cursor.execute('SELECT COUNT(*) FROM comments WHERE analyzed = 1')
|
||||
count = cursor.fetchone()[0]
|
||||
conn.close()
|
||||
logger.debug(f"已分析评论数: {count}")
|
||||
return count
|
||||
|
||||
def get_recent_comments(self, limit: int = 10) -> List[Dict]:
|
||||
@@ -212,8 +244,10 @@ class DatabaseManager:
|
||||
''', (limit,))
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
return [
|
||||
result = [
|
||||
{'id': row[0], 'content': row[1][:50] + '...' if len(row[1]) > 50 else row[1],
|
||||
'score': row[2], 'analyzed_at': row[3]}
|
||||
for row in rows
|
||||
]
|
||||
logger.debug(f"获取到 {len(result)} 条最近评论")
|
||||
return result
|
||||
|
||||
BIN
indicator.ico
Normal file
BIN
indicator.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 4.2 KiB |
101
llm_analyzer.py
101
llm_analyzer.py
@@ -1,11 +1,13 @@
|
||||
"""
|
||||
大模型分析模块 - 调用LLM API分析评论情感
|
||||
支持 OpenAI 兼容 API,包括 NVIDIA API
|
||||
"""
|
||||
import json
|
||||
import time
|
||||
import re
|
||||
from typing import Dict, Optional, Tuple
|
||||
from openai import OpenAI, OpenAIError
|
||||
from typing import Dict, Optional, Tuple, Any
|
||||
from openai import OpenAI
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class LLMAnalyzer:
|
||||
@@ -31,26 +33,34 @@ class LLMAnalyzer:
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
self.config = config
|
||||
self.base_url = config.get('base_url', 'https://api.openai.com/v1')
|
||||
self.base_url = config.get('base_url', '')
|
||||
self.api_key = config.get('api_key', '')
|
||||
self.model = config.get('model', 'gpt-3.5-turbo')
|
||||
self.timeout = config.get('timeout', 30)
|
||||
self.model = config.get('model', '')
|
||||
self.timeout = config.get('timeout', 120)
|
||||
self.retry_times = config.get('retry_times', 3)
|
||||
|
||||
self.client = None
|
||||
if self.api_key:
|
||||
self.last_result = None # 保存最后一次分析结果
|
||||
|
||||
logger.info(f"LLM分析器配置 - base_url: {self.base_url}, model: {self.model}, timeout: {self.timeout}s, retry: {self.retry_times}次")
|
||||
|
||||
if self.base_url and self.api_key:
|
||||
self._init_client()
|
||||
else:
|
||||
logger.warning("LLM API 未配置,base_url 或 api_key 为空")
|
||||
|
||||
def _init_client(self):
|
||||
"""初始化OpenAI客户端"""
|
||||
try:
|
||||
logger.info(f"初始化LLM客户端: {self.base_url}")
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
timeout=self.timeout
|
||||
)
|
||||
logger.info("LLM客户端初始化成功")
|
||||
except Exception as e:
|
||||
print(f"初始化LLM客户端失败: {e}")
|
||||
logger.error(f"初始化LLM客户端失败: {e}")
|
||||
|
||||
def update_config(self, config: Dict):
|
||||
"""更新配置"""
|
||||
@@ -61,7 +71,7 @@ class LLMAnalyzer:
|
||||
self.timeout = config.get('timeout', self.timeout)
|
||||
self.retry_times = config.get('retry_times', self.retry_times)
|
||||
|
||||
if self.api_key:
|
||||
if self.base_url and self.api_key:
|
||||
self._init_client()
|
||||
|
||||
def analyze(self, comment: str) -> Tuple[Optional[int], Optional[str]]:
|
||||
@@ -70,13 +80,21 @@ class LLMAnalyzer:
|
||||
返回 (score, label)
|
||||
"""
|
||||
if not self.client:
|
||||
logger.error("LLM客户端未初始化,请检查API配置")
|
||||
return None, "LLM未配置"
|
||||
|
||||
if not comment or not comment.strip():
|
||||
logger.warning("评论内容为空")
|
||||
return None, "评论为空"
|
||||
|
||||
logger.debug(f"开始分析评论: {comment[:50]}...")
|
||||
logger.debug(f"使用模型: {self.model}, 超时设置: {self.timeout}秒")
|
||||
|
||||
for attempt in range(self.retry_times):
|
||||
try:
|
||||
logger.info(f"API调用尝试 {attempt + 1}/{self.retry_times}")
|
||||
|
||||
logger.debug(f"发送请求到 {self.base_url}")
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
@@ -84,23 +102,45 @@ class LLMAnalyzer:
|
||||
{"role": "user", "content": f"请分析以下评论的情感倾向:\n\n{comment}"}
|
||||
],
|
||||
temperature=0.3,
|
||||
max_tokens=200
|
||||
max_tokens=500,
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
result_text = response.choices[0].message.content.strip()
|
||||
# 处理 deepseek-r1 的特殊结构(可能有 reasoning_content)
|
||||
message = response.choices[0].message
|
||||
|
||||
# 获取推理过程(如果有)
|
||||
reasoning = getattr(message, 'reasoning_content', None)
|
||||
if reasoning:
|
||||
logger.debug(f"推理过程: {reasoning[:100]}...")
|
||||
|
||||
# 获取最终回答
|
||||
result_text = message.content.strip() if message.content else ""
|
||||
logger.debug(f"API返回原始内容: {result_text[:100]}...")
|
||||
|
||||
score, label = self._parse_response(result_text)
|
||||
|
||||
# 保存最后结果
|
||||
self.last_result = {
|
||||
'score': score,
|
||||
'label': label,
|
||||
'reasoning': reasoning,
|
||||
'raw_response': result_text
|
||||
}
|
||||
|
||||
if score is not None:
|
||||
logger.info(f"分析完成: {score}分 - {label}")
|
||||
return score, label
|
||||
|
||||
except OpenAIError as e:
|
||||
print(f"API调用失败 (尝试 {attempt + 1}/{self.retry_times}): {e}")
|
||||
if attempt < self.retry_times - 1:
|
||||
time.sleep(2 ** attempt) # 指数退避
|
||||
except Exception as e:
|
||||
print(f"分析过程出错: {e}")
|
||||
break
|
||||
logger.warning(f"API调用失败 (尝试 {attempt + 1}/{self.retry_times}): {type(e).__name__}: {e}")
|
||||
logger.debug(f"错误详情: {str(e)}")
|
||||
if attempt < self.retry_times - 1:
|
||||
wait_time = 2 ** attempt
|
||||
logger.info(f"等待 {wait_time} 秒后重试...")
|
||||
time.sleep(wait_time) # 指数退避
|
||||
|
||||
logger.error(f"所有 {self.retry_times} 次重试均失败")
|
||||
return None, "分析失败"
|
||||
|
||||
def _parse_response(self, response: str) -> Tuple[Optional[int], Optional[str]]:
|
||||
@@ -113,40 +153,57 @@ class LLMAnalyzer:
|
||||
|
||||
# 验证分数范围
|
||||
score = max(0, min(100, int(score)))
|
||||
logger.debug(f"JSON解析成功: {score} - {label}")
|
||||
|
||||
return score, label
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# 尝试从文本中提取
|
||||
pass
|
||||
logger.debug("JSON解析失败,尝试文本提取")
|
||||
|
||||
# 尝试从文本中提取数字
|
||||
# 尝试从文本中提取
|
||||
numbers = re.findall(r'\b(\d{1,3})\b', response)
|
||||
if numbers:
|
||||
score = int(numbers[0])
|
||||
score = max(0, min(100, score))
|
||||
|
||||
# 提取标签
|
||||
label_match = re.search(r'["']([^"']+)["']', response)
|
||||
label_match = re.search(r'"([^"]+)"', response)
|
||||
if label_match:
|
||||
label = label_match.group(1)
|
||||
else:
|
||||
label = response.split('\n')[0][:20] if response else '无法判断'
|
||||
|
||||
logger.debug(f"文本提取成功: {score} - {label}")
|
||||
return score, label
|
||||
|
||||
logger.warning("无法解析响应")
|
||||
return None, "解析失败"
|
||||
|
||||
def get_last_result(self) -> Optional[Dict[str, Any]]:
|
||||
"""获取最后一次分析结果"""
|
||||
return self.last_result
|
||||
|
||||
def analyze_batch(self, comments: list, delay: float = 1.0) -> list:
|
||||
"""
|
||||
批量分析评论
|
||||
delay: 每次调用之间的延迟(秒)
|
||||
"""
|
||||
logger.info(f"开始批量分析 {len(comments)} 条评论,每次间隔 {delay} 秒")
|
||||
results = []
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
|
||||
for i, comment in enumerate(comments):
|
||||
print(f"分析评论 {i + 1}/{len(comments)}...")
|
||||
logger.info(f"正在分析第 {i + 1}/{len(comments)} 条评论")
|
||||
score, label = self.analyze(comment)
|
||||
|
||||
if score is not None:
|
||||
success_count += 1
|
||||
logger.debug(f"第 {i + 1} 条评论分析成功: {score}分 - {label}")
|
||||
else:
|
||||
fail_count += 1
|
||||
logger.warning(f"第 {i + 1} 条评论分析失败: {label}")
|
||||
|
||||
results.append({
|
||||
'content': comment,
|
||||
'score': score,
|
||||
@@ -154,8 +211,10 @@ class LLMAnalyzer:
|
||||
})
|
||||
|
||||
if delay > 0 and i < len(comments) - 1:
|
||||
logger.debug(f"等待 {delay} 秒后继续...")
|
||||
time.sleep(delay)
|
||||
|
||||
logger.info(f"批量分析完成,成功 {success_count} 条,失败 {fail_count} 条")
|
||||
return results
|
||||
|
||||
def is_configured(self) -> bool:
|
||||
|
||||
179
main.py
179
main.py
@@ -2,11 +2,11 @@
|
||||
主程序入口 - 股吧人气指示器
|
||||
"""
|
||||
import sys
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
from PySide6.QtWidgets import QApplication
|
||||
from PySide6.QtCore import QTimer, Signal, QObject
|
||||
from PySide6.QtCore import QTimer, Signal, QObject, QThread
|
||||
from loguru import logger
|
||||
|
||||
from config_manager import ConfigManager
|
||||
from database import DatabaseManager
|
||||
@@ -20,8 +20,10 @@ class BackendWorker(QObject):
|
||||
|
||||
fetch_finished = Signal(list)
|
||||
analysis_finished = Signal(float)
|
||||
analysis_result = Signal(dict) # 传递分析结果详情
|
||||
error_occurred = Signal(str)
|
||||
status_update = Signal(str)
|
||||
stock_data_fetched = Signal(str, float) # 股票数据获取信号
|
||||
|
||||
def __init__(self, config_manager: ConfigManager, db_manager: DatabaseManager,
|
||||
spider: SpiderManager, analyzer: LLMAnalyzer):
|
||||
@@ -32,92 +34,136 @@ class BackendWorker(QObject):
|
||||
self.analyzer = analyzer
|
||||
self.running = False
|
||||
self.last_fetch_time = 0
|
||||
self.fetch_interval = 60 # 默认60秒
|
||||
self.fetch_interval = 15 # 默认15秒
|
||||
self.no_new_content_count = 0 # 无新内容计数
|
||||
self.is_running_cycle = False # 防止并发执行
|
||||
self.fetch_count = 0 # 爬取次数统计
|
||||
self.analysis_count = 0 # API分析次数统计
|
||||
logger.info("BackendWorker 初始化完成")
|
||||
|
||||
def start(self):
|
||||
"""启动后台任务"""
|
||||
self.running = True
|
||||
self._run_cycle()
|
||||
logger.info("后台任务已启动")
|
||||
# 启动时立即执行第一次任务
|
||||
QTimer.singleShot(1000, self._run_cycle)
|
||||
|
||||
def stop(self):
|
||||
"""停止后台任务"""
|
||||
self.running = False
|
||||
logger.info("后台任务已停止")
|
||||
|
||||
def _run_cycle(self):
|
||||
"""运行一个周期"""
|
||||
if self.is_running_cycle:
|
||||
logger.debug("上一个周期仍在执行,跳过本次")
|
||||
return
|
||||
|
||||
self.is_running_cycle = True
|
||||
|
||||
if not self.running:
|
||||
self.is_running_cycle = False
|
||||
return
|
||||
|
||||
try:
|
||||
# 1. 爬取评论
|
||||
logger.info("开始爬取评论")
|
||||
self.status_update.emit("正在爬取评论...")
|
||||
self.fetch_count += 1
|
||||
comments = self.spider.fetch()
|
||||
|
||||
if not comments:
|
||||
self.no_new_content_count += 1
|
||||
interval = self.fetch_interval * (1 + min(self.no_new_content_count * 0.5, 2))
|
||||
# 爬取失败,使用固定间隔重试,不计入无新内容计数
|
||||
interval = 15 # 固定15秒重试
|
||||
logger.warning(f"未获取到新评论,{int(interval)}秒后重试")
|
||||
self.status_update.emit(f"无新内容,{int(interval)}秒后重试...")
|
||||
self.is_running_cycle = False
|
||||
QTimer.singleShot(int(interval * 1000), self._run_cycle)
|
||||
return
|
||||
|
||||
# 2. 写入数据库
|
||||
logger.info(f"获取到 {len(comments)} 条评论")
|
||||
self.status_update.emit(f"获取到 {len(comments)} 条评论...")
|
||||
new_ids = self.db.add_comments_batch(comments)
|
||||
|
||||
if new_ids:
|
||||
self.no_new_content_count = 0
|
||||
logger.info(f"新增 {len(new_ids)} 条评论到数据库")
|
||||
self.status_update.emit(f"新增 {len(new_ids)} 条评论")
|
||||
|
||||
# 3. 获取未分析评论并分析
|
||||
unanalyzed = self.db.get_unanalyzed_comments(limit=10)
|
||||
logger.info(f"获取到 {len(unanalyzed)} 条未分析评论")
|
||||
|
||||
if unanalyzed:
|
||||
self.status_update.emit(f"开始分析 {len(unanalyzed)} 条评论...")
|
||||
self._analyze_comments(unanalyzed)
|
||||
else:
|
||||
self.no_new_content_count += 1
|
||||
logger.info("评论已存在,未新增")
|
||||
|
||||
# 4. 更新指示器
|
||||
self._update_indicator()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"运行错误: {str(e)}")
|
||||
self.error_occurred.emit(f"运行错误: {str(e)}")
|
||||
|
||||
# 安排下一次执行
|
||||
if self.running:
|
||||
interval = self.fetch_interval * (1 + min(self.no_new_content_count * 0.5, 2))
|
||||
interval = self.fetch_interval * (1 + min(self.no_new_content_count * 1.0, 4.0))
|
||||
logger.debug(f"下次执行将在 {int(interval)} 秒后")
|
||||
self.is_running_cycle = False
|
||||
QTimer.singleShot(int(interval * 1000), self._run_cycle)
|
||||
else:
|
||||
self.is_running_cycle = False
|
||||
|
||||
def _analyze_comments(self, comments):
|
||||
"""分析评论"""
|
||||
logger.info(f"开始分析 {len(comments)} 条评论")
|
||||
for i, comment in enumerate(comments):
|
||||
if not self.running:
|
||||
logger.warning("分析被中断")
|
||||
break
|
||||
|
||||
try:
|
||||
self.status_update.emit(f"分析 {i+1}/{len(comments)}...")
|
||||
logger.debug(f"分析第 {i+1} 条评论: {comment['content'][:50]}...")
|
||||
self.analysis_count += 1
|
||||
score, label = self.analyzer.analyze(comment['content'])
|
||||
|
||||
# 获取分析结果详情
|
||||
last_result = self.analyzer.get_last_result()
|
||||
|
||||
if score is not None:
|
||||
self.db.mark_analyzed(comment['id'], score, label)
|
||||
logger.info(f"评论 {comment['id']} 分析完成: {score}分 - {label}")
|
||||
# 更新状态显示为简洁格式
|
||||
self.status_update.emit(f"分析 {i+1}/{len(comments)}...返回{score}分")
|
||||
# 每条评论分析完成后立即更新指示器
|
||||
self._update_indicator()
|
||||
time.sleep(1.0) # 延迟,避免API限流
|
||||
else:
|
||||
self.db.mark_analyzed(comment['id'], 50, "无法判断")
|
||||
logger.warning(f"评论 {comment['id']} 无法判断")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"分析评论 {comment.get('id', 'unknown')} 失败: {str(e)}")
|
||||
self.error_occurred.emit(f"分析失败: {str(e)}")
|
||||
self.db.mark_analyzed(comment['id'], 50, "分析异常")
|
||||
|
||||
def _update_indicator(self):
|
||||
"""更新指示器显示"""
|
||||
scores = self.db.get_all_scores()
|
||||
# 获取最新的100条评论的分数
|
||||
scores = self.db.get_all_scores(limit=100)
|
||||
|
||||
if not scores:
|
||||
logger.debug("暂无分析分数")
|
||||
return
|
||||
|
||||
# 计算平均分
|
||||
avg_score = sum(scores) / len(scores)
|
||||
logger.info(f"当前平均分: {avg_score:.2f} (基于最新的 {len(scores)} 条评论)")
|
||||
|
||||
# 根据阈值确定标签
|
||||
thresholds = self.config.get('ui', 'thresholds', default={'cold': 30, 'warm': 70})
|
||||
@@ -131,68 +177,161 @@ class BackendWorker(QObject):
|
||||
else:
|
||||
label = "中性"
|
||||
|
||||
self.analysis_finished.emit(avg_score)
|
||||
logger.info(f"情感倾向: {label}")
|
||||
# 发送整数分数
|
||||
self.analysis_finished.emit(int(avg_score))
|
||||
|
||||
def fetch_stock_data(self):
|
||||
"""爬取股票数据并添加到波形图"""
|
||||
try:
|
||||
logger.info("开始爬取股票数据")
|
||||
stock_data = self.spider.fetch_sse_stock_data()
|
||||
|
||||
if stock_data and 'time' in stock_data and 'value' in stock_data:
|
||||
logger.info(f"成功获取股票数据: {stock_data}")
|
||||
# 发送股票数据到主窗口
|
||||
self.stock_data_fetched.emit(stock_data['time'], stock_data['value'])
|
||||
else:
|
||||
logger.warning("未能获取有效的股票数据")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"爬取股票数据失败: {str(e)}")
|
||||
|
||||
def manual_refresh(self):
|
||||
"""手动刷新"""
|
||||
logger.info("用户手动刷新")
|
||||
self.no_new_content_count = 0
|
||||
if not self.is_running_cycle:
|
||||
self._run_cycle()
|
||||
else:
|
||||
logger.info("上一个周期仍在执行,跳过手动刷新")
|
||||
|
||||
def update_fetch_interval(self, interval: int):
|
||||
"""更新爬取间隔"""
|
||||
logger.info(f"更新爬取间隔: {interval} 秒")
|
||||
self.fetch_interval = interval
|
||||
|
||||
|
||||
def setup_logging(log_path: str, level: str = "INFO"):
|
||||
"""配置日志"""
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, level.upper(), logging.INFO),
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler(log_path, encoding='utf-8'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
logger.remove() # 移除默认的处理器
|
||||
|
||||
logger.add(
|
||||
log_path,
|
||||
rotation="10 MB",
|
||||
retention="7 days",
|
||||
level=level,
|
||||
encoding="utf-8",
|
||||
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}"
|
||||
)
|
||||
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
level=level,
|
||||
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
logger.info("=== 股吧人气指示器启动 ===")
|
||||
|
||||
# 创建应用
|
||||
app = QApplication(sys.argv)
|
||||
app.setQuitOnLastWindowClosed(False) # 允许最小化到托盘
|
||||
|
||||
# 加载配置
|
||||
logger.info("加载配置文件...")
|
||||
config = ConfigManager("config.json")
|
||||
|
||||
# 配置日志
|
||||
log_config = config.logging_config
|
||||
setup_logging(log_config.get('path', 'guba.log'), log_config.get('level', 'INFO'))
|
||||
logger.info(f"日志配置完成: {log_config.get('path', 'guba.log')}, 级别: {log_config.get('level', 'INFO')}")
|
||||
|
||||
# 初始化组件
|
||||
logger.info("初始化数据库...")
|
||||
db = DatabaseManager(config.database_config.get('path', 'guba.db'))
|
||||
|
||||
logger.info("初始化爬虫...")
|
||||
spider = SpiderManager(config.spider_config)
|
||||
|
||||
logger.info("初始化LLM分析器...")
|
||||
analyzer = LLMAnalyzer(config.llm_api_config)
|
||||
|
||||
# 创建后台工作器
|
||||
worker = BackendWorker(config, db, spider, analyzer)
|
||||
worker.update_fetch_interval(config.spider_config.get('fetch_interval', 60))
|
||||
worker.update_fetch_interval(config.spider_config.get('fetch_interval', 15))
|
||||
|
||||
# 创建后台线程
|
||||
logger.info("创建后台线程...")
|
||||
worker_thread = QThread()
|
||||
worker.moveToThread(worker_thread)
|
||||
|
||||
# 线程启动时开始工作
|
||||
worker_thread.started.connect(worker.start)
|
||||
# 线程结束时停止工作
|
||||
worker_thread.finished.connect(worker.stop)
|
||||
|
||||
# 创建主窗口
|
||||
window = MainWindow(config)
|
||||
logger.info("创建主窗口...")
|
||||
window = MainWindow(config, spider)
|
||||
window.show()
|
||||
|
||||
# 连接信号
|
||||
worker.status_update.connect(window.update_status)
|
||||
worker.analysis_finished.connect(window.update_indicator)
|
||||
worker.error_occurred.connect(lambda msg: window.show_message("错误", msg))
|
||||
worker.stock_data_fetched.connect(window.add_waveform_data)
|
||||
|
||||
# 启动时从数据库初始化指示器显示
|
||||
worker._update_indicator()
|
||||
logger.info("初始化指示器显示完成")
|
||||
|
||||
# 设置按钮回调
|
||||
window.set_refresh_callback(worker.manual_refresh)
|
||||
window.set_config_callback(window.show_config)
|
||||
|
||||
# 启动后台任务
|
||||
worker.start()
|
||||
# 启动后台线程
|
||||
worker_thread.start()
|
||||
logger.info("后台线程已启动")
|
||||
|
||||
# 启动股票数据爬取定时器
|
||||
stock_timer = QTimer()
|
||||
stock_timer.timeout.connect(worker.fetch_stock_data)
|
||||
stock_timer.start(60000) # 每分钟爬取一次股票数据
|
||||
logger.info("股票数据爬取定时器已启动,间隔60秒")
|
||||
|
||||
# 确保应用退出时清理线程
|
||||
def cleanup():
|
||||
logger.info("清理资源,停止后台线程...")
|
||||
worker.stop()
|
||||
worker_thread.quit()
|
||||
worker_thread.wait()
|
||||
|
||||
# 显示统计信息
|
||||
logger.info("=== 程序运行统计 ===")
|
||||
logger.info(f"爬取网站次数: {worker.fetch_count}")
|
||||
logger.info(f"提交API分析次数: {worker.analysis_count}")
|
||||
logger.info("=== 统计结束 ===")
|
||||
|
||||
# 写入统计信息到文件
|
||||
try:
|
||||
stats_file = "statistics.txt"
|
||||
with open(stats_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"=== 程序运行统计 - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ===\n")
|
||||
f.write(f"爬取网站次数: {worker.fetch_count}\n")
|
||||
f.write(f"提交API分析次数: {worker.analysis_count}\n")
|
||||
f.write("=== 统计结束 ===\n\n")
|
||||
logger.info(f"统计信息已写入到文件: {stats_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"写入统计文件失败: {str(e)}")
|
||||
|
||||
logger.info("后台线程已停止")
|
||||
|
||||
app.aboutToQuit.connect(cleanup)
|
||||
|
||||
logger.info("应用启动完成,进入主循环")
|
||||
# 运行应用
|
||||
sys.exit(app.exec())
|
||||
|
||||
|
||||
117
main_window.py
117
main_window.py
@@ -4,10 +4,13 @@ PySide6 GUI界面模块
|
||||
from PySide6.QtWidgets import (QWidget, QVBoxLayout, QHBoxLayout, QLabel,
|
||||
QPushButton, QSlider, QDialog, QFormLayout,
|
||||
QLineEdit, QSpinBox, QMessageBox, QSystemTrayIcon,
|
||||
QMenu, QTextEdit, QGroupBox, QDialogButtonBox)
|
||||
QMenu, QTextEdit, QGroupBox, QDialogButtonBox, QCheckBox)
|
||||
from PySide6.QtCore import Qt, QTimer, Signal, QPoint
|
||||
from PySide6.QtGui import QFont, QColor, QPainter, QBrush, QPen, QIcon, QAction
|
||||
from typing import Callable, Optional
|
||||
from loguru import logger
|
||||
|
||||
from waveform_widget import WaveformWidget
|
||||
|
||||
|
||||
class SentimentIndicator(QWidget):
|
||||
@@ -57,22 +60,38 @@ class SentimentIndicator(QWidget):
|
||||
|
||||
def _get_color(self, score: int) -> QColor:
|
||||
"""根据分数获取颜色"""
|
||||
if score < 30:
|
||||
# 冷色系 - 蓝色/青色
|
||||
ratio = score / 30
|
||||
return QColor(int(0 + 100 * ratio), int(150 + 50 * ratio), 255)
|
||||
elif score < 70:
|
||||
# 中性 - 灰色/绿色
|
||||
if score < 50:
|
||||
ratio = (score - 30) / 20
|
||||
return QColor(int(100 + 50 * ratio), int(200 + 20 * ratio), int(200 - 50 * ratio))
|
||||
# 45-55分区间:每1分一个颜色(从绿色到黄色渐变)
|
||||
if 45 <= score <= 55:
|
||||
# 在45-55分之间,每1分一个颜色
|
||||
ratio = (score - 45) / 10 # 0到1之间的比例
|
||||
# 从绿色(0, 200, 0)渐变到黄色(255, 255, 0)
|
||||
r = int(0 + 255 * ratio)
|
||||
g = 200 if ratio < 0.5 else int(200 + 55 * (ratio - 0.5) * 2)
|
||||
b = int(0 + 0 * ratio)
|
||||
return QColor(r, g, b)
|
||||
|
||||
# 45分以下:每5分一个颜色(从深蓝到绿色渐变)
|
||||
elif score < 45:
|
||||
# 将分数映射到0-8的区间(0-40分,每5分一个级别)
|
||||
level = score // 5
|
||||
# 从深蓝色(0, 100, 255)渐变到绿色(0, 200, 0)
|
||||
ratio = level / 8 # 0-8对应0-40分
|
||||
r = int(0 + 0 * ratio)
|
||||
g = int(100 + 100 * ratio)
|
||||
b = int(255 - 255 * ratio)
|
||||
return QColor(r, g, b)
|
||||
|
||||
# 55分以上:每5分一个颜色(从黄色到红色渐变)
|
||||
else:
|
||||
ratio = (score - 50) / 20
|
||||
return QColor(int(150 + 50 * ratio), int(220 - 20 * ratio), int(150 - 50 * ratio))
|
||||
else:
|
||||
# 暖色系 - 橙色/红色
|
||||
ratio = (score - 70) / 30
|
||||
return QColor(255, int(200 - 100 * ratio), int(50 + 50 * ratio))
|
||||
# 将分数映射到0-8的区间(60-100分,每5分一个级别)
|
||||
level = (score - 60) // 5
|
||||
level = max(0, min(8, level)) # 限制在0-8范围内
|
||||
# 从黄色(255, 255, 0)渐变到红色(255, 0, 0)
|
||||
ratio = level / 8
|
||||
r = 255
|
||||
g = int(255 - 255 * ratio)
|
||||
b = 0
|
||||
return QColor(r, g, b)
|
||||
|
||||
def get_description(self, score: int) -> str:
|
||||
"""获取描述文本"""
|
||||
@@ -125,7 +144,7 @@ class ConfigDialog(QDialog):
|
||||
self.user_agent_edit = QLineEdit(spider_config.get('user_agent', ''))
|
||||
self.interval_spin = QSpinBox()
|
||||
self.interval_spin.setRange(10, 3600)
|
||||
self.interval_spin.setValue(spider_config.get('fetch_interval', 60))
|
||||
self.interval_spin.setValue(spider_config.get('fetch_interval', 15))
|
||||
|
||||
layout.addRow("目标URL:", self.url_edit)
|
||||
layout.addRow("XPath:", self.xpath_edit)
|
||||
@@ -197,10 +216,14 @@ class ConfigDialog(QDialog):
|
||||
class MainWindow(QWidget):
|
||||
"""主窗口"""
|
||||
|
||||
def __init__(self, config_manager, parent=None):
|
||||
def __init__(self, config_manager, spider_manager=None, parent=None):
|
||||
super().__init__(parent)
|
||||
self.config_manager = config_manager
|
||||
self.setWindowTitle("股吧人气指示器")
|
||||
self.spider_manager = spider_manager
|
||||
|
||||
# 获取页面标题并设置窗口标题
|
||||
self._set_window_title()
|
||||
|
||||
self._init_ui()
|
||||
self._apply_config()
|
||||
|
||||
@@ -217,7 +240,7 @@ class MainWindow(QWidget):
|
||||
layout.setContentsMargins(10, 10, 10, 10)
|
||||
|
||||
# 标题
|
||||
self.title_label = QLabel("股吧人气")
|
||||
self.title_label = QLabel("上证指数sh000001")
|
||||
self.title_label.setAlignment(Qt.AlignCenter)
|
||||
title_font = QFont()
|
||||
title_font.setPointSize(14)
|
||||
@@ -236,26 +259,64 @@ class MainWindow(QWidget):
|
||||
status_font.setPointSize(10)
|
||||
self.status_label.setFont(status_font)
|
||||
|
||||
# 波形图组件
|
||||
self.waveform_widget = WaveformWidget()
|
||||
self.waveform_widget.setMinimumHeight(200)
|
||||
|
||||
# 按钮
|
||||
btn_layout = QHBoxLayout()
|
||||
self.refresh_btn = QPushButton("刷新")
|
||||
self.config_btn = QPushButton("配置")
|
||||
self.quit_btn = QPushButton("退出")
|
||||
self.quit_btn.clicked.connect(self.quit_app)
|
||||
btn_layout.addWidget(self.refresh_btn)
|
||||
btn_layout.addWidget(self.config_btn)
|
||||
btn_layout.addWidget(self.quit_btn)
|
||||
|
||||
# 添加到主布局
|
||||
layout.addWidget(self.title_label)
|
||||
layout.addWidget(self.indicator)
|
||||
layout.addWidget(self.score_label)
|
||||
layout.addWidget(self.status_label)
|
||||
layout.addWidget(self.waveform_widget)
|
||||
layout.addLayout(btn_layout)
|
||||
|
||||
# 设置窗口标志(无边框、可拖拽)
|
||||
self.setWindowFlags(Qt.FramelessWindowHint | Qt.WindowStaysOnTopHint)
|
||||
self.setAttribute(Qt.WA_TranslucentBackground)
|
||||
|
||||
def _set_window_title(self):
|
||||
"""设置窗口标题"""
|
||||
logger.debug("设置窗口标题")
|
||||
|
||||
# 尝试从爬虫获取页面标题
|
||||
if hasattr(self, 'spider_manager') and self.spider_manager:
|
||||
try:
|
||||
page_title = self.spider_manager.get_page_title()
|
||||
if page_title:
|
||||
# 从页面标题中提取股票名称
|
||||
import re
|
||||
match = re.search(r'(上证指数sh\d+)', page_title)
|
||||
if match:
|
||||
stock_name = match.group(1)
|
||||
window_title = f"冷暖值 - {stock_name}"
|
||||
self.setWindowTitle(window_title)
|
||||
|
||||
# 同时更新主标题标签(如果已初始化)
|
||||
if hasattr(self, 'title_label'):
|
||||
self.title_label.setText("上证指数sh000001")
|
||||
logger.info(f"设置窗口标题: {window_title}")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"获取页面标题失败: {e}")
|
||||
|
||||
# 如果获取失败,使用默认标题
|
||||
self.setWindowTitle("冷暖值 - 股吧人气")
|
||||
logger.info("使用默认窗口标题")
|
||||
|
||||
def _init_tray_icon(self):
|
||||
"""初始化系统托盘"""
|
||||
logger.debug("初始化系统托盘")
|
||||
self.tray_icon = QSystemTrayIcon(self)
|
||||
self.tray_icon.setToolTip("股吧人气指示器")
|
||||
|
||||
@@ -275,9 +336,11 @@ class MainWindow(QWidget):
|
||||
|
||||
self.tray_icon.setContextMenu(tray_menu)
|
||||
self.tray_icon.show()
|
||||
logger.info("系统托盘初始化完成")
|
||||
|
||||
def quit_app(self):
|
||||
"""退出应用"""
|
||||
logger.info("退出应用")
|
||||
self.close()
|
||||
import sys
|
||||
sys.exit(0)
|
||||
@@ -322,7 +385,7 @@ class MainWindow(QWidget):
|
||||
|
||||
context_menu.addAction(config_action)
|
||||
context_menu.addAction(quit_action)
|
||||
context_menu.exec(event.globalPosition().toPoint())
|
||||
context_menu.exec(event.globalPos())
|
||||
|
||||
def show_config(self):
|
||||
"""显示配置对话框"""
|
||||
@@ -336,23 +399,33 @@ class MainWindow(QWidget):
|
||||
label = self.indicator.get_description(score)
|
||||
self.indicator.set_value(score, label)
|
||||
self.score_label.setText(f"{score} - {label}")
|
||||
logger.debug(f"更新指示灯: {score}分 - {label}")
|
||||
|
||||
def update_status(self, text: str):
|
||||
"""更新状态"""
|
||||
self.status_label.setText(text)
|
||||
logger.debug(f"更新状态: {text}")
|
||||
|
||||
def set_refresh_callback(self, callback: Callable):
|
||||
"""设置刷新按钮回调"""
|
||||
self.refresh_btn.clicked.connect(callback)
|
||||
logger.debug("设置刷新按钮回调")
|
||||
|
||||
def set_config_callback(self, callback: Callable):
|
||||
"""设置配置按钮回调"""
|
||||
self.config_btn.clicked.connect(callback)
|
||||
logger.debug("设置配置按钮回调")
|
||||
|
||||
def show_message(self, title: str, message: str, icon=QMessageBox.Information):
|
||||
def show_message(self, title: str, message: str):
|
||||
"""显示消息"""
|
||||
logger.info(f"显示消息: {title} - {message}")
|
||||
QMessageBox.information(self, title, message)
|
||||
|
||||
def add_waveform_data(self, time_str: str, value: float):
|
||||
"""添加波形图数据点"""
|
||||
self.waveform_widget.add_data_point(time_str, value)
|
||||
logger.info(f"添加波形图数据点: 时间={time_str}, 值={value}")
|
||||
|
||||
|
||||
class QCheckBox(QPushButton):
|
||||
"""自定义复选框"""
|
||||
|
||||
@@ -3,3 +3,4 @@ requests>=2.31.0
|
||||
beautifulsoup4>=4.12.0
|
||||
lxml>=4.9.0
|
||||
openai>=1.0.0
|
||||
playwright>=1.40.0
|
||||
|
||||
185
screenshot_manager.py
Normal file
185
screenshot_manager.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""
|
||||
截图管理器 - 用于在非交易时间截取上海证券交易所网站图表
|
||||
"""
|
||||
import os
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
|
||||
|
||||
try:
|
||||
from playwright.sync_api import sync_playwright, TimeoutError as PlaywrightTimeoutError
|
||||
except ImportError:
|
||||
logger.warning("playwright未安装,截图功能将不可用")
|
||||
|
||||
|
||||
class ScreenshotManager:
|
||||
"""截图管理器"""
|
||||
|
||||
def __init__(self, screenshot_dir: str = "screenshots"):
|
||||
"""初始化截图管理器"""
|
||||
self.screenshot_dir = screenshot_dir
|
||||
self.target_url = "https://www.sse.com.cn/"
|
||||
self.chart_xpath_pattern = "//*[@id=\"hq_area\"]"
|
||||
|
||||
# 创建截图目录
|
||||
os.makedirs(self.screenshot_dir, exist_ok=True)
|
||||
|
||||
logger.info(f"截图管理器初始化完成,截图目录: {self.screenshot_dir}")
|
||||
|
||||
def capture_chart_screenshot(self) -> str:
|
||||
"""
|
||||
截取上海证券交易所网站图表
|
||||
返回截图文件路径,失败时返回空字符串
|
||||
"""
|
||||
try:
|
||||
# 检查playwright是否可用
|
||||
if 'sync_playwright' not in globals():
|
||||
logger.error("playwright未安装,无法使用截图功能")
|
||||
return ""
|
||||
|
||||
with sync_playwright() as p:
|
||||
# 启动浏览器(无头模式,后台运行)
|
||||
browser = p.chromium.launch(headless=True)
|
||||
page = browser.new_page()
|
||||
|
||||
# 设置页面超时
|
||||
page.set_default_timeout(60000) # 60秒超时
|
||||
|
||||
# 访问目标网页
|
||||
logger.info(f"访问上海证券交易所网站: {self.target_url}")
|
||||
page.goto(self.target_url, wait_until="domcontentloaded")
|
||||
|
||||
# 等待页面加载完成
|
||||
page.wait_for_load_state("networkidle")
|
||||
|
||||
# 等待页面完全加载
|
||||
page.wait_for_timeout(5000) # 额外等待5秒
|
||||
|
||||
# 等待图表元素出现
|
||||
logger.info("等待图表元素加载...")
|
||||
|
||||
# 使用动态XPath模式查找图表元素
|
||||
chart_element = None
|
||||
selectors = [
|
||||
self.chart_xpath_pattern, # 主要选择器:highcharts-xxxxxxx-0格式
|
||||
"//*[contains(@class, 'highcharts')]",
|
||||
"//*[contains(@id, 'highcharts')]",
|
||||
"//svg",
|
||||
"//canvas",
|
||||
"//div[contains(@class, 'chart')]",
|
||||
"//div[contains(@class, 'graph')]"
|
||||
]
|
||||
|
||||
for selector in selectors:
|
||||
try:
|
||||
# 等待选择器出现
|
||||
page.wait_for_selector(selector, timeout=10000)
|
||||
elements = page.query_selector_all(selector)
|
||||
if elements:
|
||||
# 选择第一个可见的元素
|
||||
for element in elements:
|
||||
if element.is_visible():
|
||||
chart_element = element
|
||||
logger.info(f"找到图表元素: {selector}")
|
||||
break
|
||||
if chart_element:
|
||||
break
|
||||
except PlaywrightTimeoutError:
|
||||
logger.debug(f"选择器超时: {selector}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.debug(f"选择器错误 {selector}: {e}")
|
||||
continue
|
||||
|
||||
if not chart_element:
|
||||
logger.warning("未找到任何图表元素,尝试截取整个页面")
|
||||
# 如果找不到图表元素,截取整个页面
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
screenshot_path = os.path.join(self.screenshot_dir, f"sse_page_{timestamp}.png")
|
||||
page.screenshot(path=screenshot_path)
|
||||
logger.info(f"截取整个页面: {screenshot_path}")
|
||||
browser.close()
|
||||
return screenshot_path
|
||||
|
||||
# 检查元素是否可见
|
||||
if not chart_element.is_visible():
|
||||
logger.warning("图表元素不可见,尝试滚动到元素位置")
|
||||
chart_element.scroll_into_view_if_needed()
|
||||
|
||||
# 生成截图文件名
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
screenshot_path = os.path.join(self.screenshot_dir, f"sse_chart_{timestamp}.png")
|
||||
|
||||
# 截取图表元素
|
||||
logger.info("开始截取图表元素")
|
||||
|
||||
# 直接使用元素进行截图
|
||||
chart_element.screenshot(path=screenshot_path)
|
||||
|
||||
logger.info(f"✅ 图表截图完成,保存至: {screenshot_path}")
|
||||
|
||||
# 关闭浏览器
|
||||
browser.close()
|
||||
|
||||
return screenshot_path
|
||||
|
||||
except PlaywrightTimeoutError as e:
|
||||
logger.error(f"页面加载超时: {e}")
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"截图过程中发生错误: {e}")
|
||||
return ""
|
||||
|
||||
def get_latest_screenshot(self) -> str:
|
||||
"""获取最新的截图文件路径"""
|
||||
try:
|
||||
if not os.path.exists(self.screenshot_dir):
|
||||
return ""
|
||||
|
||||
# 获取所有截图文件
|
||||
screenshot_files = []
|
||||
for file in os.listdir(self.screenshot_dir):
|
||||
if file.startswith("sse_chart_") and file.endswith(".png"):
|
||||
file_path = os.path.join(self.screenshot_dir, file)
|
||||
screenshot_files.append((file_path, os.path.getmtime(file_path)))
|
||||
|
||||
if not screenshot_files:
|
||||
return ""
|
||||
|
||||
# 按修改时间排序,获取最新的文件
|
||||
screenshot_files.sort(key=lambda x: x[1], reverse=True)
|
||||
return screenshot_files[0][0]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取最新截图失败: {e}")
|
||||
return ""
|
||||
|
||||
def cleanup_old_screenshots(self, keep_count: int = 10):
|
||||
"""清理旧的截图文件,只保留最新的几个"""
|
||||
try:
|
||||
if not os.path.exists(self.screenshot_dir):
|
||||
return
|
||||
|
||||
# 获取所有截图文件
|
||||
screenshot_files = []
|
||||
for file in os.listdir(self.screenshot_dir):
|
||||
if file.startswith("sse_chart_") and file.endswith(".png"):
|
||||
file_path = os.path.join(self.screenshot_dir, file)
|
||||
screenshot_files.append((file_path, os.path.getmtime(file_path)))
|
||||
|
||||
if len(screenshot_files) <= keep_count:
|
||||
return
|
||||
|
||||
# 按修改时间排序,删除旧的文件
|
||||
screenshot_files.sort(key=lambda x: x[1])
|
||||
files_to_delete = screenshot_files[:-keep_count]
|
||||
|
||||
for file_path, _ in files_to_delete:
|
||||
try:
|
||||
os.remove(file_path)
|
||||
logger.info(f"删除旧截图: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"删除截图文件失败 {file_path}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理旧截图失败: {e}")
|
||||
BIN
screenshots/sse_page_20260109_173924.png
Normal file
BIN
screenshots/sse_page_20260109_173924.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 531 KiB |
BIN
screenshots/sse_page_20260109_174348.png
Normal file
BIN
screenshots/sse_page_20260109_174348.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 531 KiB |
153
spider.py
153
spider.py
@@ -8,6 +8,7 @@ from typing import List, Dict, Optional
|
||||
from urllib.parse import urljoin
|
||||
from bs4 import BeautifulSoup
|
||||
import random
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class SpiderManager:
|
||||
@@ -21,6 +22,7 @@ class SpiderManager:
|
||||
})
|
||||
self.retry_times = config.get('retry_times', 3)
|
||||
self.retry_interval = config.get('retry_interval', 5)
|
||||
logger.info(f"爬虫管理器初始化完成,目标URL: {config.get('target_url', '')}")
|
||||
|
||||
def fetch(self, url: str = None, xpath: str = None) -> List[Dict]:
|
||||
"""
|
||||
@@ -31,13 +33,18 @@ class SpiderManager:
|
||||
target_xpath = xpath or self.config.get('xpath', '')
|
||||
|
||||
if not target_url:
|
||||
logger.warning("未设置目标URL")
|
||||
return []
|
||||
|
||||
logger.info(f"开始抓取: {target_url}")
|
||||
html = self._fetch_with_retry(target_url)
|
||||
if not html:
|
||||
logger.warning("网页获取失败")
|
||||
return []
|
||||
|
||||
return self._parse_comments(html, target_xpath, target_url)
|
||||
comments = self._parse_comments(html, target_xpath, target_url)
|
||||
logger.info(f"解析完成,获取到 {len(comments)} 条评论")
|
||||
return comments
|
||||
|
||||
def _fetch_with_retry(self, url: str, max_retries: int = None) -> Optional[str]:
|
||||
"""带重试的网页获取"""
|
||||
@@ -45,18 +52,49 @@ class SpiderManager:
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
logger.debug(f"尝试 {attempt + 1}/{max_retries} 获取网页")
|
||||
response = self.session.get(url, timeout=30)
|
||||
response.raise_for_status()
|
||||
response.encoding = response.apparent_encoding
|
||||
logger.debug(f"网页获取成功,大小: {len(response.text)} 字节")
|
||||
return response.text
|
||||
except requests.RequestException as e:
|
||||
print(f"请求失败 (尝试 {attempt + 1}/{max_retries}): {e}")
|
||||
logger.warning(f"请求失败 (尝试 {attempt + 1}/{max_retries}): {e}")
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(self.retry_interval + random.uniform(0, 2))
|
||||
else:
|
||||
logger.error(f"所有重试均失败: {url}")
|
||||
return None
|
||||
return None
|
||||
|
||||
def get_page_title(self, url: str = None) -> str:
|
||||
"""获取页面标题"""
|
||||
target_url = url or self.config.get('target_url', '')
|
||||
if not target_url:
|
||||
logger.warning("未设置目标URL")
|
||||
return ""
|
||||
|
||||
logger.info(f"获取页面标题: {target_url}")
|
||||
html = self._fetch_with_retry(target_url)
|
||||
if not html:
|
||||
logger.warning("网页获取失败")
|
||||
return ""
|
||||
|
||||
try:
|
||||
# 使用 lxml 解析页面标题
|
||||
tree = etree.HTML(html)
|
||||
title_elements = tree.xpath('//title/text()')
|
||||
if title_elements:
|
||||
title = title_elements[0].strip()
|
||||
logger.info(f"获取到页面标题: {title}")
|
||||
return title
|
||||
else:
|
||||
logger.warning("未找到页面标题")
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"解析页面标题失败: {e}")
|
||||
return ""
|
||||
|
||||
def _parse_comments(self, html: str, xpath: str, base_url: str) -> List[Dict]:
|
||||
"""解析评论"""
|
||||
comments = []
|
||||
@@ -65,10 +103,11 @@ class SpiderManager:
|
||||
# 使用 lxml 解析
|
||||
tree = etree.HTML(html)
|
||||
elements = tree.xpath(xpath)
|
||||
logger.debug(f"XPath 匹配到 {len(elements)} 个元素")
|
||||
|
||||
for elem in elements:
|
||||
try:
|
||||
text = elem.text_content().strip()
|
||||
text = etree.tostring(elem, method='text', encoding='unicode').strip()
|
||||
if text:
|
||||
# 获取链接的 href(如果存在)
|
||||
href = elem.get('href')
|
||||
@@ -79,11 +118,11 @@ class SpiderManager:
|
||||
'url': full_url
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"解析元素失败: {e}")
|
||||
logger.error(f"解析元素失败: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"XPath解析失败: {e}")
|
||||
logger.error(f"XPath解析失败: {e}")
|
||||
# 备选解析方法
|
||||
comments = self._fallback_parse(html, base_url)
|
||||
|
||||
@@ -93,6 +132,7 @@ class SpiderManager:
|
||||
"""备选解析方法 - 使用 BeautifulSoup"""
|
||||
comments = []
|
||||
try:
|
||||
logger.debug("使用备选解析方法")
|
||||
soup = BeautifulSoup(html, 'lxml')
|
||||
|
||||
# 尝试查找常见的评论元素
|
||||
@@ -106,11 +146,112 @@ class SpiderManager:
|
||||
'content': text,
|
||||
'url': base_url
|
||||
})
|
||||
logger.debug(f"备选解析获取到 {len(comments)} 条评论")
|
||||
|
||||
except Exception as e:
|
||||
print(f"备选解析失败: {e}")
|
||||
logger.error(f"备选解析失败: {e}")
|
||||
|
||||
return comments
|
||||
|
||||
def is_trading_time(self) -> bool:
|
||||
"""判断当前是否为交易时间"""
|
||||
from datetime import datetime, time
|
||||
|
||||
current_time = datetime.now().time()
|
||||
|
||||
# 上午交易时间: 9:30-11:30
|
||||
morning_start = time(9, 30)
|
||||
morning_end = time(11, 30)
|
||||
|
||||
# 下午交易时间: 13:00-15:00
|
||||
afternoon_start = time(13, 0)
|
||||
afternoon_end = time(15, 0)
|
||||
|
||||
# 判断是否在交易时间内
|
||||
is_trading = ((morning_start <= current_time <= morning_end) or
|
||||
(afternoon_start <= current_time <= afternoon_end))
|
||||
|
||||
logger.debug(f"当前时间 {current_time.strftime('%H:%M')} 是否为交易时间: {is_trading}")
|
||||
return is_trading
|
||||
|
||||
def fetch_sse_stock_data(self) -> Dict[str, float]:
|
||||
"""
|
||||
爬取上海证券交易所股票数据
|
||||
返回包含时间和数值的字典
|
||||
"""
|
||||
# 检查是否为交易时间
|
||||
if not self.is_trading_time():
|
||||
logger.info("当前为非交易时间,跳过股票数据爬取")
|
||||
return {}
|
||||
|
||||
sse_url = "https://www.sse.com.cn/"
|
||||
xpath = "//*[@id=\"hq_area\"]"
|
||||
|
||||
logger.info(f"开始爬取上海证券交易所数据: {sse_url}")
|
||||
|
||||
html = self._fetch_with_retry(sse_url)
|
||||
if not html:
|
||||
logger.warning("上海证券交易所网页获取失败")
|
||||
return {}
|
||||
|
||||
try:
|
||||
# 使用 lxml 解析
|
||||
tree = etree.HTML(html)
|
||||
|
||||
# 尝试获取股票数值
|
||||
elements = tree.xpath(xpath)
|
||||
if not elements:
|
||||
logger.warning("未找到股票数据元素,尝试备用XPath")
|
||||
# 尝试备用XPath
|
||||
backup_xpaths = [
|
||||
"//*[@id='hq_controller']//td[contains(@class, 'price')]//text()",
|
||||
"//*[contains(@class, 'stock-price')]//text()",
|
||||
"//*[contains(@class, 'price')]//text()"
|
||||
]
|
||||
|
||||
for backup_xpath in backup_xpaths:
|
||||
elements = tree.xpath(backup_xpath)
|
||||
if elements:
|
||||
logger.info(f"使用备用XPath获取到数据")
|
||||
break
|
||||
|
||||
if elements:
|
||||
# 获取文本内容并尝试转换为数值
|
||||
text_content = etree.tostring(elements[0], method='text', encoding='unicode').strip()
|
||||
logger.info(f"获取到股票数据文本: {text_content}")
|
||||
|
||||
# 提取数值(可能包含逗号、小数点等)
|
||||
import re
|
||||
# 匹配数字(包括小数点和逗号分隔符)
|
||||
numbers = re.findall(r'[\d,]+(?:\.\d+)?', text_content)
|
||||
if numbers:
|
||||
# 去除逗号并转换为浮点数
|
||||
value_str = numbers[0].replace(',', '')
|
||||
try:
|
||||
stock_value = float(value_str)
|
||||
|
||||
# 获取当前时间
|
||||
from datetime import datetime
|
||||
current_time = datetime.now().strftime("%H:%M")
|
||||
|
||||
logger.info(f"成功获取股票数据: 时间={current_time}, 值={stock_value}")
|
||||
|
||||
return {
|
||||
'time': current_time,
|
||||
'value': stock_value
|
||||
}
|
||||
except ValueError as e:
|
||||
logger.error(f"数值转换失败: {value_str}, 错误: {e}")
|
||||
else:
|
||||
logger.warning(f"未找到有效数值: {text_content}")
|
||||
else:
|
||||
logger.warning("未找到股票数据元素")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析上海证券交易所数据失败: {e}")
|
||||
|
||||
return {}
|
||||
|
||||
def set_user_agent(self, user_agent: str):
|
||||
"""更新User-Agent"""
|
||||
self.session.headers.update({'User-Agent': user_agent})
|
||||
|
||||
32
test_api.py
Normal file
32
test_api.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import os
|
||||
from openai import OpenAI
|
||||
|
||||
# 设置 API key 为环境变量
|
||||
os.environ["NVIDIA_API_KEY"] = "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"
|
||||
|
||||
client = OpenAI(
|
||||
base_url="https://integrate.api.nvidia.com/v1",
|
||||
api_key=os.environ["NVIDIA_API_KEY"]
|
||||
)
|
||||
|
||||
print("正在测试 API 连接...")
|
||||
print("发送消息: 天气")
|
||||
print("-" * 50)
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model="deepseek-ai/deepseek-r1",
|
||||
messages=[{"role": "user", "content": "天气"}],
|
||||
temperature=0.6,
|
||||
top_p=0.7,
|
||||
max_tokens=4096,
|
||||
stream=False
|
||||
)
|
||||
|
||||
reasoning = getattr(completion.choices[0].message, "reasoning_content", None)
|
||||
if reasoning:
|
||||
print("推理过程:")
|
||||
print(reasoning)
|
||||
print("-" * 50)
|
||||
|
||||
print("回答:")
|
||||
print(completion.choices[0].message.content)
|
||||
35
test_nvidia_api.py
Normal file
35
test_nvidia_api.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
测试 NVIDIA API 连接
|
||||
"""
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="https://integrate.api.nvidia.com/v1",
|
||||
api_key="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"
|
||||
)
|
||||
|
||||
try:
|
||||
print("开始测试 API 调用...")
|
||||
completion = client.chat.completions.create(
|
||||
model="deepseek-ai/deepseek-r1",
|
||||
messages=[{"role": "user", "content": "你好,请简单介绍一下你自己"}],
|
||||
temperature=0.6,
|
||||
top_p=0.7,
|
||||
max_tokens=4096,
|
||||
stream=False,
|
||||
timeout=120
|
||||
)
|
||||
|
||||
reasoning = getattr(completion.choices[0].message, "reasoning_content", None)
|
||||
if reasoning:
|
||||
print("推理过程:")
|
||||
print(reasoning)
|
||||
print("\n" + "="*50 + "\n")
|
||||
|
||||
print("回答内容:")
|
||||
print(completion.choices[0].message.content)
|
||||
print("\n" + "="*50)
|
||||
print("API 调用成功!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"API 调用失败: {type(e).__name__}: {e}")
|
||||
43
test_screenshot.py
Normal file
43
test_screenshot.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
测试截图功能
|
||||
"""
|
||||
import os
|
||||
from screenshot_manager import ScreenshotManager
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def test_screenshot():
|
||||
"""测试截图功能"""
|
||||
logger.info("开始测试截图功能")
|
||||
|
||||
# 创建截图管理器
|
||||
screenshot_manager = ScreenshotManager()
|
||||
|
||||
# 测试截图
|
||||
screenshot_path = screenshot_manager.capture_chart_screenshot()
|
||||
|
||||
if screenshot_path:
|
||||
logger.info(f"✅ 截图成功: {screenshot_path}")
|
||||
|
||||
# 检查文件是否存在
|
||||
if os.path.exists(screenshot_path):
|
||||
file_size = os.path.getsize(screenshot_path)
|
||||
logger.info(f"截图文件大小: {file_size} 字节")
|
||||
|
||||
# 获取最新截图
|
||||
latest_screenshot = screenshot_manager.get_latest_screenshot()
|
||||
logger.info(f"最新截图: {latest_screenshot}")
|
||||
|
||||
# 清理旧截图
|
||||
screenshot_manager.cleanup_old_screenshots(keep_count=2)
|
||||
logger.info("清理旧截图完成")
|
||||
else:
|
||||
logger.error("截图文件不存在")
|
||||
else:
|
||||
logger.error("截图失败")
|
||||
|
||||
logger.info("截图功能测试完成")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_screenshot()
|
||||
429
waveform_widget.py
Normal file
429
waveform_widget.py
Normal file
@@ -0,0 +1,429 @@
|
||||
"""
|
||||
波形图组件 - 用于绘制股票数据波形图
|
||||
"""
|
||||
import math
|
||||
import os
|
||||
from datetime import datetime, time
|
||||
from PySide6.QtWidgets import QWidget
|
||||
from PySide6.QtGui import QPainter, QPen, QColor, QBrush, QPixmap
|
||||
from PySide6.QtCore import QPointF, QTimer
|
||||
from loguru import logger
|
||||
|
||||
# 尝试导入截图管理器
|
||||
try:
|
||||
from screenshot_manager import ScreenshotManager
|
||||
except ImportError:
|
||||
ScreenshotManager = None
|
||||
logger.warning("截图管理器导入失败,截图功能将不可用")
|
||||
|
||||
|
||||
class WaveformWidget(QWidget):
|
||||
"""波形图控件"""
|
||||
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
self.data_points = [] # 存储数据点 [(time, value)]
|
||||
self.base_value = 0 # 基准值
|
||||
self.screenshot_manager = None
|
||||
self.latest_screenshot_path = ""
|
||||
self.setMinimumSize(600, 300)
|
||||
|
||||
# 设置背景色
|
||||
self.setStyleSheet("background-color: #1e1e1e;")
|
||||
|
||||
# 初始化截图管理器
|
||||
if ScreenshotManager:
|
||||
self.screenshot_manager = ScreenshotManager()
|
||||
logger.info("截图管理器初始化完成")
|
||||
|
||||
logger.info("WaveformWidget 初始化完成")
|
||||
|
||||
def time_to_x_position(self, time_str: str, total_width: int) -> float:
|
||||
"""
|
||||
将时间转换为X轴位置
|
||||
时间折算关系:
|
||||
- 9:30 -> 最左侧 (x=0)
|
||||
- 11:30 -> 中间 (x=total_width/2)
|
||||
- 13:00 -> 中间 (x=total_width/2)
|
||||
- 15:00 -> 最右侧 (x=total_width)
|
||||
- 10:30 -> 左侧四分之一 (x=total_width/4)
|
||||
- 14:00 -> 右侧四分之一 (x=total_width*3/4)
|
||||
"""
|
||||
try:
|
||||
# 解析时间字符串
|
||||
current_time = datetime.strptime(time_str, "%H:%M").time()
|
||||
|
||||
# 定义关键时间点
|
||||
market_start = time(9, 30) # 9:30
|
||||
market_mid1 = time(11, 30) # 11:30
|
||||
market_mid2 = time(13, 0) # 13:00
|
||||
market_end = time(15, 0) # 15:00
|
||||
|
||||
# 计算总交易时间(分钟)
|
||||
morning_duration = (market_mid1.hour - market_start.hour) * 60 + \
|
||||
(market_mid1.minute - market_start.minute)
|
||||
afternoon_duration = (market_end.hour - market_mid2.hour) * 60 + \
|
||||
(market_end.minute - market_mid2.minute)
|
||||
total_duration = morning_duration + afternoon_duration
|
||||
|
||||
# 计算当前时间相对于开盘时间的分钟数
|
||||
if current_time <= market_mid1:
|
||||
# 上午交易时段
|
||||
minutes_from_start = (current_time.hour - market_start.hour) * 60 + \
|
||||
(current_time.minute - market_start.minute)
|
||||
# 上午时段占一半宽度
|
||||
x_ratio = minutes_from_start / morning_duration * 0.5
|
||||
elif current_time >= market_mid2:
|
||||
# 下午交易时段
|
||||
minutes_from_start = (current_time.hour - market_mid2.hour) * 60 + \
|
||||
(current_time.minute - market_mid2.minute)
|
||||
# 下午时段占一半宽度,从中间开始
|
||||
x_ratio = 0.5 + minutes_from_start / afternoon_duration * 0.5
|
||||
else:
|
||||
# 午休时间,统一放在中间
|
||||
x_ratio = 0.5
|
||||
|
||||
return x_ratio * total_width
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"时间转换错误: {time_str}, 错误: {e}")
|
||||
return total_width / 2 # 默认返回中间位置
|
||||
|
||||
def is_trading_time(self, time_str: str) -> bool:
|
||||
"""判断是否为交易时间"""
|
||||
try:
|
||||
current_time = datetime.strptime(time_str, "%H:%M").time()
|
||||
|
||||
# 上午交易时间: 9:30-11:30
|
||||
morning_start = time(9, 30)
|
||||
morning_end = time(11, 30)
|
||||
|
||||
# 下午交易时间: 13:00-15:00
|
||||
afternoon_start = time(13, 0)
|
||||
afternoon_end = time(15, 0)
|
||||
|
||||
# 判断是否在交易时间内
|
||||
is_trading = ((morning_start <= current_time <= morning_end) or
|
||||
(afternoon_start <= current_time <= afternoon_end))
|
||||
|
||||
logger.debug(f"时间 {time_str} 是否为交易时间: {is_trading}")
|
||||
return is_trading
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"时间判断错误: {time_str}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def add_data_point(self, time_str: str, value: float):
|
||||
"""添加数据点"""
|
||||
# 检查是否为交易时间
|
||||
if not self.is_trading_time(time_str):
|
||||
logger.info(f"非交易时间 {time_str},跳过数据点添加")
|
||||
return
|
||||
|
||||
# 如果是第一个数据点,设置基准值
|
||||
if not self.data_points:
|
||||
self.base_value = value
|
||||
logger.info(f"设置基准值: {self.base_value}")
|
||||
|
||||
self.data_points.append((time_str, value))
|
||||
logger.info(f"添加数据点: 时间={time_str}, 值={value}")
|
||||
|
||||
# 限制数据点数量,避免内存过大
|
||||
if len(self.data_points) > 100:
|
||||
self.data_points = self.data_points[-100:]
|
||||
|
||||
# 触发重绘
|
||||
self.update()
|
||||
|
||||
def clear_data(self):
|
||||
"""清除所有数据"""
|
||||
self.data_points.clear()
|
||||
self.base_value = 0
|
||||
self.update()
|
||||
logger.info("波形图数据已清除")
|
||||
|
||||
def paintEvent(self, event):
|
||||
"""绘制波形图"""
|
||||
painter = QPainter(self)
|
||||
painter.setRenderHint(QPainter.Antialiasing)
|
||||
|
||||
width = self.width()
|
||||
height = self.height()
|
||||
|
||||
# 如果没有数据点,显示提示信息
|
||||
if not self.data_points:
|
||||
self._draw_no_data_message(painter, width, height)
|
||||
return
|
||||
|
||||
# 检查最后一个数据点的时间是否为交易时间
|
||||
last_time = self.data_points[-1][0] if self.data_points else ""
|
||||
if last_time and not self.is_trading_time(last_time):
|
||||
self._draw_non_trading_message(painter, width, height)
|
||||
return
|
||||
|
||||
# 绘制网格和坐标轴
|
||||
self._draw_grid(painter, width, height)
|
||||
|
||||
# 绘制波形线
|
||||
self._draw_waveform(painter, width, height)
|
||||
|
||||
# 绘制数据点
|
||||
self._draw_data_points(painter, width, height)
|
||||
|
||||
def _draw_grid(self, painter: QPainter, width: int, height: int):
|
||||
"""绘制网格和坐标轴"""
|
||||
# 设置网格颜色
|
||||
grid_color = QColor(100, 100, 100)
|
||||
painter.setPen(QPen(grid_color, 1, Qt.DashLine))
|
||||
|
||||
# 绘制水平网格线
|
||||
for i in range(1, 5):
|
||||
y = height * i // 5
|
||||
painter.drawLine(0, y, width, y)
|
||||
|
||||
# 绘制垂直网格线(时间刻度)
|
||||
time_points = ["9:30", "10:30", "11:30", "13:00", "14:00", "15:00"]
|
||||
for time_str in time_points:
|
||||
x = self.time_to_x_position(time_str, width)
|
||||
painter.drawLine(x, 0, x, height)
|
||||
|
||||
# 绘制坐标轴
|
||||
axis_color = QColor(200, 200, 200)
|
||||
painter.setPen(QPen(axis_color, 2))
|
||||
painter.drawLine(0, height // 2, width, height // 2) # X轴
|
||||
painter.drawLine(0, 0, 0, height) # Y轴
|
||||
|
||||
# 绘制时间标签
|
||||
painter.setPen(QPen(QColor(150, 150, 150), 1))
|
||||
for time_str in time_points:
|
||||
x = self.time_to_x_position(time_str, width)
|
||||
painter.drawText(int(x) - 20, height - 5, time_str)
|
||||
|
||||
def _draw_waveform(self, painter: QPainter, width: int, height: int):
|
||||
"""绘制波形线"""
|
||||
if len(self.data_points) < 2:
|
||||
return
|
||||
|
||||
# 设置波形线样式
|
||||
waveform_color = QColor(0, 200, 255)
|
||||
painter.setPen(QPen(waveform_color, 3))
|
||||
|
||||
points = []
|
||||
|
||||
# 计算Y轴范围(基准值±100点)
|
||||
y_min = self.base_value - 100
|
||||
y_max = self.base_value + 100
|
||||
y_range = y_max - y_min
|
||||
|
||||
for time_str, value in self.data_points:
|
||||
x = self.time_to_x_position(time_str, width)
|
||||
|
||||
# 计算Y坐标(从底部到顶部)
|
||||
y_ratio = (value - y_min) / y_range
|
||||
y = height - (y_ratio * height)
|
||||
|
||||
points.append(QPointF(x, y))
|
||||
|
||||
# 绘制折线
|
||||
for i in range(len(points) - 1):
|
||||
painter.drawLine(points[i], points[i + 1])
|
||||
|
||||
def _draw_data_points(self, painter: QPainter, width: int, height: int):
|
||||
"""绘制数据点"""
|
||||
point_color = QColor(255, 100, 100)
|
||||
painter.setPen(QPen(point_color, 1))
|
||||
painter.setBrush(QBrush(point_color))
|
||||
|
||||
# 计算Y轴范围
|
||||
y_min = self.base_value - 100
|
||||
y_max = self.base_value + 100
|
||||
y_range = y_max - y_min
|
||||
|
||||
for time_str, value in self.data_points:
|
||||
x = self.time_to_x_position(time_str, width)
|
||||
y_ratio = (value - y_min) / y_range
|
||||
y = height - (y_ratio * height)
|
||||
|
||||
# 绘制数据点圆圈
|
||||
painter.drawEllipse(int(x) - 3, int(y) - 3, 6, 6)
|
||||
|
||||
# 显示数值标签
|
||||
painter.setPen(QPen(QColor(200, 200, 200), 1))
|
||||
painter.drawText(int(x) + 5, int(y) - 5, f"{value:.2f}")
|
||||
|
||||
def _draw_no_data_message(self, painter: QPainter, width: int, height: int):
|
||||
"""绘制无数据提示信息"""
|
||||
# 设置提示信息颜色
|
||||
message_color = QColor(150, 150, 150)
|
||||
painter.setPen(QPen(message_color, 2))
|
||||
|
||||
# 设置字体
|
||||
font = painter.font()
|
||||
font.setPointSize(14)
|
||||
font.setBold(True)
|
||||
painter.setFont(font)
|
||||
|
||||
# 绘制提示信息
|
||||
message = "等待交易时间数据..."
|
||||
text_rect = painter.fontMetrics().boundingRect(message)
|
||||
x = (width - text_rect.width()) // 2
|
||||
y = height // 2
|
||||
|
||||
painter.drawText(x, y, message)
|
||||
|
||||
# 绘制交易时间说明
|
||||
font.setPointSize(10)
|
||||
font.setBold(False)
|
||||
painter.setFont(font)
|
||||
|
||||
info = "交易时间: 9:30-11:30, 13:00-15:00"
|
||||
info_rect = painter.fontMetrics().boundingRect(info)
|
||||
x_info = (width - info_rect.width()) // 2
|
||||
y_info = y + 30
|
||||
|
||||
painter.drawText(x_info, y_info, info)
|
||||
|
||||
def _draw_non_trading_message(self, painter: QPainter, width: int, height: int):
|
||||
"""绘制非交易时间提示信息或截图"""
|
||||
# 如果有截图管理器,尝试显示截图
|
||||
if self.screenshot_manager:
|
||||
screenshot_path = self._get_or_capture_screenshot()
|
||||
if screenshot_path:
|
||||
self._draw_screenshot(painter, screenshot_path, width, height)
|
||||
return
|
||||
|
||||
# 如果没有截图或截图失败,显示文本提示
|
||||
self._draw_text_message(painter, width, height)
|
||||
|
||||
def _get_or_capture_screenshot(self) -> str:
|
||||
"""获取或捕获最新的截图"""
|
||||
try:
|
||||
# 检查是否有最新的截图
|
||||
latest_screenshot = self.screenshot_manager.get_latest_screenshot()
|
||||
|
||||
# 如果截图文件存在且较新(5分钟内),使用现有截图
|
||||
if latest_screenshot and os.path.exists(latest_screenshot):
|
||||
file_time = datetime.fromtimestamp(os.path.getmtime(latest_screenshot))
|
||||
current_time = datetime.now()
|
||||
|
||||
# 如果截图在5分钟内,直接使用
|
||||
if (current_time - file_time).total_seconds() < 300: # 5分钟
|
||||
return latest_screenshot
|
||||
|
||||
# 否则捕获新的截图
|
||||
logger.info("开始捕获上海证券交易所网站截图")
|
||||
new_screenshot = self.screenshot_manager.capture_chart_screenshot()
|
||||
|
||||
if new_screenshot:
|
||||
# 清理旧截图
|
||||
self.screenshot_manager.cleanup_old_screenshots()
|
||||
return new_screenshot
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取截图失败: {e}")
|
||||
|
||||
return ""
|
||||
|
||||
def _draw_screenshot(self, painter: QPainter, screenshot_path: str, width: int, height: int):
|
||||
"""绘制截图"""
|
||||
try:
|
||||
# 加载截图
|
||||
pixmap = QPixmap(screenshot_path)
|
||||
if pixmap.isNull():
|
||||
logger.warning(f"截图加载失败: {screenshot_path}")
|
||||
self._draw_text_message(painter, width, height)
|
||||
return
|
||||
|
||||
# 计算缩放比例,保持宽高比
|
||||
pixmap_width = pixmap.width()
|
||||
pixmap_height = pixmap.height()
|
||||
|
||||
# 计算缩放比例,使截图适应显示区域
|
||||
scale_x = width / pixmap_width
|
||||
scale_y = height / pixmap_height
|
||||
scale = min(scale_x, scale_y) * 0.8 # 留出边距
|
||||
|
||||
# 计算显示尺寸
|
||||
display_width = int(pixmap_width * scale)
|
||||
display_height = int(pixmap_height * scale)
|
||||
|
||||
# 计算显示位置(居中)
|
||||
x = (width - display_width) // 2
|
||||
y = (height - display_height) // 2
|
||||
|
||||
# 绘制截图
|
||||
scaled_pixmap = pixmap.scaled(display_width, display_height)
|
||||
painter.drawPixmap(x, y, scaled_pixmap)
|
||||
|
||||
# 绘制标题
|
||||
font = painter.font()
|
||||
font.setPointSize(12)
|
||||
font.setBold(True)
|
||||
painter.setFont(font)
|
||||
|
||||
title_color = QColor(255, 255, 255)
|
||||
painter.setPen(QPen(title_color, 2))
|
||||
|
||||
title = "上海证券交易所实时图表"
|
||||
title_rect = painter.fontMetrics().boundingRect(title)
|
||||
title_x = (width - title_rect.width()) // 2
|
||||
title_y = y - 10
|
||||
|
||||
painter.drawText(title_x, title_y, title)
|
||||
|
||||
# 绘制更新时间
|
||||
font.setPointSize(8)
|
||||
font.setBold(False)
|
||||
painter.setFont(font)
|
||||
|
||||
update_time = datetime.now().strftime("更新时间: %H:%M:%S")
|
||||
time_rect = painter.fontMetrics().boundingRect(update_time)
|
||||
time_x = (width - time_rect.width()) // 2
|
||||
time_y = y + display_height + 20
|
||||
|
||||
painter.drawText(time_x, time_y, update_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"绘制截图失败: {e}")
|
||||
self._draw_text_message(painter, width, height)
|
||||
|
||||
def _draw_text_message(self, painter: QPainter, width: int, height: int):
|
||||
"""绘制文本提示信息"""
|
||||
# 设置提示信息颜色
|
||||
message_color = QColor(200, 100, 100)
|
||||
painter.setPen(QPen(message_color, 2))
|
||||
|
||||
# 设置字体
|
||||
font = painter.font()
|
||||
font.setPointSize(14)
|
||||
font.setBold(True)
|
||||
painter.setFont(font)
|
||||
|
||||
# 绘制提示信息
|
||||
message = "非交易时间"
|
||||
text_rect = painter.fontMetrics().boundingRect(message)
|
||||
x = (width - text_rect.width()) // 2
|
||||
y = height // 2
|
||||
|
||||
painter.drawText(x, y, message)
|
||||
|
||||
# 绘制交易时间说明
|
||||
font.setPointSize(10)
|
||||
font.setBold(False)
|
||||
painter.setFont(font)
|
||||
|
||||
info = "交易时间: 9:30-11:30, 13:00-15:00"
|
||||
info_rect = painter.fontMetrics().boundingRect(info)
|
||||
x_info = (width - info_rect.width()) // 2
|
||||
y_info = y + 30
|
||||
|
||||
painter.drawText(x_info, y_info, info)
|
||||
|
||||
# 绘制当前时间
|
||||
current_time = datetime.now().strftime("%H:%M")
|
||||
time_info = f"当前时间: {current_time}"
|
||||
time_rect = painter.fontMetrics().boundingRect(time_info)
|
||||
x_time = (width - time_rect.width()) // 2
|
||||
y_time = y_info + 25
|
||||
|
||||
painter.drawText(x_time, y_time, time_info)
|
||||
Reference in New Issue
Block a user