V2.0 五大核心增强: 锚点定位/原生图表/插件架构/WebSocket/LLM智能
This commit is contained in:
251
ppt_manager_v2/orchestrator.py
Normal file
251
ppt_manager_v2/orchestrator.py
Normal file
@@ -0,0 +1,251 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
from loguru import logger
|
||||
from datetime import datetime
|
||||
import yaml
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from core.anchor_engine import AnchorEngine
|
||||
from core.native_chart import NativeChartManager
|
||||
from core.conditional_renderer import ConditionalRenderer
|
||||
from plugins.base_generator import GeneratorPluginManager
|
||||
from ai.llm_analyst import LLMAnalyst
|
||||
|
||||
class Orchestrator:
|
||||
def __init__(self, config_path: str = None):
|
||||
self.base_dir = Path(__file__).parent
|
||||
|
||||
if config_path:
|
||||
self.config_path = Path(config_path)
|
||||
else:
|
||||
self.config_path = self.base_dir / "config" / "project_config_v2.yaml"
|
||||
|
||||
self.config = self._load_config()
|
||||
self.prs = None
|
||||
self.results: Dict[str, Any] = {}
|
||||
self.progress_callback = None
|
||||
|
||||
log_file = self.base_dir / "logs" / f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
|
||||
log_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
logger.add(str(log_file), rotation="10 MB", level="INFO", encoding="utf-8")
|
||||
|
||||
self.plugin_manager = GeneratorPluginManager(self.base_dir / "plugins" / "generators")
|
||||
self.plugin_manager.discover_plugins()
|
||||
|
||||
self.anchor_engine = AnchorEngine()
|
||||
self.chart_manager = NativeChartManager()
|
||||
self.conditional_renderer = ConditionalRenderer()
|
||||
|
||||
ai_config = self.config.get('ai', {})
|
||||
self.llm = LLMAnalyst(ai_config)
|
||||
|
||||
def _load_config(self) -> Dict:
|
||||
if self.config_path.exists():
|
||||
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||||
return yaml.safe_load(f) or {}
|
||||
return {}
|
||||
|
||||
def set_progress_callback(self, callback):
|
||||
self.progress_callback = callback
|
||||
|
||||
def _report_progress(self, step: str, message: str, percent: int = 0):
|
||||
logger.info(f"[{step}] {message}")
|
||||
if self.progress_callback:
|
||||
self.progress_callback({
|
||||
'step': step,
|
||||
'message': message,
|
||||
'percent': percent,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
})
|
||||
|
||||
def load_template(self, template_path: str = None) -> bool:
|
||||
template_file = self.base_dir / "template_ppt" / (template_path or self.config.get('template', ''))
|
||||
|
||||
if not template_file.exists():
|
||||
self._report_progress("LOAD", f"创建新演示文稿(模板不存在: {template_file.name})", 5)
|
||||
from pptx import Presentation
|
||||
self.prs = Presentation()
|
||||
|
||||
for i in range(8):
|
||||
layout_idx = 5 if i == 0 else 1
|
||||
if layout_idx < len(self.prs.slide_layouts):
|
||||
slide = self.prs.slides.add_slide(self.prs.slide_layouts[layout_idx])
|
||||
if slide.shapes.title:
|
||||
slide.shapes.title.text = f"示例页 {i+1}"
|
||||
if i == 3:
|
||||
if slide.shapes.title:
|
||||
slide.shapes.title.text = "GDP趋势图"
|
||||
for shp in slide.shapes:
|
||||
if shp != slide.shapes.title and hasattr(shp, 'name'):
|
||||
shp.name = "chart_gdp"
|
||||
break
|
||||
if i == 4:
|
||||
if slide.shapes.title:
|
||||
slide.shapes.title.text = "CPI/PPI走势图"
|
||||
for shp in slide.shapes:
|
||||
if shp != slide.shapes.title and hasattr(shp, 'name'):
|
||||
shp.name = "chart_cpi"
|
||||
break
|
||||
logger.debug(f"创建示例页 {i+1}")
|
||||
else:
|
||||
self._report_progress("LOAD", f"加载模板: {template_file.name}", 5)
|
||||
self.prs = self.anchor_engine.load_presentation(template_file)
|
||||
|
||||
self.anchor_engine.prs = self.prs
|
||||
self.chart_manager.set_presentation(self.prs)
|
||||
self.conditional_renderer.prs = self.prs
|
||||
|
||||
anchors_found = self.anchor_engine.scan_anchors()
|
||||
self._report_progress("LOAD", f"发现 {len(anchors_found)} 个可绑定锚点", 10)
|
||||
return True
|
||||
|
||||
def run_plugins(self, params: Dict[str, Any] = None) -> Dict:
|
||||
self._report_progress("PLUGINS", "开始执行数据插件...", 15)
|
||||
|
||||
run_params = {**self.config.get('params', {}), **(params or {})}
|
||||
plugin_results = {}
|
||||
|
||||
total = len(self.plugin_manager.plugins)
|
||||
for idx, (pid, plugin_class) in enumerate(self.plugin_manager.plugins.items()):
|
||||
self._report_progress("PLUGINS", f"执行插件: {plugin_class.generator_name}", 20 + int(idx/total*30))
|
||||
|
||||
try:
|
||||
plugin = plugin_class(run_params)
|
||||
if plugin.fetch_data(run_params):
|
||||
render_result = plugin.render()
|
||||
plugin_results[pid] = {
|
||||
'success': True,
|
||||
'data': plugin.get_data(),
|
||||
'render': render_result
|
||||
}
|
||||
logger.success(f"插件 [{pid}] 执行成功")
|
||||
else:
|
||||
plugin_results[pid] = {'success': False}
|
||||
except Exception as e:
|
||||
logger.exception(f"插件 [{pid}] 执行失败: {e}")
|
||||
plugin_results[pid] = {'success': False, 'error': str(e)}
|
||||
|
||||
self.results['plugins'] = plugin_results
|
||||
return plugin_results
|
||||
|
||||
def update_native_charts(self) -> int:
|
||||
self._report_progress("CHARTS", "开始更新原生图表...", 55)
|
||||
|
||||
updates_count = 0
|
||||
anchors_config = self.config.get('anchors', [])
|
||||
|
||||
for anchor_cfg in anchors_config:
|
||||
anchor_type = anchor_cfg.get('type', '')
|
||||
anchor_name = anchor_cfg.get('name', '')
|
||||
|
||||
if 'native_chart' in anchor_type:
|
||||
plugin_id = anchor_cfg.get('plugin')
|
||||
chart_type = anchor_cfg.get('chart_type', 'line')
|
||||
|
||||
if plugin_id and plugin_id in self.results.get('plugins', {}):
|
||||
plugin_result = self.results['plugins'][plugin_id]
|
||||
if plugin_result.get('success'):
|
||||
render = plugin_result['render']
|
||||
|
||||
success = self.chart_manager.update_chart_by_anchor(
|
||||
anchor_name=render.get('anchor', anchor_name),
|
||||
categories=render.get('categories', []),
|
||||
series_data=render.get('series', {}),
|
||||
chart_type=chart_type
|
||||
)
|
||||
if success:
|
||||
updates_count += 1
|
||||
|
||||
self._report_progress("CHARTS", f"完成 {updates_count} 个原生图表更新", 70)
|
||||
return updates_count
|
||||
|
||||
def ai_generate_summary(self) -> str:
|
||||
self._report_progress("AI", "LLM正在生成分析摘要...", 75)
|
||||
|
||||
context = {}
|
||||
for pid, presult in self.results.get('plugins', {}).items():
|
||||
if presult.get('success') and 'data' in presult:
|
||||
df = presult['data']
|
||||
if hasattr(df, 'iloc'):
|
||||
for col in df.columns:
|
||||
try:
|
||||
context[f"{pid}_{col}"] = round(float(df[col].iloc[-1]), 2)
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
summary = self.llm.generate_analysis(context, max_words=200)
|
||||
self.results['llm_summary'] = summary
|
||||
|
||||
for anchor_cfg in self.config.get('anchors', []):
|
||||
if anchor_cfg.get('type') == 'text_replace' and 'summary' in anchor_cfg.get('name', ''):
|
||||
self.anchor_engine.replace_text_anchor(anchor_cfg['name'], summary)
|
||||
if anchor_cfg.get('type') == 'ai_text_generate':
|
||||
anchor_name = anchor_cfg.get('name', 'text_llm_summary')
|
||||
self.anchor_engine.replace_text_anchor(anchor_name, summary)
|
||||
|
||||
logger.success(f"LLM分析完成: {summary[:50]}...")
|
||||
return summary
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM分析失败: {e}")
|
||||
return ""
|
||||
|
||||
def process_conditions(self):
|
||||
self._report_progress("CONDITIONS", "处理条件渲染规则...", 80)
|
||||
|
||||
context = {}
|
||||
for pid, presult in self.results.get('plugins', {}).items():
|
||||
if presult.get('success') and 'data' in presult:
|
||||
df = presult['data']
|
||||
if hasattr(df, 'iloc'):
|
||||
for col in df.columns:
|
||||
try:
|
||||
val = float(df[col].iloc[-1])
|
||||
context[col] = val
|
||||
if 'GDP' in col: context['gdp_growth'] = val
|
||||
if 'unemploy' in col.lower(): context['unemployment_rate'] = val
|
||||
except:
|
||||
pass
|
||||
|
||||
self.conditional_renderer.set_context(context)
|
||||
|
||||
slide_conditions = self.config.get('slide_conditions', [])
|
||||
for cond_cfg in slide_conditions:
|
||||
if 'condition' in cond_cfg:
|
||||
cond_result = self.conditional_renderer.evaluate_condition(cond_cfg['condition'])
|
||||
logger.info(f"条件规则: {cond_cfg['condition']} -> {'满足' if cond_result else '不满足'}")
|
||||
|
||||
return True
|
||||
|
||||
def save(self, output_name: str = None) -> str:
|
||||
self._report_progress("SAVE", "正在保存最终PPT...", 95)
|
||||
|
||||
output_dir = self.base_dir / self.config.get('output_dir', 'output')
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not output_name:
|
||||
output_name = f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pptx"
|
||||
|
||||
output_path = output_dir / output_name
|
||||
self.prs.save(str(output_path))
|
||||
|
||||
self._report_progress("DONE", f"PPT已保存: {output_name}", 100)
|
||||
logger.success(f"=" * 50)
|
||||
logger.success(f"生成完成: {output_path}")
|
||||
logger.success(f"=" * 50)
|
||||
|
||||
return str(output_path)
|
||||
|
||||
def run_full_pipeline(self, template_path: str = None, params: Dict = None) -> str:
|
||||
self.load_template(template_path)
|
||||
self.run_plugins(params)
|
||||
self.update_native_charts()
|
||||
self.ai_generate_summary()
|
||||
self.process_conditions()
|
||||
return self.save()
|
||||
|
||||
if __name__ == "__main__":
|
||||
orch = Orchestrator()
|
||||
orch.run_full_pipeline()
|
||||
Reference in New Issue
Block a user