Files
work-secretfile-selfcheck/UmiOCR-data/py_src/mission/mission_ocr.py

167 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# ===============================================
# =============== OCR - 任务管理器 ===============
# ===============================================
"""
一种任务管理器为全局单例,不同标签页要执行同一种任务,要访问对应的任务管理器。
任务管理器中有一个引擎API实例所有任务均使用该API。
标签页可以向任务管理器提交一组任务队列,其中包含了每一项任务的信息,及总体的参数和回调。
"""
import os
from umi_log import logger
from .mission import Mission
from ..ocr.tbpu import getParser, IgnoreArea
from ..ocr.api import getApiOcr, getLocalOptions
from ..utils.utils import argdIntConvert
# 合法文件后缀
ImageSuf = [
".jpg",
".jpe",
".jpeg",
".jfif",
".png",
".webp",
".bmp",
".tif",
".tiff",
]
class __MissionOcrClass(Mission):
def __init__(self):
super().__init__()
self._apiKey = "" # 当前api类型
self._api = None # 当前引擎api对象
# ========================= 【重载】 =========================
# msnInfo: { 回调函数"onXX", 参数"argd":{"tbpu.xx", "ocr.xx"} }
# msnList: [ { "path", "bytes", "base64" } ]
def addMissionList(self, msnInfo, msnList): # 添加任务列表
# 实例化 tbpu 文本后处理模块
msnInfo["tbpu"] = []
argd = msnInfo["argd"]
# 忽略区域
if "tbpu.ignoreArea" in argd:
iArea = argd["tbpu.ignoreArea"]
if isinstance(iArea, list) and len(iArea) > 0:
msnInfo["tbpu"].append(IgnoreArea(iArea))
# 获取排版解析器对象
if "tbpu.parser" in argd:
msnInfo["tbpu"].append(getParser(argd["tbpu.parser"]))
# 检查任务合法性
for i in range(len(msnList) - 1, -1, -1):
if "path" in msnList[i]:
p = msnList[i]["path"]
if os.path.splitext(p)[-1].lower() not in ImageSuf:
logger.warning(f"添加OCR任务时{i}项的路径path不是图片{p}")
del msnList[i]
elif "bytes" not in msnList[i] and "base64" not in msnList[i]:
logger.warning(f"添加OCR任务时{i}项不含 path、bytes、base64")
del msnList[i]
return super().addMissionList(msnInfo, msnList)
def msnPreTask(self, msnInfo): # 用于更新api和参数
# 检查API对象
if not self._api:
return "[Error] MissionOCR: API object is None."
# 检查参数更新
startInfo = self._dictShortKey(msnInfo["argd"])
# 恢复int类型
argdIntConvert(startInfo)
msg = self._api.start(startInfo)
if msg.startswith("[Error]"):
logger.error(f"OCR引擎启动失败 {msg}")
return msg # 更新失败,结束该队列
else:
return "" # 更新成功 TODO: continue
def msnTask(self, msnInfo, msn): # 执行msn
if "path" in msn:
res = self._api.runPath(msn["path"])
res["path"] = msn["path"] # 结果字典中补充参数
elif "bytes" in msn:
res = self._api.runBytes(msn["bytes"])
elif "base64" in msn:
res = self._api.runBase64(msn["base64"])
else:
res = {
"code": 901,
"data": f"[Error] Unknown task type.\n【异常】未知的任务类型。\n{str(msn)[:100]}",
}
# 任务成功时的后处理
if res["code"] == 100:
# 计算平均置信度
score, num = 0, 0
for r in res["data"]:
score += r["score"]
num += 1
if num > 0:
score /= num
res["score"] = score
# 执行 tbpu
if msnInfo["tbpu"]:
for tbpu in msnInfo["tbpu"]:
res["data"] = tbpu.run(res["data"])
# 如果忽略区域等处理将所有文本删除则结束tbpu
if not res["data"]:
res["code"] = 101
res["data"] = ""
break
return res
# ========================= 【qml接口】 =========================
def getStatus(self): # 返回当前状态
return {
"apiKey": self._apiKey,
"missionListsLength": self.getMissionListsLength(),
}
def setApi(self, apiKey, info): # 设置api
# 成功返回 [Success] ,失败返回 [Error] 开头的字符串
self._apiKey = apiKey
info = self._dictShortKey(info)
# 如果api对象已启动则先停止
if self._api:
self._api.stop()
# 获取新api对象
res = getApiOcr(apiKey, info)
# 失败
if isinstance(res, str):
self._apiKey = ""
self._api = None
return res
# 成功
else:
self._api = res
return "[Success]"
# 将字典中配置项的长key转为短key
# 如: ocr.win32_PaddleOCR-json.path → path
def _dictShortKey(self, d):
newD = {}
key1 = "ocr."
key2 = key1 + self._apiKey + "."
for k in d:
if k.startswith(key2):
newD[k[len(key2) :]] = d[k]
elif k.startswith(key1):
newD[k[len(key1) :]] = d[k]
return newD
# ========================= 【qml接口】 =========================
def getLocalOptions(self):
if self._apiKey:
return getLocalOptions(self._apiKey)
else:
return {}
# 全局 OCR任务管理器
MissionOCR = __MissionOcrClass()