Files
long-screen-cut/main.py
xiaji 8600c0f576 feat: 初始提交 - 滚动截屏OCR工具
- 实现智能区域检测算法(灰度阈值 + 连续行判定)
- 支持Umi-OCR和自定义HTTP OCR服务
- 添加热键触发和鼠标框选区域功能
- 实现自动滚动和智能停止逻辑
- 添加完整的README文档
2026-03-06 15:07:51 +08:00

605 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
滚动截屏OCR工具
功能通过热键激活手动框选区域后自动滚动截屏并进行OCR识别
"""
import json
import time
import base64
import io
import tempfile
from dataclasses import dataclass, field
from typing import List, Tuple, Optional, Callable
from pathlib import Path
import cv2
import numpy as np
import requests
from PIL import Image
import pyautogui
import keyboard
import mouse
from loguru import logger
from umi_ocr_client import UmiOCRClient, check_and_wait_for_service
@dataclass
class DivRegion:
"""div区域数据结构"""
top: int
bottom: int
left: int
right: int
text: str = ""
@property
def height(self) -> int:
return self.bottom - self.top
@property
def width(self) -> int:
return self.right - self.left
@dataclass
class GapInfo:
"""空白间隔信息"""
start_row: int
end_row: int
@property
def height(self) -> int:
return self.end_row - self.start_row
@dataclass
class AnalysisResult:
"""图像分析结果"""
divs: List[DivRegion] = field(default_factory=list)
gaps: List[GapInfo] = field(default_factory=list)
class Config:
"""配置类"""
# 热键设置
HOTKEY = "ctrl+f9"
# 图像分析参数
GRAY_THRESHOLD = 240 # 灰度阈值,接近白色的阈值
CONSECUTIVE_LINES = 3 # 连续多少行判定为空白
WHITE_PIXEL_RATIO = 0.9 # 一行中超过多少比例的像素为白色才认为是空白行
# OCR设置
OCR_ENGINE = "umi" # OCR引擎: "umi" 使用Umi-OCR, "http" 使用HTTP接口
OCR_API_URL = "http://localhost:8000/ocr" # HTTP OCR服务地址 (OCR_ENGINE=http时使用)
OCR_TIMEOUT = 30 # OCR请求超时时间
# Umi-OCR设置
UMI_OCR_HOST = "127.0.0.1"
UMI_OCR_PORT = 1224
# 滚动设置
SCROLL_DELAY = 0.5 # 滚动后等待渲染的时间(秒)
MAX_SCROLL_COUNT = 100 # 最大滚动次数,防止无限循环
# 输出设置
OUTPUT_DIR = "output"
class RegionSelector:
"""区域选择器 - 用于手动框选截图区域"""
def __init__(self):
self.start_pos: Optional[Tuple[int, int]] = None
self.end_pos: Optional[Tuple[int, int]] = None
self.is_selecting = False
def select_region(self) -> Tuple[int, int, int, int]:
"""
手动选择区域,返回 (left, top, right, bottom)
点击确定左上角,拖动释放确定右下角
"""
logger.info("请按住鼠标左键拖动选择区域...")
print("\n>>> 请按住鼠标左键拖动选择截图区域,释放后确定 <<<")
# 等待鼠标按下
while not mouse.is_pressed(button='left'):
time.sleep(0.01)
self.start_pos = mouse.get_position()
self.is_selecting = True
logger.info(f"选择开始位置: {self.start_pos}")
# 等待鼠标释放
while mouse.is_pressed(button='left'):
time.sleep(0.01)
self.end_pos = mouse.get_position()
self.is_selecting = False
logger.info(f"选择结束位置: {self.end_pos}")
# 计算边界
left = min(self.start_pos[0], self.end_pos[0])
top = min(self.start_pos[1], self.end_pos[1])
right = max(self.start_pos[0], self.end_pos[0])
bottom = max(self.start_pos[1], self.end_pos[1])
logger.info(f"选定区域: ({left}, {top}, {right}, {bottom}), 尺寸: {right-left}x{bottom-top}")
print(f"已选择区域: 左上角({left}, {top}), 右下角({right}, {bottom})")
return left, top, right, bottom
class ImageAnalyzer:
"""图像分析器 - 分析div边界和空白间隔"""
def __init__(self, config: Config):
self.config = config
def analyze(self, image: np.ndarray) -> AnalysisResult:
"""
分析图像定位div边界
使用灰度阈值 + 连续行判定
"""
result = AnalysisResult()
# 转换为灰度图
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
gray = image
height, width = gray.shape
logger.debug(f"分析图像尺寸: {width}x{height}")
# 逐行分析
is_in_gap = False
gap_start = 0
div_start = 0
consecutive_blank = 0
for row in range(height):
# 计算当前行的白色像素比例
white_pixels = np.sum(gray[row] > self.config.GRAY_THRESHOLD)
white_ratio = white_pixels / width
is_blank = white_ratio > self.config.WHITE_PIXEL_RATIO
if is_blank:
consecutive_blank += 1
else:
# 如果之前是空白区域现在进入div
if consecutive_blank >= self.config.CONSECUTIVE_LINES and not is_in_gap:
# 记录空白间隔
gap_end = row - consecutive_blank
gap = GapInfo(start_row=gap_start, end_row=gap_end)
result.gaps.append(gap)
logger.debug(f"发现空白间隔: 行 {gap.start_row}-{gap.end_row}, 高度 {gap.height}")
# 记录div开始
div_start = row
is_in_gap = True
consecutive_blank = 0
gap_start = row
# 如果连续多行都是空白,认为是间隔区域
if consecutive_blank >= self.config.CONSECUTIVE_LINES and is_in_gap:
# 记录div结束
div_end = row - consecutive_blank
if div_end > div_start:
div = DivRegion(
top=div_start,
bottom=div_end,
left=0,
right=width
)
result.divs.append(div)
logger.debug(f"发现div区域: 行 {div.top}-{div.bottom}, 高度 {div.height}")
is_in_gap = False
gap_start = row - consecutive_blank + 1
# 处理最后一个div如果图像不以空白结束
if not is_in_gap and div_start < height - consecutive_blank:
div = DivRegion(
top=div_start,
bottom=height - consecutive_blank,
left=0,
right=width
)
result.divs.append(div)
logger.debug(f"发现末尾div区域: 行 {div.top}-{div.bottom}, 高度 {div.height}")
logger.info(f"分析完成: 发现 {len(result.divs)} 个div, {len(result.gaps)} 个空白间隔")
return result
def calculate_scroll_distance(self, result: AnalysisResult) -> int:
"""
根据分析结果计算滚动距离
策略滚动到下一个div的顶部
"""
if not result.divs:
logger.warning("未检测到div使用默认滚动距离")
return 100
# 获取第一个div和第一个空白间隔
first_div = result.divs[0]
# 如果有空白间隔滚动距离为第一个div高度 + 其后的空白间隔
scroll_distance = first_div.height
# 查找第一个div之后的空白间隔
for gap in result.gaps:
if gap.start_row >= first_div.bottom:
scroll_distance += gap.height
break
# 添加一些重叠,确保连续性
overlap = min(20, first_div.height // 4)
scroll_distance = max(scroll_distance - overlap, 50)
logger.info(f"计算滚动距离: {scroll_distance} 像素")
return int(scroll_distance)
class OCREngine:
"""OCR引擎 - 调用OCR服务识别文字"""
def __init__(self, config: Config):
self.config = config
self.umi_client: Optional[UmiOCRClient] = None
if config.OCR_ENGINE == "umi":
self.umi_client = UmiOCRClient(
host=config.UMI_OCR_HOST,
port=config.UMI_OCR_PORT
)
def _recognize_with_http(self, image: np.ndarray) -> List[str]:
"""使用HTTP接口进行OCR识别"""
try:
# 将numpy数组转换为PIL Image
if len(image.shape) == 3:
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
else:
pil_image = Image.fromarray(image)
# 转换为base64
buffered = io.BytesIO()
pil_image.save(buffered, format="PNG")
img_base64 = base64.b64encode(buffered.getvalue()).decode()
# 调用OCR API
response = requests.post(
self.config.OCR_API_URL,
json={"image": img_base64},
timeout=self.config.OCR_TIMEOUT
)
response.raise_for_status()
data = response.json()
texts = data.get("texts", [])
return texts
except requests.exceptions.ConnectionError:
logger.error(f"无法连接到OCR服务: {self.config.OCR_API_URL}")
return []
except Exception as e:
logger.error(f"HTTP OCR识别失败: {e}")
return []
def _recognize_with_umi(self, image: np.ndarray) -> List[str]:
"""使用Umi-OCR进行识别"""
if not self.umi_client:
logger.error("Umi-OCR客户端未初始化")
return []
try:
# 将图像保存为临时文件
if len(image.shape) == 3:
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
else:
pil_image = Image.fromarray(image)
# 创建临时文件
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
tmp_path = tmp_file.name
pil_image.save(tmp_path, format="PNG")
try:
# 调用Umi-OCR识别
text = self.umi_client.recognize_image(tmp_path, timeout=self.config.OCR_TIMEOUT)
if text:
# 按行分割
lines = [line.strip() for line in text.split('\n') if line.strip()]
return lines
return []
finally:
# 删除临时文件
try:
Path(tmp_path).unlink()
except Exception:
pass
except Exception as e:
logger.error(f"Umi-OCR识别失败: {e}")
return []
def recognize(self, image: np.ndarray) -> List[str]:
"""
对图像进行OCR识别
返回识别到的文字列表
"""
if self.config.OCR_ENGINE == "umi":
texts = self._recognize_with_umi(image)
else:
texts = self._recognize_with_http(image)
logger.info(f"OCR识别完成识别到 {len(texts)} 段文字")
return texts
def recognize_divs(self, image: np.ndarray, divs: List[DivRegion]) -> List[str]:
"""
对每个div区域分别进行OCR识别
"""
all_texts = []
for i, div in enumerate(divs):
# 截取div区域
div_image = image[div.top:div.bottom, div.left:div.right]
texts = self.recognize(div_image)
all_texts.extend(texts)
logger.debug(f"Div {i+1} OCR结果: {texts}")
return all_texts
def check_service(self) -> bool:
"""检查OCR服务是否可用"""
if self.config.OCR_ENGINE == "umi":
if not self.umi_client:
return False
return self.umi_client.is_service_running()
else:
try:
response = requests.get(self.config.OCR_API_URL.replace('/ocr', '/health'), timeout=2)
return response.status_code == 200
except Exception:
return False
class ScrollCaptureOCR:
"""滚动截屏OCR主类"""
def __init__(self):
self.config = Config()
self.region_selector = RegionSelector()
self.image_analyzer = ImageAnalyzer(self.config)
self.ocr_engine = OCREngine(self.config)
self.capture_region: Optional[Tuple[int, int, int, int]] = None
self.previous_ocr_result: List[str] = []
self.scroll_count = 0
self.all_results: List[dict] = []
# 创建输出目录
Path(self.config.OUTPUT_DIR).mkdir(exist_ok=True)
def capture_screen(self) -> np.ndarray:
"""截取指定区域的屏幕"""
if not self.capture_region:
raise ValueError("未设置截图区域")
left, top, right, bottom = self.capture_region
screenshot = pyautogui.screenshot(region=(left, top, right - left, bottom - top))
return cv2.cvtColor(np.array(screenshot), cv2.COLOR_RGB2BGR)
def scroll_screen(self, distance: int):
"""在截图区域执行滚动"""
if not self.capture_region:
return
# 将鼠标移动到截图区域中央
left, top, right, bottom = self.capture_region
center_x = (left + right) // 2
center_y = (top + bottom) // 2
pyautogui.moveTo(center_x, center_y)
time.sleep(0.1)
# 执行滚动
pyautogui.scroll(-distance)
logger.info(f"向下滚动 {distance} 像素")
# 等待页面渲染
time.sleep(self.config.SCROLL_DELAY)
def check_duplicate(self, current_texts: List[str]) -> bool:
"""
检查当前OCR结果是否与上一次相同
用于判断是否到达底部
"""
if not self.previous_ocr_result:
return False
# 简单比较:如果文字列表完全相同,认为是重复
is_duplicate = current_texts == self.previous_ocr_result
if is_duplicate:
logger.info("检测到OCR结果重复可能已到达底部")
return is_duplicate
def save_result(self, scroll_index: int, image: np.ndarray, texts: List[str]):
"""保存截图和OCR结果"""
timestamp = time.strftime("%Y%m%d_%H%M%S")
# 保存图片
image_path = Path(self.config.OUTPUT_DIR) / f"capture_{timestamp}_{scroll_index:03d}.png"
cv2.imwrite(str(image_path), image)
# 保存OCR结果
result = {
"index": scroll_index,
"timestamp": timestamp,
"image_path": str(image_path),
"texts": texts
}
self.all_results.append(result)
logger.info(f"保存结果: {image_path}, 识别文字数: {len(texts)}")
def save_final_result(self):
"""保存所有结果到JSON文件"""
output_path = Path(self.config.OUTPUT_DIR) / f"all_results_{time.strftime('%Y%m%d_%H%M%S')}.json"
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(self.all_results, f, ensure_ascii=False, indent=2)
logger.info(f"所有结果已保存到: {output_path}")
print(f"\n所有结果已保存到: {output_path}")
def process_once(self) -> bool:
"""
执行一次处理循环
返回False表示应该停止
"""
logger.info(f"=== 第 {self.scroll_count + 1} 次截屏 ===")
print(f"\n>>> 第 {self.scroll_count + 1} 次截屏处理中...")
# 1. 截取当前屏幕
image = self.capture_screen()
logger.info(f"截图完成,尺寸: {image.shape[1]}x{image.shape[0]}")
# 2. 分析图像定位div边界
analysis = self.image_analyzer.analyze(image)
if not analysis.divs:
logger.warning("未检测到任何div区域可能已到达底部或区域选择有误")
print("警告: 未检测到内容区域")
return False
# 3. OCR提取文字
current_texts = self.ocr_engine.recognize_divs(image, analysis.divs)
print(f"识别到 {len(current_texts)} 段文字")
for i, text in enumerate(current_texts[:3], 1):
preview = text[:50] + "..." if len(text) > 50 else text
print(f" [{i}] {preview}")
if len(current_texts) > 3:
print(f" ... 还有 {len(current_texts) - 3} 段文字")
# 4. 保存结果
self.save_result(self.scroll_count, image, current_texts)
# 5. 判断是否到达底部OCR结果重复
if self.check_duplicate(current_texts):
print("\n>>> 检测到内容重复,已到达底部,处理完成 <<<")
return False
self.previous_ocr_result = current_texts
# 6. 计算滚动距离
scroll_distance = self.image_analyzer.calculate_scroll_distance(analysis)
# 7. 执行滚动
self.scroll_screen(scroll_distance)
self.scroll_count += 1
# 检查最大滚动次数
if self.scroll_count >= self.config.MAX_SCROLL_COUNT:
logger.warning(f"达到最大滚动次数限制 ({self.config.MAX_SCROLL_COUNT})")
print(f"\n>>> 达到最大滚动次数限制,处理完成 <<<")
return False
return True
def run(self):
"""主运行流程"""
print("=" * 60)
print("滚动截屏OCR工具")
print("=" * 60)
print(f"\n使用说明:")
print(f"1. 按下热键 {self.config.HOTKEY} 启动")
print(f"2. 按住鼠标左键拖动选择截图区域")
print(f"3. 程序将自动滚动截屏并进行OCR识别")
print(f"4. 当检测到重复内容时自动停止")
print(f"5. 结果将保存在 '{self.config.OUTPUT_DIR}' 目录")
print("\n" + "=" * 60)
logger.info("程序启动,等待热键触发...")
print(f"\n>>> 等待热键 {self.config.HOTKEY} 启动... <<<")
# 注册热键
keyboard.add_hotkey(self.config.HOTKEY, self._on_hotkey)
# 保持程序运行
try:
while True:
time.sleep(0.1)
except KeyboardInterrupt:
logger.info("程序被用户中断")
print("\n>>> 程序已停止 <<<")
def _on_hotkey(self):
"""热键回调函数"""
logger.info("热键触发,开始处理")
print(f"\n{'='*60}")
print("热键已触发!")
# 检查OCR服务
print("\n>>> 检查OCR服务... <<<")
if not self.ocr_engine.check_service():
if self.config.OCR_ENGINE == "umi":
print("✗ Umi-OCR服务未运行")
print("请先启动Umi-OCR软件并开启HTTP服务")
print(" 1. 打开Umi-OCR")
print(" 2. 进入 设置 -> HTTP接口")
print(" 3. 勾选 '启用HTTP服务'")
print(f" 4. 确保端口为 {self.config.UMI_OCR_PORT}")
else:
print(f"✗ OCR服务未运行: {self.config.OCR_API_URL}")
return
print("✓ OCR服务运行中")
# 选择区域
try:
self.capture_region = self.region_selector.select_region()
except Exception as e:
logger.error(f"区域选择失败: {e}")
print(f"区域选择失败: {e}")
return
# 重置状态
self.previous_ocr_result = []
self.scroll_count = 0
self.all_results = []
print(f"\n>>> 开始自动滚动截屏和OCR识别... <<<")
# 循环处理
try:
while self.process_once():
pass
except Exception as e:
logger.error(f"处理过程中出错: {e}", exc_info=True)
print(f"\n错误: {e}")
# 保存最终结果
if self.all_results:
self.save_final_result()
print(f"\n共处理 {len(self.all_results)} 次截屏")
print(f"结果保存在: {Path(self.config.OUTPUT_DIR).absolute()}")
print(f"\n{'='*60}")
print(">>> 等待下一次热键触发... <<<")
logger.info("处理完成,等待下一次热键触发")
def main():
"""入口函数"""
app = ScrollCaptureOCR()
app.run()
if __name__ == "__main__":
main()