252 lines
11 KiB
Python
252 lines
11 KiB
Python
import sys
|
||
from pathlib import Path
|
||
from typing import Dict, Any, List, Optional
|
||
from loguru import logger
|
||
from datetime import datetime
|
||
import yaml
|
||
|
||
sys.path.insert(0, str(Path(__file__).parent))
|
||
|
||
from core.anchor_engine import AnchorEngine
|
||
from core.native_chart import NativeChartManager
|
||
from core.conditional_renderer import ConditionalRenderer
|
||
from plugins.base_generator import GeneratorPluginManager
|
||
from ai.llm_analyst import LLMAnalyst
|
||
|
||
class Orchestrator:
|
||
def __init__(self, config_path: str = None):
|
||
self.base_dir = Path(__file__).parent
|
||
|
||
if config_path:
|
||
self.config_path = Path(config_path)
|
||
else:
|
||
self.config_path = self.base_dir / "config" / "project_config_v2.yaml"
|
||
|
||
self.config = self._load_config()
|
||
self.prs = None
|
||
self.results: Dict[str, Any] = {}
|
||
self.progress_callback = None
|
||
|
||
log_file = self.base_dir / "logs" / f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
|
||
log_file.parent.mkdir(parents=True, exist_ok=True)
|
||
logger.add(str(log_file), rotation="10 MB", level="INFO", encoding="utf-8")
|
||
|
||
self.plugin_manager = GeneratorPluginManager(self.base_dir / "plugins" / "generators")
|
||
self.plugin_manager.discover_plugins()
|
||
|
||
self.anchor_engine = AnchorEngine()
|
||
self.chart_manager = NativeChartManager()
|
||
self.conditional_renderer = ConditionalRenderer()
|
||
|
||
ai_config = self.config.get('ai', {})
|
||
self.llm = LLMAnalyst(ai_config)
|
||
|
||
def _load_config(self) -> Dict:
|
||
if self.config_path.exists():
|
||
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||
return yaml.safe_load(f) or {}
|
||
return {}
|
||
|
||
def set_progress_callback(self, callback):
|
||
self.progress_callback = callback
|
||
|
||
def _report_progress(self, step: str, message: str, percent: int = 0):
|
||
logger.info(f"[{step}] {message}")
|
||
if self.progress_callback:
|
||
self.progress_callback({
|
||
'step': step,
|
||
'message': message,
|
||
'percent': percent,
|
||
'timestamp': datetime.now().isoformat()
|
||
})
|
||
|
||
def load_template(self, template_path: str = None) -> bool:
|
||
template_file = self.base_dir / "template_ppt" / (template_path or self.config.get('template', ''))
|
||
|
||
if not template_file.exists():
|
||
self._report_progress("LOAD", f"创建新演示文稿(模板不存在: {template_file.name})", 5)
|
||
from pptx import Presentation
|
||
self.prs = Presentation()
|
||
|
||
for i in range(8):
|
||
layout_idx = 5 if i == 0 else 1
|
||
if layout_idx < len(self.prs.slide_layouts):
|
||
slide = self.prs.slides.add_slide(self.prs.slide_layouts[layout_idx])
|
||
if slide.shapes.title:
|
||
slide.shapes.title.text = f"示例页 {i+1}"
|
||
if i == 3:
|
||
if slide.shapes.title:
|
||
slide.shapes.title.text = "GDP趋势图"
|
||
for shp in slide.shapes:
|
||
if shp != slide.shapes.title and hasattr(shp, 'name'):
|
||
shp.name = "chart_gdp"
|
||
break
|
||
if i == 4:
|
||
if slide.shapes.title:
|
||
slide.shapes.title.text = "CPI/PPI走势图"
|
||
for shp in slide.shapes:
|
||
if shp != slide.shapes.title and hasattr(shp, 'name'):
|
||
shp.name = "chart_cpi"
|
||
break
|
||
logger.debug(f"创建示例页 {i+1}")
|
||
else:
|
||
self._report_progress("LOAD", f"加载模板: {template_file.name}", 5)
|
||
self.prs = self.anchor_engine.load_presentation(template_file)
|
||
|
||
self.anchor_engine.prs = self.prs
|
||
self.chart_manager.set_presentation(self.prs)
|
||
self.conditional_renderer.prs = self.prs
|
||
|
||
anchors_found = self.anchor_engine.scan_anchors()
|
||
self._report_progress("LOAD", f"发现 {len(anchors_found)} 个可绑定锚点", 10)
|
||
return True
|
||
|
||
def run_plugins(self, params: Dict[str, Any] = None) -> Dict:
|
||
self._report_progress("PLUGINS", "开始执行数据插件...", 15)
|
||
|
||
run_params = {**self.config.get('params', {}), **(params or {})}
|
||
plugin_results = {}
|
||
|
||
total = len(self.plugin_manager.plugins)
|
||
for idx, (pid, plugin_class) in enumerate(self.plugin_manager.plugins.items()):
|
||
self._report_progress("PLUGINS", f"执行插件: {plugin_class.generator_name}", 20 + int(idx/total*30))
|
||
|
||
try:
|
||
plugin = plugin_class(run_params)
|
||
if plugin.fetch_data(run_params):
|
||
render_result = plugin.render()
|
||
plugin_results[pid] = {
|
||
'success': True,
|
||
'data': plugin.get_data(),
|
||
'render': render_result
|
||
}
|
||
logger.success(f"插件 [{pid}] 执行成功")
|
||
else:
|
||
plugin_results[pid] = {'success': False}
|
||
except Exception as e:
|
||
logger.exception(f"插件 [{pid}] 执行失败: {e}")
|
||
plugin_results[pid] = {'success': False, 'error': str(e)}
|
||
|
||
self.results['plugins'] = plugin_results
|
||
return plugin_results
|
||
|
||
def update_native_charts(self) -> int:
|
||
self._report_progress("CHARTS", "开始更新原生图表...", 55)
|
||
|
||
updates_count = 0
|
||
anchors_config = self.config.get('anchors', [])
|
||
|
||
for anchor_cfg in anchors_config:
|
||
anchor_type = anchor_cfg.get('type', '')
|
||
anchor_name = anchor_cfg.get('name', '')
|
||
|
||
if 'native_chart' in anchor_type:
|
||
plugin_id = anchor_cfg.get('plugin')
|
||
chart_type = anchor_cfg.get('chart_type', 'line')
|
||
|
||
if plugin_id and plugin_id in self.results.get('plugins', {}):
|
||
plugin_result = self.results['plugins'][plugin_id]
|
||
if plugin_result.get('success'):
|
||
render = plugin_result['render']
|
||
|
||
success = self.chart_manager.update_chart_by_anchor(
|
||
anchor_name=render.get('anchor', anchor_name),
|
||
categories=render.get('categories', []),
|
||
series_data=render.get('series', {}),
|
||
chart_type=chart_type
|
||
)
|
||
if success:
|
||
updates_count += 1
|
||
|
||
self._report_progress("CHARTS", f"完成 {updates_count} 个原生图表更新", 70)
|
||
return updates_count
|
||
|
||
def ai_generate_summary(self) -> str:
|
||
self._report_progress("AI", "LLM正在生成分析摘要...", 75)
|
||
|
||
context = {}
|
||
for pid, presult in self.results.get('plugins', {}).items():
|
||
if presult.get('success') and 'data' in presult:
|
||
df = presult['data']
|
||
if hasattr(df, 'iloc'):
|
||
for col in df.columns:
|
||
try:
|
||
context[f"{pid}_{col}"] = round(float(df[col].iloc[-1]), 2)
|
||
except:
|
||
pass
|
||
|
||
try:
|
||
summary = self.llm.generate_analysis(context, max_words=200)
|
||
self.results['llm_summary'] = summary
|
||
|
||
for anchor_cfg in self.config.get('anchors', []):
|
||
if anchor_cfg.get('type') == 'text_replace' and 'summary' in anchor_cfg.get('name', ''):
|
||
self.anchor_engine.replace_text_anchor(anchor_cfg['name'], summary)
|
||
if anchor_cfg.get('type') == 'ai_text_generate':
|
||
anchor_name = anchor_cfg.get('name', 'text_llm_summary')
|
||
self.anchor_engine.replace_text_anchor(anchor_name, summary)
|
||
|
||
logger.success(f"LLM分析完成: {summary[:50]}...")
|
||
return summary
|
||
except Exception as e:
|
||
logger.warning(f"LLM分析失败: {e}")
|
||
return ""
|
||
|
||
def process_conditions(self):
|
||
self._report_progress("CONDITIONS", "处理条件渲染规则...", 80)
|
||
|
||
context = {}
|
||
for pid, presult in self.results.get('plugins', {}).items():
|
||
if presult.get('success') and 'data' in presult:
|
||
df = presult['data']
|
||
if hasattr(df, 'iloc'):
|
||
for col in df.columns:
|
||
try:
|
||
val = float(df[col].iloc[-1])
|
||
context[col] = val
|
||
if 'GDP' in col: context['gdp_growth'] = val
|
||
if 'unemploy' in col.lower(): context['unemployment_rate'] = val
|
||
except:
|
||
pass
|
||
|
||
self.conditional_renderer.set_context(context)
|
||
|
||
slide_conditions = self.config.get('slide_conditions', [])
|
||
for cond_cfg in slide_conditions:
|
||
if 'condition' in cond_cfg:
|
||
cond_result = self.conditional_renderer.evaluate_condition(cond_cfg['condition'])
|
||
logger.info(f"条件规则: {cond_cfg['condition']} -> {'满足' if cond_result else '不满足'}")
|
||
|
||
return True
|
||
|
||
def save(self, output_name: str = None) -> str:
|
||
self._report_progress("SAVE", "正在保存最终PPT...", 95)
|
||
|
||
output_dir = self.base_dir / self.config.get('output_dir', 'output')
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
if not output_name:
|
||
output_name = f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pptx"
|
||
|
||
output_path = output_dir / output_name
|
||
self.prs.save(str(output_path))
|
||
|
||
self._report_progress("DONE", f"PPT已保存: {output_name}", 100)
|
||
logger.success(f"=" * 50)
|
||
logger.success(f"生成完成: {output_path}")
|
||
logger.success(f"=" * 50)
|
||
|
||
return str(output_path)
|
||
|
||
def run_full_pipeline(self, template_path: str = None, params: Dict = None) -> str:
|
||
self.load_template(template_path)
|
||
self.run_plugins(params)
|
||
self.update_native_charts()
|
||
self.ai_generate_summary()
|
||
self.process_conditions()
|
||
return self.save()
|
||
|
||
if __name__ == "__main__":
|
||
orch = Orchestrator()
|
||
orch.run_full_pipeline()
|