Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 95 additions & 30 deletions pero_ocr/core/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def __init__(self, id: str = None,
self.transcription_confidence = transcription_confidence
self.category = category

self.embeddings = []
self.metadata = {}
self.graphical_metadata = None

def get_dense_logits(self, zero_logit_value: int = -80):
dense_logits = self.logits.toarray()
dense_logits[dense_logits == 0] = zero_logit_value
Expand Down Expand Up @@ -203,11 +207,8 @@ def from_pagexml_parse_custom(self, custom_str):
heights = heights_array
self.heights = heights.tolist()

def to_altoxml(self, text_block, arabic_helper, min_line_confidence, version: ALTOVersion, next_line=None,
previous_line=None, word_splitters=["-"]):
if self.character_confidences is None or self.transcription_confidence is None:
self.calculate_confidences()

def to_altoxml(self, text_block, tags, mods_namespace, arabic_helper, min_line_confidence, version: ALTOVersion,
next_line=None, previous_line=None, word_splitters=["-"]):
if self.transcription_confidence is not None and self.transcription_confidence < min_line_confidence:
return

Expand Down Expand Up @@ -238,6 +239,10 @@ def to_altoxml(self, text_block, arabic_helper, min_line_confidence, version: AL
if self.transcription_confidence is not None:
string.set("WC", str(round(self.transcription_confidence, 2)))

if self.graphical_metadata is not None:
tag_references = [metadata.tag_id for metadata in self.graphical_metadata]
text_line.set("TAGREFS", ' '.join(tag_references))

def get_labels(self):
chars = [i for i in range(len(self.characters))]
char_to_num = dict(zip(self.characters, chars))
Expand Down Expand Up @@ -465,6 +470,10 @@ def __init__(self, id: str,
self.transcription = None
self.detection_confidence = detection_confidence

self.embeddings = []
self.metadata = {}
self.graphical_metadata = None

def get_lines_of_category(self, categories: Union[str, list]):
if isinstance(categories, str):
categories = [categories]
Expand All @@ -481,10 +490,10 @@ def get_polygon_bounding_box(self) -> Tuple[int, int, int, int]:
"""Get bounding box of region polygon which includes all polygon points.
:return: Tuple[int, int, int, int]: (x_min, y_min, x_max, y_max)
"""
x_min = min(self.polygon[:, 0])
x_max = max(self.polygon[:, 0])
y_min = min(self.polygon[:, 1])
y_max = max(self.polygon[:, 1])
x_min = round(min(self.polygon[:, 0]))
x_max = round(max(self.polygon[:, 0]))
y_min = round(min(self.polygon[:, 1]))
y_max = round(max(self.polygon[:, 1]))

return x_min, y_min, x_max, y_max

Expand Down Expand Up @@ -550,33 +559,49 @@ def from_pagexml(cls, region_element: ET.SubElement, schema):

return layout_region

def to_altoxml(self, print_space, arabic_helper, min_line_confidence, print_space_coords: Tuple[int, int, int, int],
version: ALTOVersion, word_splitters=["-"]) -> Tuple[int, int, int, int]:
def to_altoxml(self, print_space, tags, mods_namespace, arabic_helper, min_line_confidence,
print_space_coords: Tuple[int, int, int, int], version: ALTOVersion, word_splitters=["-"]) -> Tuple[int, int, int, int]:
print_space_height, print_space_width, print_space_vpos, print_space_hpos = print_space_coords

text_block = ET.SubElement(print_space, "TextBlock")
text_block.set("ID", 'block_{}'.format(self.id))
if self.category is None or self.category == 'text':
block = ET.SubElement(print_space, "TextBlock")

text_block_height, text_block_width, text_block_vpos, text_block_hpos = get_hwvh(self.polygon)
text_block.set("HEIGHT", str(int(text_block_height)))
text_block.set("WIDTH", str(int(text_block_width)))
text_block.set("VPOS", str(int(text_block_vpos)))
text_block.set("HPOS", str(int(text_block_hpos)))
if self.category is None or self.category == 'text':
block.set("ID", 'block_{}'.format(self.id))
else:
block.set("ID", self.id)

print_space_height = max([print_space_vpos + print_space_height, text_block_vpos + text_block_height])
print_space_width = max([print_space_hpos + print_space_width, text_block_hpos + text_block_width])
print_space_vpos = min([print_space_vpos, text_block_vpos])
print_space_hpos = min([print_space_hpos, text_block_hpos])
else:
from anno_page.core.layout import region_to_altoxml
block = region_to_altoxml(self, print_space)

block_height, block_width, block_vpos, block_hpos = get_hwvh(self.polygon)
block.set("HEIGHT", str(int(block_height)))
block.set("WIDTH", str(int(block_width)))
block.set("VPOS", str(int(block_vpos)))
block.set("HPOS", str(int(block_hpos)))

print_space_height = max([print_space_vpos + print_space_height, block_vpos + block_height])
print_space_width = max([print_space_hpos + print_space_width, block_hpos + block_width])
print_space_vpos = min([print_space_vpos, block_vpos])
print_space_hpos = min([print_space_hpos, block_hpos])
print_space_height = print_space_height - print_space_vpos
print_space_width = print_space_width - print_space_hpos

if self.graphical_metadata is not None:
self.graphical_metadata.to_altoxml(tags,
category=self.category,
bounding_box=self.get_polygon_bounding_box(),
confidence=self.detection_confidence,
mods_namespace=mods_namespace)

for i, line in enumerate(self.lines):
if not line.transcription or line.transcription.strip() == "":
continue

previous_line = self.lines[i - 1] if i > 0 else None
next_line = self.lines[i + 1] if i + 1 < len(self.lines) else None
line.to_altoxml(text_block, arabic_helper, min_line_confidence, version, next_line=next_line,
line.to_altoxml(block, tags, mods_namespace, arabic_helper, min_line_confidence, version, next_line=next_line,
previous_line=previous_line, word_splitters=word_splitters)
return print_space_height, print_space_width, print_space_vpos, print_space_hpos

Expand All @@ -597,6 +622,14 @@ def from_altoxml(cls, text_block: ET.SubElement, schema):

return region_layout

def get_all_embeddings(self):
embeddings = []
for line in self.lines:
embeddings += line.embeddings

embeddings += self.embeddings
return embeddings


def get_coords_from_pagexml(coords_element, schema):
if 'points' in coords_element.attrib:
Expand Down Expand Up @@ -726,6 +759,9 @@ def __init__(self, id: str = None, page_size: Tuple[int, int] = (0, 0), file: st
self.reading_order = None
self.confidence = None

self.embeddings = []
self.metadata = {}

if file is not None:
self.from_pagexml(file)

Expand Down Expand Up @@ -803,7 +839,11 @@ def to_altoxml_string(self, ocr_processing_element: ET.SubElement = None, page_u
min_line_confidence: float = 0, version: ALTOVersion = ALTOVersion.ALTO_v2_x,
word_splitters=["-"]):
arabic_helper = ArabicHelper()

mods_namespace_url = "http://www.loc.gov/mods/v3"

NSMAP = {"xlink": 'http://www.w3.org/1999/xlink',
"mods": mods_namespace_url,
"xsi": 'http://www.w3.org/2001/XMLSchema-instance'}
root = ET.Element("alto", nsmap=NSMAP)

Expand All @@ -821,8 +861,9 @@ def to_altoxml_string(self, ocr_processing_element: ET.SubElement = None, page_u
if ocr_processing_element is not None:
description.append(ocr_processing_element)
else:
ocr_processing_element = create_ocr_processing_element()
ocr_processing_element = create_ocr_processing_element(alto_version=version)
description.append(ocr_processing_element)
tags = ET.SubElement(root, "Tags")
layout = ET.SubElement(root, "Layout")
page = ET.SubElement(layout, "Page")
if page_uuid is not None:
Expand All @@ -845,9 +886,19 @@ def to_altoxml_string(self, ocr_processing_element: ET.SubElement = None, page_u
print_space_hpos = self.page_size[1]
print_space_coords = (print_space_height, print_space_width, print_space_vpos, print_space_hpos)

for block in self.regions:
print_space_coords = block.to_altoxml(print_space, arabic_helper, min_line_confidence, print_space_coords,
version, word_splitters=word_splitters)
text_regions = []
nontext_regions = []
for region in self.regions:
if region.category is None or region.category == 'text':
text_regions.append(region)
else:
nontext_regions.append(region)

for region in nontext_regions:
print_space_coords = region.to_altoxml(print_space, tags, mods_namespace_url, arabic_helper, min_line_confidence, print_space_coords, version)

for region in text_regions:
print_space_coords = region.to_altoxml(print_space, tags, mods_namespace_url, arabic_helper, min_line_confidence, print_space_coords, version, word_splitters=word_splitters)

print_space_height, print_space_width, print_space_vpos, print_space_hpos = print_space_coords

Expand Down Expand Up @@ -1119,6 +1170,14 @@ def rename_region_id(self, old_id, new_id):
else:
raise ValueError(f'Region with id {old_id} not found.')

def get_all_embeddings(self):
embeddings = []
for region in self.regions:
embeddings += region.get_all_embeddings()

embeddings += self.embeddings
return embeddings


def draw_lines(img, lines, color=(255, 0, 0), circles=(False, False, False), close=False, thickness=2):
"""Draw a line into image.
Expand Down Expand Up @@ -1189,10 +1248,16 @@ def create_ocr_processing_element(id: str = "IdOcr",
software_creator_str: str = "Project PERO",
software_name_str: str = "PERO OCR",
software_version_str: str = "v0.1.0",
processing_datetime=None):
ocr_processing = ET.Element("OCRProcessing")
processing_datetime=None,
alto_version: ALTOVersion = ALTOVersion.ALTO_v2_x):
if alto_version == ALTOVersion.ALTO_v4_4:
ocr_processing = ET.Element("Processing")
ocr_processing_step = ocr_processing
else:
ocr_processing = ET.Element("OCRProcessing")
ocr_processing_step = ET.SubElement(ocr_processing, "ocrProcessingStep")

ocr_processing.set("ID", id)
ocr_processing_step = ET.SubElement(ocr_processing, "ocrProcessingStep")
processing_date_time = ET.SubElement(ocr_processing_step, "processingDateTime")
if processing_datetime is not None:
processing_date_time.text = processing_datetime
Expand Down
42 changes: 38 additions & 4 deletions pero_ocr/document_ocr/page_parser.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import numpy as np

import logging
from multiprocessing import Pool
import math
import time
import re
from typing import Union, Tuple, List
from collections import defaultdict

import torch.cuda

Expand Down Expand Up @@ -80,6 +80,28 @@ def page_decoder_factory(config, device, config_path=''):
categories=categories)


def operations_factory(config, device, config_path=''):
if config['METHOD'] == 'CHATGPT_IMAGE_CAPTIONING':
from anno_page.engines.captioning import ChatGPTImageCaptioning
operation_engine = ChatGPTImageCaptioning(config, device, config_path=config_path)
elif config['METHOD'] == 'CLIP_EMBEDDING':
from anno_page.engines.embedding import ClipEmbeddingEngine
operation_engine = ClipEmbeddingEngine(config, device, config_path=config_path)
elif config['METHOD'] == 'CAPTION_YOLO_NEAREST':
from anno_page.engines.captioning import CaptionYoloNearestEngine
operation_engine = CaptionYoloNearestEngine(config, device, config_path=config_path)
elif config['METHOD'] == 'CAPTION_YOLO_ORGANIZER':
from anno_page.engines.captioning import CaptionYoloOrganizerEngine
operation_engine = CaptionYoloOrganizerEngine(config, device, config_path=config_path)
elif config['METHOD'] == 'CAPTION_YOLO_KEYPOINTS':
from anno_page.engines.captioning import CaptionYoloKeypointsEngine
operation_engine = CaptionYoloKeypointsEngine(config, device, config_path=config_path)
else:
raise ValueError(f"Unknown operation method: {config['METHOD']}")

return operation_engine


class MissingLogits(Exception):
pass

Expand Down Expand Up @@ -341,12 +363,13 @@ def process_page(self, img, page_layout: PageLayout):
page_layout.regions = []

result = self.engine.detect(img)
start_id = self.get_start_id([region.id for region in page_layout_text.regions])

category_counts = defaultdict(int)
for region in page_layout_text.regions:
category_counts[region.category] += 1

boxes = result.boxes.data.cpu()
for box_id, box in enumerate(boxes):
id_str = 'r{:03d}'.format(start_id + box_id)

x_min, y_min, x_max, y_max, conf, class_id = box.tolist()
polygon = np.array([[x_min, y_min], [x_min, y_max], [x_max, y_max], [x_max, y_min], [x_min, y_min]])
baseline_y = y_min + (y_max - y_min) / 2
Expand All @@ -357,6 +380,9 @@ def process_page(self, img, page_layout: PageLayout):
if self.categories and category not in self.categories:
continue

category_counts[category] += 1

id_str = f'{helpers.normalize_category_name(category)}_{category_counts[category]:03d}'.lower()
region = RegionLayout(id_str, polygon, category=category, detection_confidence=conf)

if category in self.line_categories:
Expand Down Expand Up @@ -628,6 +654,7 @@ def __init__(self, config, device=None, config_path='', ):
self.run_line_cropper = config['PAGE_PARSER'].getboolean('RUN_LINE_CROPPER', fallback=False)
self.run_ocr = config['PAGE_PARSER'].getboolean('RUN_OCR', fallback=False)
self.run_decoder = config['PAGE_PARSER'].getboolean('RUN_DECODER', fallback=False)
self.run_operations = config['PAGE_PARSER'].getboolean('RUN_OPERATIONS', fallback=False)
self.filter_confident_lines_threshold = config['PAGE_PARSER'].getfloat('FILTER_CONFIDENT_LINES_THRESHOLD',
fallback=-1)

Expand All @@ -637,6 +664,7 @@ def __init__(self, config, device=None, config_path='', ):
self.line_croppers = {}
self.ocrs = {}
self.decoder = None
self.operations = {}

if self.run_layout_parser:
self.layout_parsers = self.init_config_sections(config, config_path, 'LAYOUT_PARSER', layout_parser_factory)
Expand All @@ -646,6 +674,8 @@ def __init__(self, config, device=None, config_path='', ):
self.ocrs = self.init_config_sections(config, config_path, 'OCR', ocr_factory)
if self.run_decoder:
self.decoder = page_decoder_factory(config, self.device, config_path=config_path)
if self.run_operations:
self.operations = self.init_config_sections(config, config_path, 'OPERATION', operations_factory)

@property
def provides_ctc_logits(self):
Expand Down Expand Up @@ -681,6 +711,10 @@ def process_page(self, image, page_layout):
if self.filter_confident_lines_threshold > 0:
page_layout = self.filter_confident_lines(page_layout)

if self.run_operations:
for _, operation_engine in sorted(self.operations.items()):
page_layout = operation_engine.process_page(image, page_layout)

page_layout.calculate_confidence()

return page_layout
Expand Down
23 changes: 23 additions & 0 deletions pero_ocr/layout_engines/layout_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,3 +489,26 @@ def merge_page_layouts(page_layout_positive: PageLayout, page_layout_negative: P
page_layout_positive.regions.append(region)

return page_layout_positive


def insert_line_to_page_layout(page_layout: PageLayout, region: RegionLayout, line: TextLine) -> PageLayout:
"""Insert line to page layout given region of origin. Find if region already exists by ID."""
existing_region = find_region_by_id(page_layout, region.id)

if existing_region is not None:
existing_region.lines.append(line)
else:
region.lines = [line]
page_layout.regions.append(region)
return page_layout


def find_region_by_id(page_layout: PageLayout, region_id: str) -> Optional[RegionLayout]:
for region in page_layout.regions:
if region.id == region_id:
return region
return None


def normalize_category_name(category):
return category.lower().replace(' ', '_').replace(',', '')
11 changes: 7 additions & 4 deletions pero_ocr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
except Exception:
logging.warning('cannot import numba, creating dummy jit definition')

def jit(function):
def wrapper(*args, **kwargs):
return function(*args, **kwargs)
return wrapper

def jit(*_, **__):
def decorator(function):
def wrapper(*args, **kwargs):
return function(*args, **kwargs)
return wrapper
return decorator


def compose_path(file_path, reference_path):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ classifiers=[
]

dependencies = [
"numpy~=1.24.2",
"numpy",
"opencv-python",
"lxml",
"lmdb",
Expand Down
Loading