Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 194 additions & 52 deletions pero_ocr/core/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ class ALTOVersion(Enum):
ALTO_v2_x = 1
ALTO_v4_4 = 2

class WordAlignmentOrigin(Enum):
FROM_LOGITS = 'from_logits'
MEAN_WIDTH = 'mean_width'

def log_softmax(x):
a = np.logaddexp.reduce(x, axis=1)[:, np.newaxis]
return x - a
Expand All @@ -42,6 +46,75 @@ def export_id(id, validate_change_id):
return 'id_' + id if validate_change_id else id


class Word(object):
def __init__(self, id: str = None,
polygon: Optional[np.ndarray] = None,
transcription: Optional[str] = None,
transcription_confidence: Optional[Num] = None):
self.id = id
assert polygon.shape[1] == 2, f'Polygon has wrong shape: {polygon.shape}, expected (n, 2)'
self.polygon = polygon
self.transcription = transcription
self.transcription_confidence = transcription_confidence

def to_pagexml(self, line_element: ET.SubElement, fallback_id: int, validate_id: bool = False):
word_element = ET.SubElement(line_element, "Word")
word_element.set("id", export_id(self.id, validate_id))

coords = ET.SubElement(word_element, "Coords")
coords.set("points", coords_to_pagexml_points(self.polygon))

text_element = ET.SubElement(word_element, "TextEquiv")
if self.transcription_confidence is not None:
text_element.set("conf", f"{self.transcription_confidence:.3f}")

text_element = ET.SubElement(text_element, "Unicode")

text_element.text = self.transcription

@classmethod
def from_pagexml(cls, word_element: ET.SubElement, schema):
coords = word_element.find(schema + 'Coords')
polygon = get_coords_from_pagexml(coords, schema)

word = cls(id=word_element.attrib['id'], polygon=polygon)

transcription = word_element.find(schema + 'TextEquiv')

if transcription is not None:

word.transcription = transcription.find(schema + 'Unicode').text
if word.transcription is None:
word.transcription = ''

conf = transcription.get('conf', None)
word.transcription_confidence = float(conf) if conf is not None else None

return word

def to_altoxml(self, text_line, arabic_helper, add_space: bool):
string = ET.SubElement(text_line, "String")

if arabic_helper.is_arabic_line(self.transcription):
string.set("CONTENT", arabic_helper.label_form_to_string(self.transcription))
else:
string.set("CONTENT", self.transcription)

h, w, vpos, hpos = get_hwvh(self.polygon)
string.set("HEIGHT", str(int(h)))
string.set("WIDTH", str(int(w)))
string.set("VPOS", str(int(vpos)))
string.set("HPOS", str(int(hpos)))

if self.transcription_confidence is not None:
string.set("WC", str(round(self.transcription_confidence, 2)))

if add_space:
space = ET.SubElement(text_line, "SP")
space.set("WIDTH", str(4))
space.set("VPOS", str(int(vpos)))
space.set("HPOS", str(int(hpos + w)))

class TextLine(object):
def __init__(self, id: str = None,
baseline: Optional[np.ndarray] = None,
Expand All @@ -68,6 +141,10 @@ def __init__(self, id: str = None,
self.transcription_confidence = transcription_confidence
self.category = category

self.custom = {}
self.words: List[Word] = [] # words are not required for text line, but can be added using align_words()
self.word_alignment_origin = None

def get_dense_logits(self, zero_logit_value: int = -80):
dense_logits = self.logits.toarray()
dense_logits[dense_logits == 0] = zero_logit_value
Expand All @@ -91,6 +168,14 @@ def to_pagexml(self, region_element: ET.SubElement, fallback_id: int, validate_i
custom['heights'] = list(np.round(heights_out, decimals=1))
if self.category is not None:
custom['category'] = self.category
if self.word_alignment_origin is not None:
custom['word_alignment_origin'] = self.word_alignment_origin.value
if self.custom is not None:
new_custom = self.custom.copy()
# overwrite possible existing keys in self.custom by textline attributes
# as they have benn updated more probably during computation
new_custom.update(custom)
custom = new_custom
if len(custom) > 0:
text_line.set("custom", json.dumps(custom))

Expand All @@ -103,6 +188,10 @@ def to_pagexml(self, region_element: ET.SubElement, fallback_id: int, validate_i
baseline_element = ET.SubElement(text_line, "Baseline")
baseline_element.set("points", coords_to_pagexml_points(self.baseline))

if self.words:
for word in self.words:
word.to_pagexml(text_line, fallback_id=fallback_id, validate_id=validate_id)

if self.transcription is not None:
text_element = ET.SubElement(text_line, "TextEquiv")
if self.transcription_confidence is not None:
Expand All @@ -111,7 +200,7 @@ def to_pagexml(self, region_element: ET.SubElement, fallback_id: int, validate_i
text_element.text = self.transcription

@classmethod
def from_pagexml(cls, line_element: ET.SubElement, schema, fallback_index: int):
def from_pagexml(cls, line_element: ET.SubElement, schema, fallback_index: int, fake_baseline_if_needed: bool = False):
new_textline = cls(id=line_element.attrib['id'])
if 'custom' in line_element.attrib:
new_textline.from_pagexml_parse_custom(line_element.attrib['custom'])
Expand All @@ -125,19 +214,22 @@ def from_pagexml(cls, line_element: ET.SubElement, schema, fallback_index: int):
if new_textline.index is None:
new_textline.index = fallback_index

textline = line_element.find(schema + 'Coords')
if textline is not None:
new_textline.polygon = get_coords_from_pagexml(textline, schema)

baseline = line_element.find(schema + 'Baseline')
if baseline is not None:
new_textline.baseline = get_coords_from_pagexml(baseline, schema)
else:
logger.warning(f'Warning: Baseline is missing in TextLine. '
f'Skipping this line during import. Line ID: {new_textline.id}')
return None

textline = line_element.find(schema + 'Coords')
if textline is not None:
new_textline.polygon = get_coords_from_pagexml(textline, schema)
if fake_baseline_if_needed:
new_textline.baseline = fake_baseline(new_textline.polygon)
else:
logger.warning(f'Warning: Baseline is missing in TextLine. '
f'Skipping this line during import. Line ID: {new_textline.id}')
return None

if not new_textline.heights:
if not new_textline.heights: # and load_baseline
guess_line_heights_from_polygon(new_textline, use_center=False, n=len(new_textline.baseline))

transcription = line_element.find(schema + 'TextEquiv')
Expand All @@ -148,13 +240,23 @@ def from_pagexml(cls, line_element: ET.SubElement, schema, fallback_index: int):
new_textline.transcription = t_unicode
conf = transcription.get('conf', None)
new_textline.transcription_confidence = float(conf) if conf is not None else None

for word in line_element.iter(schema + 'Word'):
new_word = Word.from_pagexml(word, schema)
new_textline.words.append(new_word)

return new_textline

def from_pagexml_parse_custom(self, custom_str):
try:
custom = json.loads(custom_str)
self.category = custom.get('category', None)
self.heights = custom.get('heights', None)
self.category = custom.pop('category', None)
self.heights = custom.pop('heights', None)
word_alignment_origin = custom.pop('word_alignment_origin', None)
if word_alignment_origin is not None:
self.word_alignment_origin = WordAlignmentOrigin(word_alignment_origin)
if len(custom) > 0:
self.custom = custom
except json.decoder.JSONDecodeError:
if 'heights_v2' in custom_str:
for word in custom_str.split():
Expand Down Expand Up @@ -192,9 +294,14 @@ def to_altoxml(self, text_block, arabic_helper, min_line_confidence, version: AL
text_line.set("WIDTH", str(int(text_line_width)))

if self.category in (None, 'text'):
self.to_altoxml_text(text_line, arabic_helper,
text_line_height, text_line_width, text_line_vpos, text_line_hpos)
if not self.words:
self.align_words()

for i, word in enumerate(self.words):
add_space = i < len(self.words) - 1
word.to_altoxml(text_line, arabic_helper=arabic_helper, add_space=add_space)
else:
# export text line of other categories as a single string
string = ET.SubElement(text_line, "String")
string.set("CONTENT", self.transcription)

Expand Down Expand Up @@ -223,12 +330,11 @@ def get_labels(self):
labels.append(0)
return np.array(labels)

def to_altoxml_text(self, text_line, arabic_helper,
text_line_height, text_line_width, text_line_vpos, text_line_hpos):
arabic_line = False
if arabic_helper.is_arabic_line(self.transcription):
arabic_line = True
def align_words(self, force_new: bool = False):
if self.words and not force_new:
return

self.words = []
logits = None
logprobs = None
aligned_letters = None
Expand All @@ -240,7 +346,7 @@ def to_altoxml_text(self, text_line, arabic_helper,
logprobs = self.get_full_logprobs()[self.logit_coords[0]:self.logit_coords[1]]
aligned_letters = align_text(-logprobs, np.array(label), blank_idx)
except (ValueError, IndexError, TypeError) as e:
logger.warning(f'Error: Alto export, unable to align line {self.id} due to exception: {e}.')
logger.warning(f'Error: Alto export, unable to align line {self.id} due to exception: {e}. (Fallback: every word has the same width.)')

if logits is not None and logits.shape[0] > 0:
max_val = np.max(logits, axis=1)
Expand All @@ -252,15 +358,21 @@ def to_altoxml_text(self, text_line, arabic_helper,
else:
self.transcription_confidence = 0.0

average_word_width = (text_line_hpos + text_line_width) / len(self.transcription.split())
for w, word in enumerate(self.transcription.split()):
string = ET.SubElement(text_line, "String")
string.set("CONTENT", word)
line_h, line_w, line_vpos, line_hpos = get_hwvh(self.polygon)

if len(self.transcription.split()) == 0:
return
average_word_width = (line_w / len(self.transcription.split()))

string.set("HEIGHT", str(int(text_line_height)))
string.set("WIDTH", str(int(average_word_width)))
string.set("VPOS", str(int(text_line_vpos)))
string.set("HPOS", str(int(text_line_hpos + (w * average_word_width))))
for w, word_transciption in enumerate(self.transcription.split()):
# create new word with average width
new_word_hpos = line_hpos + (w * average_word_width)
new_word_polygon = hwvh_to_polygon(line_h, average_word_width, line_vpos, new_word_hpos)

new_word = Word(id=f'{self.id}_w{w:03d}', transcription=word_transciption,
polygon=new_word_polygon)
self.words.append(new_word)
self.word_alignment_origin = WordAlignmentOrigin.MEAN_WIDTH
else:
crop_engine = EngineLineCropper(poly=2)
line_coords = crop_engine.get_crop_inputs(self.baseline, self.heights, 16)
Expand Down Expand Up @@ -305,28 +417,18 @@ def to_altoxml_text(self, text_line, arabic_helper,
word_confidence = np.quantile(
confidences[letter_counter:letter_counter + len(splitted_transcription[w])], .50)

string = ET.SubElement(text_line, "String")

if arabic_line:
string.set("CONTENT", arabic_helper.label_form_to_string(splitted_transcription[w]))
else:
string.set("CONTENT", splitted_transcription[w])

string.set("HEIGHT", str(int((np.max(all_y) - np.min(all_y)))))
string.set("WIDTH", str(int((np.max(all_x) - np.min(all_x)))))
string.set("VPOS", str(int(np.min(all_y))))
string.set("HPOS", str(int(np.min(all_x))))
new_word_hpos = np.min(all_x)
new_word_vpos = np.min(all_y)
new_word_height = np.max(all_y) - np.min(all_y)
new_word_width = np.max(all_x) - np.min(all_x)
new_word_polygon = hwvh_to_polygon(new_word_height, new_word_width, new_word_vpos, new_word_hpos)

if word_confidence is not None:
string.set("WC", str(round(word_confidence, 2)))
new_word = Word(id=f'{self.id}_w{w:03d}', transcription=splitted_transcription[w],
polygon=new_word_polygon, transcription_confidence=word_confidence)

if w != (len(self.transcription.split()) - 1):
space = ET.SubElement(text_line, "SP")

space.set("WIDTH", str(4))
space.set("VPOS", str(int(np.min(all_y))))
space.set("HPOS", str(int(np.max(all_x))))
self.words.append(new_word)
letter_counter += len(splitted_transcription[w]) + 1
self.word_alignment_origin = WordAlignmentOrigin.FROM_LOGITS

def to_altoxml_baseline(self, version: ALTOVersion) -> str:
if version == ALTOVersion.ALTO_v2_x:
Expand Down Expand Up @@ -923,17 +1025,32 @@ def load_logits(self, file: str):
line.logit_coords = logit_coords[line.id]

def render_to_image(self, image, thickness: int = 2, circles: bool = True,
render_order: bool = False, render_category: bool = False):
render_order: bool = False, render_category: bool = False,
render_baseline: bool = True, render_words: bool = True
) -> np.ndarray:
"""Render layout into image.
:param image: image to render layout into
:param thickness: thickness of lines
:param circles: render circles at polygon vertices
:param render_order: render region order number given by enumerate(regions) to the middle of given region
:param render_region_id: render region id to the upper left corner of given region
:param render_category: render region category above the upper left corner of given region
:param render_baseline: render region baseline as a red line
:param render_words: render region words as black polygons
:return: image with rendered layout as numpy array
"""
for region_layout in self.regions:
image = draw_lines(
image,
[line.baseline for line in region_layout.lines if line.baseline is not None], color=(0, 0, 255),
circles=(circles, circles, False), thickness=thickness)
if render_words:
for line in region_layout.lines:
for word in line.words:
image = draw_lines(
image,
[word.polygon], color=(0, 0, 0), close=True, thickness=thickness // 2)

if render_baseline:
image = draw_lines(
image,
[line.baseline for line in region_layout.lines if line.baseline is not None], color=(0, 0, 255),
circles=(circles, circles, False), thickness=thickness)
image = draw_lines(
image,
[line.polygon for line in region_layout.lines if line.polygon is not None], color=(0, 255, 0),
Expand Down Expand Up @@ -976,6 +1093,13 @@ def lines_iterator(self, categories: list = None):
if not categories or line.category in categories:
yield line

def words_iterator(self, categories: list = None):
for region in self.regions:
for line in region.lines:
if not categories or line.category in categories:
for word in line.words:
yield word

def get_quality(self, x: int = None, y: int = None, width: int = None, height: int = None, power: int = 6):
bbox_confidences = []
for b, block in enumerate(self.regions):
Expand Down Expand Up @@ -1120,6 +1244,24 @@ def get_hwvh(polygon):

return height, width, vpos, hpos

def hwvh_to_polygon(height: int, width: int, vpos: int, hpos: int) -> np.ndarray:
"""Convert height, width, vpos, hpos to polygon."""
return np.array([
[hpos, vpos],
[hpos + width, vpos],
[hpos + width, vpos + height],
[hpos, vpos + height]
])

def fake_baseline(polygon: np.ndarray, ratio: float = 4) -> np.ndarray:
"""Create a fake baseline for the given polygon (to 1/4 height by default) from the bottom of the polygon."""
height, width, vpos, hpos = get_hwvh(polygon)
new_height = height // ratio

return np.array([
[hpos, vpos + height - new_height],
[hpos + width, vpos + height - new_height]
], dtype=np.int32)

def create_ocr_processing_element(id: str = "IdOcr",
software_creator_str: str = "Project PERO",
Expand Down
5 changes: 5 additions & 0 deletions pero_ocr/document_ocr/page_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ def __init__(self, config, device, config_path=''):
self.categories = config_get_list(config, key='CATEGORIES', fallback=[])
self.substitute_output = config.getboolean('SUBSTITUTE_OUTPUT', fallback=True)
self.substitute_output_atomic = config.getboolean('SUBSTITUTE_OUTPUT_ATOMIC', fallback=True)
self.align_words = config.getboolean('ALIGN_WORDS', fallback=False)
self.update_transcription_by_confidence = config.getboolean(
'UPDATE_TRANSCRIPTION_BY_CONFIDENCE', fallback=False)

Expand Down Expand Up @@ -567,6 +568,10 @@ def process_page(self, img, page_layout: PageLayout):
if self.substitute_output and self.ocr_engine.output_substitution is not None:
self.substitute_transcriptions(lines_to_process)

if self.align_words:
for line in page_layout.lines_iterator():
line.align_words()

return page_layout

def substitute_transcriptions(self, lines_to_process: List[TextLine]):
Expand Down