Files
long-screen-cut/main.py

762 lines
27 KiB
Python
Raw Normal View History

"""
滚动截屏OCR工具
功能通过热键激活手动框选区域后自动滚动截屏并进行OCR识别
"""
import json
import time
import base64
import io
import tempfile
from dataclasses import dataclass, field
from typing import List, Tuple, Optional, Callable, Dict
from pathlib import Path
import cv2
import numpy as np
import requests
from PIL import Image
import pyautogui
import keyboard
import mouse
from loguru import logger
from umi_ocr_client import UmiOCRClient, check_and_wait_for_service
@dataclass
class DivRegion:
"""div区域数据结构"""
top: int
bottom: int
left: int
right: int
text: str = ""
@property
def height(self) -> int:
return self.bottom - self.top
@property
def width(self) -> int:
return self.right - self.left
@dataclass
class GapInfo:
"""空白间隔信息"""
start_row: int
end_row: int
@property
def height(self) -> int:
return self.end_row - self.start_row
@dataclass
class AnalysisResult:
"""图像分析结果"""
divs: List[DivRegion] = field(default_factory=list)
gaps: List[GapInfo] = field(default_factory=list)
class Config:
"""配置类"""
# 热键设置
HOTKEY = "ctrl+f9"
# 图像分析参数
GRAY_THRESHOLD = 240 # 灰度阈值,接近白色的阈值
CONSECUTIVE_LINES = 3 # 连续多少行判定为空白
WHITE_PIXEL_RATIO = 0.9 # 一行中超过多少比例的像素为白色才认为是空白行
# OCR设置
OCR_ENGINE = "umi" # OCR引擎: "umi" 使用Umi-OCR, "http" 使用HTTP接口
OCR_API_URL = "http://localhost:8000/ocr" # HTTP OCR服务地址 (OCR_ENGINE=http时使用)
OCR_TIMEOUT = 30 # OCR请求超时时间
# Umi-OCR设置
UMI_OCR_HOST = "127.0.0.1"
UMI_OCR_PORT = 1224
# 滚动设置
SCROLL_DELAY = 0.5 # 滚动后等待渲染的时间(秒)
MAX_SCROLL_COUNT = 100 # 最大滚动次数,防止无限循环
# 输出设置
OUTPUT_DIR = "output"
class RegionSelector:
"""区域选择器 - 用于手动框选截图区域"""
def __init__(self):
self.start_pos: Optional[Tuple[int, int]] = None
self.end_pos: Optional[Tuple[int, int]] = None
self.is_selecting = False
def select_region(self) -> Tuple[int, int, int, int]:
"""
手动选择区域返回 (left, top, right, bottom)
点击确定左上角拖动释放确定右下角
"""
logger.info("请按住鼠标左键拖动选择区域...")
print("\n>>> 请按住鼠标左键拖动选择截图区域,释放后确定 <<<")
# 等待鼠标按下
while not mouse.is_pressed(button='left'):
time.sleep(0.01)
self.start_pos = mouse.get_position()
self.is_selecting = True
logger.info(f"选择开始位置: {self.start_pos}")
# 等待鼠标释放
while mouse.is_pressed(button='left'):
time.sleep(0.01)
self.end_pos = mouse.get_position()
self.is_selecting = False
logger.info(f"选择结束位置: {self.end_pos}")
# 计算边界
left = min(self.start_pos[0], self.end_pos[0])
top = min(self.start_pos[1], self.end_pos[1])
right = max(self.start_pos[0], self.end_pos[0])
bottom = max(self.start_pos[1], self.end_pos[1])
logger.info(f"选定区域: ({left}, {top}, {right}, {bottom}), 尺寸: {right-left}x{bottom-top}")
print(f"已选择区域: 左上角({left}, {top}), 右下角({right}, {bottom})")
return left, top, right, bottom
class ImageAnalyzer:
"""图像分析器 - 分析div边界和空白间隔"""
def __init__(self, config: Config):
self.config = config
def analyze(self, image: np.ndarray) -> AnalysisResult:
"""
分析图像定位div边界
使用灰度阈值 + 连续行判定
"""
result = AnalysisResult()
# 转换为灰度图
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
gray = image
height, width = gray.shape
logger.debug(f"分析图像尺寸: {width}x{height}")
# 逐行分析
is_in_gap = False
gap_start = 0
div_start = 0
consecutive_blank = 0
for row in range(height):
# 计算当前行的白色像素比例
white_pixels = np.sum(gray[row] > self.config.GRAY_THRESHOLD)
white_ratio = white_pixels / width
is_blank = white_ratio > self.config.WHITE_PIXEL_RATIO
if is_blank:
consecutive_blank += 1
else:
# 如果之前是空白区域现在进入div
if consecutive_blank >= self.config.CONSECUTIVE_LINES and not is_in_gap:
# 记录空白间隔
gap_end = row - consecutive_blank
gap = GapInfo(start_row=gap_start, end_row=gap_end)
result.gaps.append(gap)
logger.debug(f"发现空白间隔: 行 {gap.start_row}-{gap.end_row}, 高度 {gap.height}")
# 记录div开始
div_start = row
is_in_gap = True
consecutive_blank = 0
gap_start = row
# 如果连续多行都是空白,认为是间隔区域
if consecutive_blank >= self.config.CONSECUTIVE_LINES and is_in_gap:
# 记录div结束
div_end = row - consecutive_blank
if div_end > div_start:
div = DivRegion(
top=div_start,
bottom=div_end,
left=0,
right=width
)
result.divs.append(div)
logger.debug(f"发现div区域: 行 {div.top}-{div.bottom}, 高度 {div.height}")
is_in_gap = False
gap_start = row - consecutive_blank + 1
# 处理最后一个div如果图像不以空白结束
if not is_in_gap and div_start < height - consecutive_blank:
div = DivRegion(
top=div_start,
bottom=height - consecutive_blank,
left=0,
right=width
)
result.divs.append(div)
logger.debug(f"发现末尾div区域: 行 {div.top}-{div.bottom}, 高度 {div.height}")
logger.info(f"分析完成: 发现 {len(result.divs)} 个div, {len(result.gaps)} 个空白间隔")
return result
def calculate_scroll_distance(self, result: AnalysisResult) -> int:
"""
根据分析结果计算滚动距离
策略滚动到下一个div的顶部
"""
if not result.divs:
logger.warning("未检测到div使用默认滚动距离")
return 100
# 获取第一个div和第一个空白间隔
first_div = result.divs[0]
# 如果有空白间隔滚动距离为第一个div高度 + 其后的空白间隔
scroll_distance = first_div.height
# 查找第一个div之后的空白间隔
for gap in result.gaps:
if gap.start_row >= first_div.bottom:
scroll_distance += gap.height
break
# 添加一些重叠,确保连续性
overlap = min(20, first_div.height // 4)
scroll_distance = max(scroll_distance - overlap, 50)
logger.info(f"计算滚动距离: {scroll_distance} 像素")
return int(scroll_distance)
class OCREngine:
"""OCR引擎 - 调用OCR服务识别文字"""
def __init__(self, config: Config):
self.config = config
self.umi_client: Optional[UmiOCRClient] = None
if config.OCR_ENGINE == "umi":
self.umi_client = UmiOCRClient(
host=config.UMI_OCR_HOST,
port=config.UMI_OCR_PORT
)
def _recognize_with_http(self, image: np.ndarray) -> List[str]:
"""使用HTTP接口进行OCR识别"""
try:
# 将numpy数组转换为PIL Image
if len(image.shape) == 3:
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
else:
pil_image = Image.fromarray(image)
# 转换为base64
buffered = io.BytesIO()
pil_image.save(buffered, format="PNG")
img_base64 = base64.b64encode(buffered.getvalue()).decode()
# 调用OCR API
response = requests.post(
self.config.OCR_API_URL,
json={"image": img_base64},
timeout=self.config.OCR_TIMEOUT
)
response.raise_for_status()
data = response.json()
texts = data.get("texts", [])
return texts
except requests.exceptions.ConnectionError:
logger.error(f"无法连接到OCR服务: {self.config.OCR_API_URL}")
return []
except Exception as e:
logger.error(f"HTTP OCR识别失败: {e}")
return []
def _recognize_with_umi(self, image: np.ndarray) -> List[str]:
"""使用Umi-OCR进行识别"""
if not self.umi_client:
logger.error("Umi-OCR客户端未初始化")
return []
try:
# 将图像保存为临时文件
if len(image.shape) == 3:
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
else:
pil_image = Image.fromarray(image)
# 创建临时文件
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
tmp_path = tmp_file.name
pil_image.save(tmp_path, format="PNG")
try:
# 调用Umi-OCR识别
text = self.umi_client.recognize_image(tmp_path, timeout=self.config.OCR_TIMEOUT)
if text:
# 按行分割
lines = [line.strip() for line in text.split('\n') if line.strip()]
return lines
return []
finally:
# 删除临时文件
try:
Path(tmp_path).unlink()
except Exception:
pass
except Exception as e:
logger.error(f"Umi-OCR识别失败: {e}")
return []
def recognize(self, image: np.ndarray) -> List[str]:
"""
对图像进行OCR识别
返回识别到的文字列表
"""
if self.config.OCR_ENGINE == "umi":
texts = self._recognize_with_umi(image)
else:
texts = self._recognize_with_http(image)
logger.info(f"OCR识别完成识别到 {len(texts)} 段文字")
return texts
def recognize_divs(self, image: np.ndarray, divs: List[DivRegion]) -> List[str]:
"""
对每个div区域分别进行OCR识别
"""
all_texts = []
for i, div in enumerate(divs):
# 截取div区域
div_image = image[div.top:div.bottom, div.left:div.right]
texts = self.recognize(div_image)
all_texts.extend(texts)
logger.debug(f"Div {i+1} OCR结果: {texts}")
return all_texts
def check_service(self) -> bool:
"""检查OCR服务是否可用"""
if self.config.OCR_ENGINE == "umi":
if not self.umi_client:
return False
return self.umi_client.is_service_running()
else:
try:
response = requests.get(self.config.OCR_API_URL.replace('/ocr', '/health'), timeout=2)
return response.status_code == 200
except Exception:
return False
class ScrollCaptureOCR:
"""滚动截屏OCR主类"""
def __init__(self):
self.config = Config()
self.region_selector = RegionSelector()
self.image_analyzer = ImageAnalyzer(self.config)
self.ocr_engine = OCREngine(self.config)
self.capture_region: Optional[Tuple[int, int, int, int]] = None
self.previous_ocr_result: List[str] = []
self.scroll_count = 0
self.all_results: List[dict] = []
# 新增记录已处理的div信息
self.processed_divs: List[Dict] = [] # 已处理的所有div信息
self.last_div_signature: Optional[str] = None # 最后一个div的签名用于去重
self.total_scroll_distance: int = 0 # 累计滚动距离
self.is_first_capture: bool = True # 是否是第一次截图
# 创建输出目录
Path(self.config.OUTPUT_DIR).mkdir(exist_ok=True)
def capture_screen(self) -> np.ndarray:
"""截取指定区域的屏幕"""
if not self.capture_region:
raise ValueError("未设置截图区域")
left, top, right, bottom = self.capture_region
screenshot = pyautogui.screenshot(region=(left, top, right - left, bottom - top))
return cv2.cvtColor(np.array(screenshot), cv2.COLOR_RGB2BGR)
def get_div_signature(self, div: DivRegion, image: np.ndarray) -> str:
"""
生成div的签名用于判断是否是同一个div
使用div的图像内容的哈希值
"""
import hashlib
div_image = image[div.top:div.bottom, div.left:div.right]
# 缩小图像以加快计算
small = cv2.resize(div_image, (32, 32))
# 计算平均哈希
gray = cv2.cvtColor(small, cv2.COLOR_BGR2GRAY)
avg = gray.mean()
hash_str = ''.join(['1' if p > avg else '0' for p in gray.flatten()])
return hashlib.md5(hash_str.encode()).hexdigest()[:16]
def scroll_screen(self, distance: int):
"""在截图区域执行滚动"""
if not self.capture_region:
return
# 将鼠标移动到截图区域中央
left, top, right, bottom = self.capture_region
center_x = (left + right) // 2
center_y = (top + bottom) // 2
pyautogui.moveTo(center_x, center_y)
time.sleep(0.1)
# 执行滚动
pyautogui.scroll(-distance)
logger.info(f"向下滚动 {distance} 像素")
# 等待页面渲染
time.sleep(self.config.SCROLL_DELAY)
def check_duplicate(self, current_texts: List[str]) -> bool:
"""
检查当前OCR结果是否与上一次相同
用于判断是否到达底部
"""
if not self.previous_ocr_result:
return False
# 简单比较:如果文字列表完全相同,认为是重复
is_duplicate = current_texts == self.previous_ocr_result
if is_duplicate:
logger.info("检测到OCR结果重复可能已到达底部")
return is_duplicate
def save_result(self, scroll_index: int, image: np.ndarray, texts: List[str]):
"""保存截图和OCR结果"""
timestamp = time.strftime("%Y%m%d_%H%M%S")
# 保存图片
image_path = Path(self.config.OUTPUT_DIR) / f"capture_{timestamp}_{scroll_index:03d}.png"
cv2.imwrite(str(image_path), image)
# 保存OCR结果
result = {
"index": scroll_index,
"timestamp": timestamp,
"image_path": str(image_path),
"texts": texts
}
self.all_results.append(result)
logger.info(f"保存结果: {image_path}, 识别文字数: {len(texts)}")
def save_final_result(self):
"""保存所有结果到JSON文件"""
output_path = Path(self.config.OUTPUT_DIR) / f"all_results_{time.strftime('%Y%m%d_%H%M%S')}.json"
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(self.all_results, f, ensure_ascii=False, indent=2)
logger.info(f"所有结果已保存到: {output_path}")
print(f"\n所有结果已保存到: {output_path}")
def process_once(self) -> bool:
"""
执行一次处理循环
返回False表示应该停止
新逻辑
1. 第一次截图处理所有div
2. 后续截图只处理新的div跳过已处理的
3. 滚动距离 = 最后一个新div的底部位置 + 空白间隔
"""
logger.info(f"=== 第 {self.scroll_count + 1} 次截屏 ===")
print(f"\n>>> 第 {self.scroll_count + 1} 次截屏处理中...")
# 1. 截取当前屏幕
image = self.capture_screen()
height, width = image.shape[:2]
logger.info(f"截图完成,尺寸: {width}x{height}")
# 2. 分析图像定位div边界
analysis = self.image_analyzer.analyze(image)
if not analysis.divs:
logger.warning("未检测到任何div区域可能已到达底部或区域选择有误")
print("警告: 未检测到内容区域")
return False
logger.info(f"检测到 {len(analysis.divs)} 个div区域")
print(f"检测到 {len(analysis.divs)} 个内容区域")
# 3. 识别新的div跳过已处理的
new_divs = []
current_texts = []
last_processed_signature = None
for i, div in enumerate(analysis.divs):
# 生成div签名
div_signature = self.get_div_signature(div, image)
# 检查是否是已处理的div
is_processed = False
if not self.is_first_capture and self.processed_divs:
# 与已处理的div比较
for processed in self.processed_divs:
if processed.get('signature') == div_signature:
is_processed = True
logger.info(f"跳过已处理的div {i+1}")
break
if is_processed:
continue
# 新的div进行处理
new_divs.append({
'div': div,
'signature': div_signature,
'index': i
})
logger.info(f"处理新的div {i+1},位置: {div.top}-{div.bottom}")
print(f" 处理新区域 {i+1}/{len(analysis.divs)}...")
# 截取单个div区域
div_image = image[div.top:div.bottom, div.left:div.right]
# OCR识别
texts = self.ocr_engine.recognize(div_image)
div.text = "\n".join(texts)
current_texts.extend(texts)
# 保存单个div的结果
self.save_div_result(self.scroll_count, i, div_image, texts, div)
# 记录处理的div
self.processed_divs.append({
'signature': div_signature,
'text': div.text,
'scroll_count': self.scroll_count,
'div_index': i
})
last_processed_signature = div_signature
logger.info(f" 识别到 {len(texts)} 段文字")
# 如果没有新的div说明已经到底
if not new_divs:
logger.info("没有新的div需要处理可能已到达底部")
print("✓ 没有新的内容需要处理")
return False
print(f"✓ 本次处理 {len(new_divs)} 个新区域,共识别 {len(current_texts)} 段文字")
for i, text in enumerate(current_texts[:3], 1):
preview = text[:50] + "..." if len(text) > 50 else text
print(f" [{i}] {preview}")
if len(current_texts) > 3:
print(f" ... 还有 {len(current_texts) - 3} 段文字")
# 4. 保存完整截图结果
self.save_result(self.scroll_count, image, current_texts)
# 5. 判断是否到达底部没有新内容或OCR结果重复
if self.check_duplicate(current_texts):
print("\n>>> 检测到内容重复,已到达底部,处理完成 <<<")
return False
self.previous_ocr_result = current_texts
# 6. 计算滚动距离 - 基于最后一个新div的位置
if new_divs:
last_new_div = new_divs[-1]['div']
scroll_distance = self.calculate_scroll_based_on_last_div(
last_new_div, analysis.gaps, height
)
else:
scroll_distance = int(height * 0.8) # 默认滚动80%高度
# 7. 执行滚动
self.scroll_screen(scroll_distance)
self.total_scroll_distance += scroll_distance
self.scroll_count += 1
self.is_first_capture = False # 标记不再是第一次
# 检查最大滚动次数
if self.scroll_count >= self.config.MAX_SCROLL_COUNT:
logger.warning(f"达到最大滚动次数限制 ({self.config.MAX_SCROLL_COUNT})")
print(f"\n>>> 达到最大滚动次数限制,处理完成 <<<")
return False
return True
def calculate_scroll_based_on_last_div(self, last_div: DivRegion, gaps: List[GapInfo], image_height: int) -> int:
"""
基于最后一个div计算滚动距离
策略滚动到最后一个div的底部 + 其后空白间隔
这样可以让下一个新div出现在截图区域的顶部
"""
last_div_bottom = last_div.bottom
# 查找最后一个div之后的空白间隔
last_gap_height = 0
for gap in gaps:
if gap.start_row >= last_div.bottom:
last_gap_height = gap.height
break
# 计算滚动距离
scroll_distance = last_div_bottom + last_gap_height
# 确保至少滚动一定距离(避免 stuck
min_scroll = 50
scroll_distance = max(scroll_distance, min_scroll)
# 限制最大滚动距离不超过图片高度的90%,保留一些重叠)
max_scroll = int(image_height * 0.9)
scroll_distance = min(scroll_distance, max_scroll)
logger.info(f"滚动距离计算: 最后div底部={last_div_bottom}, "
f"空白间隔={last_gap_height}, 滚动距离={scroll_distance}")
print(f" 滚动距离: {scroll_distance} 像素 (基于最后一个内容区域)")
return int(scroll_distance)
def save_div_result(self, scroll_index: int, div_index: int, image: np.ndarray, texts: List[str], div: DivRegion):
"""保存单个div的结果"""
timestamp = time.strftime("%Y%m%d_%H%M%S")
# 保存div图片
div_dir = Path(self.config.OUTPUT_DIR) / f"scroll_{scroll_index:03d}"
div_dir.mkdir(parents=True, exist_ok=True)
image_path = div_dir / f"div_{div_index:02d}_{timestamp}.png"
cv2.imwrite(str(image_path), image)
# 保存div OCR结果
result = {
"scroll_index": scroll_index,
"div_index": div_index,
"timestamp": timestamp,
"div_position": {"top": div.top, "bottom": div.bottom, "left": div.left, "right": div.right},
"image_path": str(image_path),
"texts": texts
}
# 也添加到总结果中
if not hasattr(self, '_div_results'):
self._div_results = []
self._div_results.append(result)
logger.debug(f"保存div结果: scroll={scroll_index}, div={div_index}, 文字数={len(texts)}")
def run(self):
"""主运行流程"""
print("=" * 60)
print("滚动截屏OCR工具")
print("=" * 60)
print(f"\n使用说明:")
print(f"1. 按下热键 {self.config.HOTKEY} 启动")
print(f"2. 按住鼠标左键拖动选择截图区域")
print(f"3. 程序将自动滚动截屏并进行OCR识别")
print(f"4. 当检测到重复内容时自动停止")
print(f"5. 结果将保存在 '{self.config.OUTPUT_DIR}' 目录")
print("\n" + "=" * 60)
logger.info("程序启动,等待热键触发...")
print(f"\n>>> 等待热键 {self.config.HOTKEY} 启动... <<<")
# 注册热键
keyboard.add_hotkey(self.config.HOTKEY, self._on_hotkey)
# 保持程序运行
try:
while True:
time.sleep(0.1)
except KeyboardInterrupt:
logger.info("程序被用户中断")
print("\n>>> 程序已停止 <<<")
def _on_hotkey(self):
"""热键回调函数"""
logger.info("热键触发,开始处理")
print(f"\n{'='*60}")
print("热键已触发!")
# 检查OCR服务
print("\n>>> 检查OCR服务... <<<")
if not self.ocr_engine.check_service():
if self.config.OCR_ENGINE == "umi":
print("✗ Umi-OCR服务未运行")
print("请先启动Umi-OCR软件并开启HTTP服务")
print(" 1. 打开Umi-OCR")
print(" 2. 进入 设置 -> HTTP接口")
print(" 3. 勾选 '启用HTTP服务'")
print(f" 4. 确保端口为 {self.config.UMI_OCR_PORT}")
else:
print(f"✗ OCR服务未运行: {self.config.OCR_API_URL}")
return
print("✓ OCR服务运行中")
# 选择区域
try:
self.capture_region = self.region_selector.select_region()
except Exception as e:
logger.error(f"区域选择失败: {e}")
print(f"区域选择失败: {e}")
return
# 重置状态
self.previous_ocr_result = []
self.scroll_count = 0
self.all_results = []
print(f"\n>>> 开始自动滚动截屏和OCR识别... <<<")
# 循环处理
try:
while self.process_once():
pass
except Exception as e:
logger.error(f"处理过程中出错: {e}", exc_info=True)
print(f"\n错误: {e}")
# 保存最终结果
if self.all_results:
self.save_final_result()
print(f"\n共处理 {len(self.all_results)} 次截屏")
print(f"结果保存在: {Path(self.config.OUTPUT_DIR).absolute()}")
print(f"\n{'='*60}")
print(">>> 等待下一次热键触发... <<<")
logger.info("处理完成,等待下一次热键触发")
def main():
"""入口函数"""
app = ScrollCaptureOCR()
app.run()
if __name__ == "__main__":
main()