Files
desktop-transfer/translator.py
xiaji 2659fdd6ac feat(translator): 添加transformers作为备选模型加载方式
支持使用transformers库作为llama-cpp-python的备选方案加载模型
新增模型加载失败时的自动回退机制
更新requirements.txt添加transformers和torch依赖
2026-01-16 11:08:34 +08:00

187 lines
7.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
from utils.logger import logger
# 尝试导入llama_cpp如果失败则设置Llama为None
try:
from llama_cpp import Llama
llama_cpp_available = True
except ImportError:
logger.warning("llama-cpp-python库未找到将尝试使用transformers库")
Llama = None
llama_cpp_available = False
# 尝试导入transformers库
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
transformers_available = True
except ImportError:
logger.warning("transformers库未找到")
transformers_available = False
class Translator:
def __init__(self, model_path=None):
self.model = None
self.model_path = model_path
self.is_ready = False
self.model_name = ""
self.llama_cpp_available = llama_cpp_available
self.transformers_available = transformers_available
self.use_transformers = False
def load_model(self, model_path=None):
"""加载模型"""
if model_path:
self.model_path = model_path
if not self.model_path:
logger.error("未提供模型路径")
return False
try:
logger.info(f"开始加载模型: {self.model_path}")
if self.llama_cpp_available:
if os.path.exists(self.model_path):
try:
self.model = Llama(
model_path=self.model_path,
n_ctx=2048,
n_threads=4,
n_gpu_layers=100
)
self.use_transformers = False
self.is_ready = True
self.model_name = os.path.basename(self.model_path)
logger.info(f"模型加载成功: {self.model_name}")
return True
except Exception as e:
logger.warning(f"使用llama-cpp-python加载模型失败: {e}")
else:
logger.warning(f"模型文件不存在: {self.model_path}")
if self.transformers_available:
try:
from transformers import AutoConfig
if os.path.exists(self.model_path):
config = AutoConfig.from_pretrained(
self.model_path,
local_files_only=True,
trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_path,
local_files_only=True,
trust_remote_code=True,
torch_dtype="auto"
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_path,
local_files_only=True,
trust_remote_code=True
)
else:
config = AutoConfig.from_pretrained(
self.model_path,
trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_path,
trust_remote_code=True,
torch_dtype="auto"
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_path,
trust_remote_code=True
)
self.use_transformers = True
self.is_ready = True
self.model_name = self.model_path
logger.info(f"使用transformers加载模型成功: {self.model_name}")
return True
except Exception as e:
logger.error(f"使用transformers加载模型失败: {e}")
return False
logger.error("没有可用的模型加载方式")
return False
except Exception as e:
logger.error(f"模型加载失败: {e}")
self.is_ready = False
return False
def translate(self, text, context="", terms=None):
"""执行翻译"""
if not self.is_ready or not self.model:
logger.error("模型未就绪,无法执行翻译")
return ""
try:
# 构建翻译提示词
prompt = self._build_prompt(text, context, terms)
logger.info(f"开始翻译,输入长度: {len(text)} 字符")
if self.use_transformers:
import torch
inputs = self.tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=2048,
temperature=0.7,
top_p=0.95,
do_sample=True
)
translated_text = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
else:
output = self.model(
prompt,
max_tokens=2048,
temperature=0.7,
top_p=0.95,
stop=["\n原文:", "\n译文:", "\n###"]
)
translated_text = output["choices"][0]["text"].strip()
logger.info(f"翻译完成,输出长度: {len(translated_text)} 字符")
return translated_text
except Exception as e:
logger.error(f"翻译失败: {e}")
return ""
def _build_prompt(self, text, context="", terms=None):
"""构建翻译提示词"""
prompt = "你是一个专业的翻译助手,根据以下要求将中文翻译成英文:\n"
if context:
prompt += f"\n文本背景/场景介绍:{context}\n"
if terms:
prompt += "\n术语定义:\n"
for term in terms:
prompt += f"{term}\n"
prompt += f"\n原文:{text}\n译文:"
return prompt
def unload_model(self):
"""卸载模型"""
if self.model:
try:
del self.model
self.model = None
self.is_ready = False
logger.info("模型已卸载")
return True
except Exception as e:
logger.error(f"模型卸载失败: {e}")
return False
return True
def get_model_info(self):
"""获取模型信息"""
if self.is_ready:
return f"{self.model_name.split('.')[0]}"
return "未加载模型"