- 实现智能区域检测算法(灰度阈值 + 连续行判定) - 支持Umi-OCR和自定义HTTP OCR服务 - 添加热键触发和鼠标框选区域功能 - 实现自动滚动和智能停止逻辑 - 添加完整的README文档
605 lines
20 KiB
Python
605 lines
20 KiB
Python
"""
|
||
滚动截屏OCR工具
|
||
功能:通过热键激活,手动框选区域后,自动滚动截屏并进行OCR识别
|
||
"""
|
||
|
||
import json
|
||
import time
|
||
import base64
|
||
import io
|
||
import tempfile
|
||
from dataclasses import dataclass, field
|
||
from typing import List, Tuple, Optional, Callable
|
||
from pathlib import Path
|
||
|
||
import cv2
|
||
import numpy as np
|
||
import requests
|
||
from PIL import Image
|
||
import pyautogui
|
||
import keyboard
|
||
import mouse
|
||
from loguru import logger
|
||
|
||
from umi_ocr_client import UmiOCRClient, check_and_wait_for_service
|
||
|
||
|
||
@dataclass
|
||
class DivRegion:
|
||
"""div区域数据结构"""
|
||
top: int
|
||
bottom: int
|
||
left: int
|
||
right: int
|
||
text: str = ""
|
||
|
||
@property
|
||
def height(self) -> int:
|
||
return self.bottom - self.top
|
||
|
||
@property
|
||
def width(self) -> int:
|
||
return self.right - self.left
|
||
|
||
|
||
@dataclass
|
||
class GapInfo:
|
||
"""空白间隔信息"""
|
||
start_row: int
|
||
end_row: int
|
||
|
||
@property
|
||
def height(self) -> int:
|
||
return self.end_row - self.start_row
|
||
|
||
|
||
@dataclass
|
||
class AnalysisResult:
|
||
"""图像分析结果"""
|
||
divs: List[DivRegion] = field(default_factory=list)
|
||
gaps: List[GapInfo] = field(default_factory=list)
|
||
|
||
|
||
class Config:
|
||
"""配置类"""
|
||
# 热键设置
|
||
HOTKEY = "ctrl+f9"
|
||
|
||
# 图像分析参数
|
||
GRAY_THRESHOLD = 240 # 灰度阈值,接近白色的阈值
|
||
CONSECUTIVE_LINES = 3 # 连续多少行判定为空白
|
||
WHITE_PIXEL_RATIO = 0.9 # 一行中超过多少比例的像素为白色才认为是空白行
|
||
|
||
# OCR设置
|
||
OCR_ENGINE = "umi" # OCR引擎: "umi" 使用Umi-OCR, "http" 使用HTTP接口
|
||
OCR_API_URL = "http://localhost:8000/ocr" # HTTP OCR服务地址 (OCR_ENGINE=http时使用)
|
||
OCR_TIMEOUT = 30 # OCR请求超时时间
|
||
|
||
# Umi-OCR设置
|
||
UMI_OCR_HOST = "127.0.0.1"
|
||
UMI_OCR_PORT = 1224
|
||
|
||
# 滚动设置
|
||
SCROLL_DELAY = 0.5 # 滚动后等待渲染的时间(秒)
|
||
MAX_SCROLL_COUNT = 100 # 最大滚动次数,防止无限循环
|
||
|
||
# 输出设置
|
||
OUTPUT_DIR = "output"
|
||
|
||
|
||
class RegionSelector:
|
||
"""区域选择器 - 用于手动框选截图区域"""
|
||
|
||
def __init__(self):
|
||
self.start_pos: Optional[Tuple[int, int]] = None
|
||
self.end_pos: Optional[Tuple[int, int]] = None
|
||
self.is_selecting = False
|
||
|
||
def select_region(self) -> Tuple[int, int, int, int]:
|
||
"""
|
||
手动选择区域,返回 (left, top, right, bottom)
|
||
点击确定左上角,拖动释放确定右下角
|
||
"""
|
||
logger.info("请按住鼠标左键拖动选择区域...")
|
||
print("\n>>> 请按住鼠标左键拖动选择截图区域,释放后确定 <<<")
|
||
|
||
# 等待鼠标按下
|
||
while not mouse.is_pressed(button='left'):
|
||
time.sleep(0.01)
|
||
|
||
self.start_pos = mouse.get_position()
|
||
self.is_selecting = True
|
||
logger.info(f"选择开始位置: {self.start_pos}")
|
||
|
||
# 等待鼠标释放
|
||
while mouse.is_pressed(button='left'):
|
||
time.sleep(0.01)
|
||
|
||
self.end_pos = mouse.get_position()
|
||
self.is_selecting = False
|
||
logger.info(f"选择结束位置: {self.end_pos}")
|
||
|
||
# 计算边界
|
||
left = min(self.start_pos[0], self.end_pos[0])
|
||
top = min(self.start_pos[1], self.end_pos[1])
|
||
right = max(self.start_pos[0], self.end_pos[0])
|
||
bottom = max(self.start_pos[1], self.end_pos[1])
|
||
|
||
logger.info(f"选定区域: ({left}, {top}, {right}, {bottom}), 尺寸: {right-left}x{bottom-top}")
|
||
print(f"已选择区域: 左上角({left}, {top}), 右下角({right}, {bottom})")
|
||
|
||
return left, top, right, bottom
|
||
|
||
|
||
class ImageAnalyzer:
|
||
"""图像分析器 - 分析div边界和空白间隔"""
|
||
|
||
def __init__(self, config: Config):
|
||
self.config = config
|
||
|
||
def analyze(self, image: np.ndarray) -> AnalysisResult:
|
||
"""
|
||
分析图像,定位div边界
|
||
使用灰度阈值 + 连续行判定
|
||
"""
|
||
result = AnalysisResult()
|
||
|
||
# 转换为灰度图
|
||
if len(image.shape) == 3:
|
||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||
else:
|
||
gray = image
|
||
|
||
height, width = gray.shape
|
||
logger.debug(f"分析图像尺寸: {width}x{height}")
|
||
|
||
# 逐行分析
|
||
is_in_gap = False
|
||
gap_start = 0
|
||
div_start = 0
|
||
consecutive_blank = 0
|
||
|
||
for row in range(height):
|
||
# 计算当前行的白色像素比例
|
||
white_pixels = np.sum(gray[row] > self.config.GRAY_THRESHOLD)
|
||
white_ratio = white_pixels / width
|
||
|
||
is_blank = white_ratio > self.config.WHITE_PIXEL_RATIO
|
||
|
||
if is_blank:
|
||
consecutive_blank += 1
|
||
else:
|
||
# 如果之前是空白区域,现在进入div
|
||
if consecutive_blank >= self.config.CONSECUTIVE_LINES and not is_in_gap:
|
||
# 记录空白间隔
|
||
gap_end = row - consecutive_blank
|
||
gap = GapInfo(start_row=gap_start, end_row=gap_end)
|
||
result.gaps.append(gap)
|
||
logger.debug(f"发现空白间隔: 行 {gap.start_row}-{gap.end_row}, 高度 {gap.height}")
|
||
|
||
# 记录div开始
|
||
div_start = row
|
||
is_in_gap = True
|
||
|
||
consecutive_blank = 0
|
||
gap_start = row
|
||
|
||
# 如果连续多行都是空白,认为是间隔区域
|
||
if consecutive_blank >= self.config.CONSECUTIVE_LINES and is_in_gap:
|
||
# 记录div结束
|
||
div_end = row - consecutive_blank
|
||
if div_end > div_start:
|
||
div = DivRegion(
|
||
top=div_start,
|
||
bottom=div_end,
|
||
left=0,
|
||
right=width
|
||
)
|
||
result.divs.append(div)
|
||
logger.debug(f"发现div区域: 行 {div.top}-{div.bottom}, 高度 {div.height}")
|
||
|
||
is_in_gap = False
|
||
gap_start = row - consecutive_blank + 1
|
||
|
||
# 处理最后一个div(如果图像不以空白结束)
|
||
if not is_in_gap and div_start < height - consecutive_blank:
|
||
div = DivRegion(
|
||
top=div_start,
|
||
bottom=height - consecutive_blank,
|
||
left=0,
|
||
right=width
|
||
)
|
||
result.divs.append(div)
|
||
logger.debug(f"发现末尾div区域: 行 {div.top}-{div.bottom}, 高度 {div.height}")
|
||
|
||
logger.info(f"分析完成: 发现 {len(result.divs)} 个div, {len(result.gaps)} 个空白间隔")
|
||
return result
|
||
|
||
def calculate_scroll_distance(self, result: AnalysisResult) -> int:
|
||
"""
|
||
根据分析结果计算滚动距离
|
||
策略:滚动到下一个div的顶部
|
||
"""
|
||
if not result.divs:
|
||
logger.warning("未检测到div,使用默认滚动距离")
|
||
return 100
|
||
|
||
# 获取第一个div和第一个空白间隔
|
||
first_div = result.divs[0]
|
||
|
||
# 如果有空白间隔,滚动距离为第一个div高度 + 其后的空白间隔
|
||
scroll_distance = first_div.height
|
||
|
||
# 查找第一个div之后的空白间隔
|
||
for gap in result.gaps:
|
||
if gap.start_row >= first_div.bottom:
|
||
scroll_distance += gap.height
|
||
break
|
||
|
||
# 添加一些重叠,确保连续性
|
||
overlap = min(20, first_div.height // 4)
|
||
scroll_distance = max(scroll_distance - overlap, 50)
|
||
|
||
logger.info(f"计算滚动距离: {scroll_distance} 像素")
|
||
return int(scroll_distance)
|
||
|
||
|
||
class OCREngine:
|
||
"""OCR引擎 - 调用OCR服务识别文字"""
|
||
|
||
def __init__(self, config: Config):
|
||
self.config = config
|
||
self.umi_client: Optional[UmiOCRClient] = None
|
||
|
||
if config.OCR_ENGINE == "umi":
|
||
self.umi_client = UmiOCRClient(
|
||
host=config.UMI_OCR_HOST,
|
||
port=config.UMI_OCR_PORT
|
||
)
|
||
|
||
def _recognize_with_http(self, image: np.ndarray) -> List[str]:
|
||
"""使用HTTP接口进行OCR识别"""
|
||
try:
|
||
# 将numpy数组转换为PIL Image
|
||
if len(image.shape) == 3:
|
||
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||
else:
|
||
pil_image = Image.fromarray(image)
|
||
|
||
# 转换为base64
|
||
buffered = io.BytesIO()
|
||
pil_image.save(buffered, format="PNG")
|
||
img_base64 = base64.b64encode(buffered.getvalue()).decode()
|
||
|
||
# 调用OCR API
|
||
response = requests.post(
|
||
self.config.OCR_API_URL,
|
||
json={"image": img_base64},
|
||
timeout=self.config.OCR_TIMEOUT
|
||
)
|
||
response.raise_for_status()
|
||
|
||
data = response.json()
|
||
texts = data.get("texts", [])
|
||
return texts
|
||
|
||
except requests.exceptions.ConnectionError:
|
||
logger.error(f"无法连接到OCR服务: {self.config.OCR_API_URL}")
|
||
return []
|
||
except Exception as e:
|
||
logger.error(f"HTTP OCR识别失败: {e}")
|
||
return []
|
||
|
||
def _recognize_with_umi(self, image: np.ndarray) -> List[str]:
|
||
"""使用Umi-OCR进行识别"""
|
||
if not self.umi_client:
|
||
logger.error("Umi-OCR客户端未初始化")
|
||
return []
|
||
|
||
try:
|
||
# 将图像保存为临时文件
|
||
if len(image.shape) == 3:
|
||
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||
else:
|
||
pil_image = Image.fromarray(image)
|
||
|
||
# 创建临时文件
|
||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
|
||
tmp_path = tmp_file.name
|
||
pil_image.save(tmp_path, format="PNG")
|
||
|
||
try:
|
||
# 调用Umi-OCR识别
|
||
text = self.umi_client.recognize_image(tmp_path, timeout=self.config.OCR_TIMEOUT)
|
||
if text:
|
||
# 按行分割
|
||
lines = [line.strip() for line in text.split('\n') if line.strip()]
|
||
return lines
|
||
return []
|
||
finally:
|
||
# 删除临时文件
|
||
try:
|
||
Path(tmp_path).unlink()
|
||
except Exception:
|
||
pass
|
||
|
||
except Exception as e:
|
||
logger.error(f"Umi-OCR识别失败: {e}")
|
||
return []
|
||
|
||
def recognize(self, image: np.ndarray) -> List[str]:
|
||
"""
|
||
对图像进行OCR识别
|
||
返回识别到的文字列表
|
||
"""
|
||
if self.config.OCR_ENGINE == "umi":
|
||
texts = self._recognize_with_umi(image)
|
||
else:
|
||
texts = self._recognize_with_http(image)
|
||
|
||
logger.info(f"OCR识别完成,识别到 {len(texts)} 段文字")
|
||
return texts
|
||
|
||
def recognize_divs(self, image: np.ndarray, divs: List[DivRegion]) -> List[str]:
|
||
"""
|
||
对每个div区域分别进行OCR识别
|
||
"""
|
||
all_texts = []
|
||
for i, div in enumerate(divs):
|
||
# 截取div区域
|
||
div_image = image[div.top:div.bottom, div.left:div.right]
|
||
texts = self.recognize(div_image)
|
||
all_texts.extend(texts)
|
||
logger.debug(f"Div {i+1} OCR结果: {texts}")
|
||
return all_texts
|
||
|
||
def check_service(self) -> bool:
|
||
"""检查OCR服务是否可用"""
|
||
if self.config.OCR_ENGINE == "umi":
|
||
if not self.umi_client:
|
||
return False
|
||
return self.umi_client.is_service_running()
|
||
else:
|
||
try:
|
||
response = requests.get(self.config.OCR_API_URL.replace('/ocr', '/health'), timeout=2)
|
||
return response.status_code == 200
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
class ScrollCaptureOCR:
|
||
"""滚动截屏OCR主类"""
|
||
|
||
def __init__(self):
|
||
self.config = Config()
|
||
self.region_selector = RegionSelector()
|
||
self.image_analyzer = ImageAnalyzer(self.config)
|
||
self.ocr_engine = OCREngine(self.config)
|
||
|
||
self.capture_region: Optional[Tuple[int, int, int, int]] = None
|
||
self.previous_ocr_result: List[str] = []
|
||
self.scroll_count = 0
|
||
self.all_results: List[dict] = []
|
||
|
||
# 创建输出目录
|
||
Path(self.config.OUTPUT_DIR).mkdir(exist_ok=True)
|
||
|
||
def capture_screen(self) -> np.ndarray:
|
||
"""截取指定区域的屏幕"""
|
||
if not self.capture_region:
|
||
raise ValueError("未设置截图区域")
|
||
|
||
left, top, right, bottom = self.capture_region
|
||
screenshot = pyautogui.screenshot(region=(left, top, right - left, bottom - top))
|
||
return cv2.cvtColor(np.array(screenshot), cv2.COLOR_RGB2BGR)
|
||
|
||
def scroll_screen(self, distance: int):
|
||
"""在截图区域执行滚动"""
|
||
if not self.capture_region:
|
||
return
|
||
|
||
# 将鼠标移动到截图区域中央
|
||
left, top, right, bottom = self.capture_region
|
||
center_x = (left + right) // 2
|
||
center_y = (top + bottom) // 2
|
||
|
||
pyautogui.moveTo(center_x, center_y)
|
||
time.sleep(0.1)
|
||
|
||
# 执行滚动
|
||
pyautogui.scroll(-distance)
|
||
logger.info(f"向下滚动 {distance} 像素")
|
||
|
||
# 等待页面渲染
|
||
time.sleep(self.config.SCROLL_DELAY)
|
||
|
||
def check_duplicate(self, current_texts: List[str]) -> bool:
|
||
"""
|
||
检查当前OCR结果是否与上一次相同
|
||
用于判断是否到达底部
|
||
"""
|
||
if not self.previous_ocr_result:
|
||
return False
|
||
|
||
# 简单比较:如果文字列表完全相同,认为是重复
|
||
is_duplicate = current_texts == self.previous_ocr_result
|
||
|
||
if is_duplicate:
|
||
logger.info("检测到OCR结果重复,可能已到达底部")
|
||
|
||
return is_duplicate
|
||
|
||
def save_result(self, scroll_index: int, image: np.ndarray, texts: List[str]):
|
||
"""保存截图和OCR结果"""
|
||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||
|
||
# 保存图片
|
||
image_path = Path(self.config.OUTPUT_DIR) / f"capture_{timestamp}_{scroll_index:03d}.png"
|
||
cv2.imwrite(str(image_path), image)
|
||
|
||
# 保存OCR结果
|
||
result = {
|
||
"index": scroll_index,
|
||
"timestamp": timestamp,
|
||
"image_path": str(image_path),
|
||
"texts": texts
|
||
}
|
||
self.all_results.append(result)
|
||
|
||
logger.info(f"保存结果: {image_path}, 识别文字数: {len(texts)}")
|
||
|
||
def save_final_result(self):
|
||
"""保存所有结果到JSON文件"""
|
||
output_path = Path(self.config.OUTPUT_DIR) / f"all_results_{time.strftime('%Y%m%d_%H%M%S')}.json"
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
json.dump(self.all_results, f, ensure_ascii=False, indent=2)
|
||
logger.info(f"所有结果已保存到: {output_path}")
|
||
print(f"\n所有结果已保存到: {output_path}")
|
||
|
||
def process_once(self) -> bool:
|
||
"""
|
||
执行一次处理循环
|
||
返回False表示应该停止
|
||
"""
|
||
logger.info(f"=== 第 {self.scroll_count + 1} 次截屏 ===")
|
||
print(f"\n>>> 第 {self.scroll_count + 1} 次截屏处理中...")
|
||
|
||
# 1. 截取当前屏幕
|
||
image = self.capture_screen()
|
||
logger.info(f"截图完成,尺寸: {image.shape[1]}x{image.shape[0]}")
|
||
|
||
# 2. 分析图像,定位div边界
|
||
analysis = self.image_analyzer.analyze(image)
|
||
|
||
if not analysis.divs:
|
||
logger.warning("未检测到任何div区域,可能已到达底部或区域选择有误")
|
||
print("警告: 未检测到内容区域")
|
||
return False
|
||
|
||
# 3. OCR提取文字
|
||
current_texts = self.ocr_engine.recognize_divs(image, analysis.divs)
|
||
print(f"识别到 {len(current_texts)} 段文字")
|
||
for i, text in enumerate(current_texts[:3], 1):
|
||
preview = text[:50] + "..." if len(text) > 50 else text
|
||
print(f" [{i}] {preview}")
|
||
if len(current_texts) > 3:
|
||
print(f" ... 还有 {len(current_texts) - 3} 段文字")
|
||
|
||
# 4. 保存结果
|
||
self.save_result(self.scroll_count, image, current_texts)
|
||
|
||
# 5. 判断是否到达底部(OCR结果重复)
|
||
if self.check_duplicate(current_texts):
|
||
print("\n>>> 检测到内容重复,已到达底部,处理完成 <<<")
|
||
return False
|
||
|
||
self.previous_ocr_result = current_texts
|
||
|
||
# 6. 计算滚动距离
|
||
scroll_distance = self.image_analyzer.calculate_scroll_distance(analysis)
|
||
|
||
# 7. 执行滚动
|
||
self.scroll_screen(scroll_distance)
|
||
|
||
self.scroll_count += 1
|
||
|
||
# 检查最大滚动次数
|
||
if self.scroll_count >= self.config.MAX_SCROLL_COUNT:
|
||
logger.warning(f"达到最大滚动次数限制 ({self.config.MAX_SCROLL_COUNT})")
|
||
print(f"\n>>> 达到最大滚动次数限制,处理完成 <<<")
|
||
return False
|
||
|
||
return True
|
||
|
||
def run(self):
|
||
"""主运行流程"""
|
||
print("=" * 60)
|
||
print("滚动截屏OCR工具")
|
||
print("=" * 60)
|
||
print(f"\n使用说明:")
|
||
print(f"1. 按下热键 {self.config.HOTKEY} 启动")
|
||
print(f"2. 按住鼠标左键拖动选择截图区域")
|
||
print(f"3. 程序将自动滚动截屏并进行OCR识别")
|
||
print(f"4. 当检测到重复内容时自动停止")
|
||
print(f"5. 结果将保存在 '{self.config.OUTPUT_DIR}' 目录")
|
||
print("\n" + "=" * 60)
|
||
|
||
logger.info("程序启动,等待热键触发...")
|
||
print(f"\n>>> 等待热键 {self.config.HOTKEY} 启动... <<<")
|
||
|
||
# 注册热键
|
||
keyboard.add_hotkey(self.config.HOTKEY, self._on_hotkey)
|
||
|
||
# 保持程序运行
|
||
try:
|
||
while True:
|
||
time.sleep(0.1)
|
||
except KeyboardInterrupt:
|
||
logger.info("程序被用户中断")
|
||
print("\n>>> 程序已停止 <<<")
|
||
|
||
def _on_hotkey(self):
|
||
"""热键回调函数"""
|
||
logger.info("热键触发,开始处理")
|
||
print(f"\n{'='*60}")
|
||
print("热键已触发!")
|
||
|
||
# 检查OCR服务
|
||
print("\n>>> 检查OCR服务... <<<")
|
||
if not self.ocr_engine.check_service():
|
||
if self.config.OCR_ENGINE == "umi":
|
||
print("✗ Umi-OCR服务未运行")
|
||
print("请先启动Umi-OCR软件并开启HTTP服务:")
|
||
print(" 1. 打开Umi-OCR")
|
||
print(" 2. 进入 设置 -> HTTP接口")
|
||
print(" 3. 勾选 '启用HTTP服务'")
|
||
print(f" 4. 确保端口为 {self.config.UMI_OCR_PORT}")
|
||
else:
|
||
print(f"✗ OCR服务未运行: {self.config.OCR_API_URL}")
|
||
return
|
||
|
||
print("✓ OCR服务运行中")
|
||
|
||
# 选择区域
|
||
try:
|
||
self.capture_region = self.region_selector.select_region()
|
||
except Exception as e:
|
||
logger.error(f"区域选择失败: {e}")
|
||
print(f"区域选择失败: {e}")
|
||
return
|
||
|
||
# 重置状态
|
||
self.previous_ocr_result = []
|
||
self.scroll_count = 0
|
||
self.all_results = []
|
||
|
||
print(f"\n>>> 开始自动滚动截屏和OCR识别... <<<")
|
||
|
||
# 循环处理
|
||
try:
|
||
while self.process_once():
|
||
pass
|
||
except Exception as e:
|
||
logger.error(f"处理过程中出错: {e}", exc_info=True)
|
||
print(f"\n错误: {e}")
|
||
|
||
# 保存最终结果
|
||
if self.all_results:
|
||
self.save_final_result()
|
||
print(f"\n共处理 {len(self.all_results)} 次截屏")
|
||
print(f"结果保存在: {Path(self.config.OUTPUT_DIR).absolute()}")
|
||
|
||
print(f"\n{'='*60}")
|
||
print(">>> 等待下一次热键触发... <<<")
|
||
logger.info("处理完成,等待下一次热键触发")
|
||
|
||
|
||
def main():
|
||
"""入口函数"""
|
||
app = ScrollCaptureOCR()
|
||
app.run()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|