110 lines
3.9 KiB
Python
110 lines
3.9 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import Dict, Any, List, Optional
|
|
import pandas as pd
|
|
from pathlib import Path
|
|
from loguru import logger
|
|
|
|
class BaseGenerator(ABC):
|
|
generator_id: str = None
|
|
generator_name: str = None
|
|
description: str = None
|
|
version: str = "1.0.0"
|
|
|
|
params_schema: Dict[str, Any] = {}
|
|
|
|
def __init__(self, params: Dict[str, Any] = None):
|
|
self.params = params or {}
|
|
self.logger = logger.bind(generator=self.generator_id)
|
|
self._data = None
|
|
self._chart_data = None
|
|
|
|
@abstractmethod
|
|
def fetch_data(self, params: Dict[str, Any] = None) -> bool:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def render(self) -> Dict[str, Any]:
|
|
pass
|
|
|
|
def validate_params(self, params: Dict[str, Any]) -> bool:
|
|
for param_name, param_config in self.params_schema.items():
|
|
if param_config.get('required', False) and param_name not in params:
|
|
self.logger.warning(f"缺少必需参数: {param_name}")
|
|
return True
|
|
|
|
def get_data(self):
|
|
return self._data
|
|
|
|
def get_result(self) -> Dict[str, Any]:
|
|
return {
|
|
'generator_id': self.generator_id,
|
|
'generator_name': self.generator_name,
|
|
'data': self._data,
|
|
'chart_data': self._chart_data,
|
|
'success': True
|
|
}
|
|
|
|
class GeneratorPluginManager:
|
|
def __init__(self, plugins_dir: Path = None):
|
|
self.plugins: Dict[str, BaseGenerator] = {}
|
|
self.plugins_dir = plugins_dir or Path(__file__).parent / 'generators'
|
|
self._discovered = False
|
|
|
|
def discover_plugins(self) -> int:
|
|
import sys
|
|
import importlib.util
|
|
|
|
if self._discovered:
|
|
return len(self.plugins)
|
|
|
|
generators_dir = self.plugins_dir
|
|
if not generators_dir.exists():
|
|
logger.warning(f"插件目录不存在: {generators_dir}")
|
|
return 0
|
|
|
|
sys.path.insert(0, str(generators_dir.parent))
|
|
|
|
for py_file in generators_dir.glob("*.py"):
|
|
if py_file.name.startswith("_"):
|
|
continue
|
|
|
|
try:
|
|
module_name = f"plugins.generators.{py_file.stem}"
|
|
spec = importlib.util.spec_from_file_location(module_name, str(py_file))
|
|
module = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(module)
|
|
|
|
for name, obj in vars(module).items():
|
|
if (isinstance(obj, type) and
|
|
issubclass(obj, BaseGenerator) and
|
|
obj != BaseGenerator and
|
|
obj.generator_id is not None):
|
|
|
|
self.plugins[obj.generator_id] = obj
|
|
logger.success(f"加载插件 [{obj.generator_id}]: {obj.generator_name}")
|
|
|
|
except Exception as e:
|
|
logger.exception(f"加载插件失败 {py_file}: {e}")
|
|
|
|
self._discovered = True
|
|
logger.info(f"插件扫描完成,共加载 {len(self.plugins)} 个生成器")
|
|
return len(self.plugins)
|
|
|
|
def get_generator(self, generator_id: str, params: Dict[str, Any] = None) -> Optional[BaseGenerator]:
|
|
if generator_id in self.plugins:
|
|
return self.plugins[generator_id](params)
|
|
logger.warning(f"找不到生成器插件: {generator_id}")
|
|
return None
|
|
|
|
def list_generators(self) -> List[Dict[str, Any]]:
|
|
result = []
|
|
for gid, cls in self.plugins.items():
|
|
result.append({
|
|
'id': gid,
|
|
'name': cls.generator_name,
|
|
'description': cls.description,
|
|
'version': cls.version,
|
|
'params_schema': cls.params_schema
|
|
})
|
|
return result
|