28
28
29
29
30
30
class DocVQAExample (object ):
31
-
32
- def __init__ (self ,
33
- question ,
34
- doc_tokens ,
35
- doc_boxes = [],
36
- answer = None ,
37
- labels = None ,
38
- image = None ):
31
+ def __init__ (self , question , doc_tokens , doc_boxes = [], answer = None , labels = None , image = None ):
39
32
self .question = question
40
33
self .doc_tokens = doc_tokens
41
34
self .doc_boxes = doc_boxes
@@ -47,13 +40,7 @@ def __init__(self,
47
40
class DocVQAFeatures (object ):
48
41
"""A single set of features of data."""
49
42
50
- def __init__ (self ,
51
- example_index ,
52
- input_ids ,
53
- input_mask ,
54
- segment_ids ,
55
- boxes = None ,
56
- label = None ):
43
+ def __init__ (self , example_index , input_ids , input_mask , segment_ids , boxes = None , label = None ):
57
44
self .example_index = example_index
58
45
self .input_ids = input_ids
59
46
self .input_mask = input_mask
@@ -63,15 +50,9 @@ def __init__(self,
63
50
64
51
65
52
class DocVQA (Dataset ):
66
-
67
- def __init__ (self ,
68
- args ,
69
- tokenizer ,
70
- label2id_map ,
71
- max_seq_len = 512 ,
72
- max_query_length = 20 ,
73
- max_doc_length = 512 ,
74
- max_span_num = 1 ):
53
+ def __init__ (
54
+ self , args , tokenizer , label2id_map , max_seq_len = 512 , max_query_length = 20 , max_doc_length = 512 , max_span_num = 1
55
+ ):
75
56
super (DocVQA , self ).__init__ ()
76
57
self .tokenizer = tokenizer
77
58
self .label2id_map = label2id_map
@@ -113,17 +94,16 @@ def check_is_max_context(self, doc_spans, cur_span_index, position):
113
94
continue
114
95
num_left_context = position - doc_span .start
115
96
num_right_context = end - position
116
- score = min (num_left_context ,
117
- num_right_context ) + 0.01 * doc_span .length
97
+ score = min (num_left_context , num_right_context ) + 0.01 * doc_span .length
118
98
if best_score is None or score > best_score :
119
99
best_score = score
120
100
best_span_index = span_index
121
101
122
102
return cur_span_index == best_span_index
123
103
124
- def convert_examples_to_features (self , examples , tokenizer , label_map ,
125
- max_seq_length , max_span_num ,
126
- max_doc_length , max_query_length ):
104
+ def convert_examples_to_features (
105
+ self , examples , tokenizer , label_map , max_seq_length , max_span_num , max_doc_length , max_query_length
106
+ ):
127
107
128
108
if "[CLS]" in self .tokenizer .get_vocab ():
129
109
start_token = "[CLS]"
@@ -188,8 +168,7 @@ def convert_examples_to_features(self, examples, tokenizer, label_map,
188
168
segment_ids .append (0 )
189
169
for i in range (doc_span .length ):
190
170
split_token_index = doc_span .start + i
191
- is_max_context = self .check_is_max_context (
192
- doc_spans , doc_span_index , split_token_index )
171
+ is_max_context = self .check_is_max_context (doc_spans , doc_span_index , split_token_index )
193
172
token_is_max_context [len (tokens )] = is_max_context
194
173
tokens .append (all_doc_tokens [split_token_index ])
195
174
boxes_tokens .append (all_doc_boxes_tokens [split_token_index ])
@@ -292,12 +271,10 @@ def create_examples(self, data, is_test=False):
292
271
question = sample ["question" ]
293
272
doc_tokens = sample ["document" ]
294
273
doc_boxes = sample ["document_bbox" ]
295
- labels = sample [' labels' ] if not is_test else []
274
+ labels = sample [" labels" ] if not is_test else []
296
275
297
- x_min , y_min = min (doc_boxes , key = lambda x : x [0 ])[0 ], min (
298
- doc_boxes , key = lambda x : x [2 ])[2 ]
299
- x_max , y_max = max (doc_boxes , key = lambda x : x [1 ])[1 ], max (
300
- doc_boxes , key = lambda x : x [3 ])[3 ]
276
+ x_min , y_min = min (doc_boxes , key = lambda x : x [0 ])[0 ], min (doc_boxes , key = lambda x : x [2 ])[2 ]
277
+ x_max , y_max = max (doc_boxes , key = lambda x : x [1 ])[1 ], max (doc_boxes , key = lambda x : x [3 ])[3 ]
301
278
width = x_max - x_min
302
279
height = y_max - y_min
303
280
@@ -308,12 +285,15 @@ def create_examples(self, data, is_test=False):
308
285
scale_x = 1000 / max (width , height )
309
286
scale_y = 1000 / max (width , height )
310
287
311
- scaled_doc_boxes = [[
312
- round ((b [0 ] - x_min ) * scale_x ),
313
- round ((b [2 ] - y_min ) * scale_y ),
314
- round ((b [1 ] - x_min ) * scale_x ),
315
- round ((b [3 ] - y_min ) * scale_y )
316
- ] for b in doc_boxes ]
288
+ scaled_doc_boxes = [
289
+ [
290
+ round ((b [0 ] - x_min ) * scale_x ),
291
+ round ((b [2 ] - y_min ) * scale_y ),
292
+ round ((b [1 ] - x_min ) * scale_x ),
293
+ round ((b [3 ] - y_min ) * scale_y ),
294
+ ]
295
+ for b in doc_boxes
296
+ ]
317
297
318
298
for box , oribox in zip (scaled_doc_boxes , doc_boxes ):
319
299
if box [0 ] < 0 :
@@ -326,10 +306,9 @@ def create_examples(self, data, is_test=False):
326
306
if pos > 1000 :
327
307
print (width , height , box , oribox )
328
308
329
- example = DocVQAExample (question = question ,
330
- doc_tokens = doc_tokens ,
331
- doc_boxes = scaled_doc_boxes ,
332
- labels = labels )
309
+ example = DocVQAExample (
310
+ question = question , doc_tokens = doc_tokens , doc_boxes = scaled_doc_boxes , labels = labels
311
+ )
333
312
examples .append (example )
334
313
return examples
335
314
@@ -339,7 +318,7 @@ def docvqa_input(self):
339
318
dataset = self .args .train_file
340
319
elif self .args .do_test :
341
320
dataset = self .args .test_file
342
- with open (dataset , 'r' , encoding = ' utf8' ) as f :
321
+ with open (dataset , "r" , encoding = " utf8" ) as f :
343
322
for index , line in enumerate (f ):
344
323
data .append (json .loads (line .strip ()))
345
324
@@ -353,30 +332,32 @@ def docvqa_input(self):
353
332
max_seq_length = self .max_seq_len ,
354
333
max_doc_length = self .max_doc_length ,
355
334
max_span_num = self .max_span_num ,
356
- max_query_length = self .max_query_length )
357
-
358
- all_input_ids = paddle .to_tensor ([f .input_ids for f in features ],
359
- dtype = "int64" )
360
- all_input_mask = paddle .to_tensor ([f .input_mask for f in features ],
361
- dtype = "int64" )
362
- all_segment_ids = paddle .to_tensor ([f .segment_ids for f in features ],
363
- dtype = "int64" )
364
- all_bboxes = paddle .to_tensor ([f .boxes for f in features ],
365
- dtype = "int64" )
366
- all_labels = paddle .to_tensor ([f .label for f in features ],
367
- dtype = "int64" )
335
+ max_query_length = self .max_query_length ,
336
+ )
337
+
338
+ all_input_ids = paddle .to_tensor ([f .input_ids for f in features ], dtype = "int64" )
339
+ all_input_mask = paddle .to_tensor ([f .input_mask for f in features ], dtype = "int64" )
340
+ all_segment_ids = paddle .to_tensor ([f .segment_ids for f in features ], dtype = "int64" )
341
+ all_bboxes = paddle .to_tensor ([f .boxes for f in features ], dtype = "int64" )
342
+ all_labels = paddle .to_tensor ([f .label for f in features ], dtype = "int64" )
368
343
self .sample_list = [
369
344
np .array (all_input_ids ),
370
345
np .array (all_input_mask ),
371
346
np .array (all_segment_ids ),
372
347
np .array (all_bboxes ),
373
- np .array (all_labels )
348
+ np .array (all_labels ),
374
349
]
375
350
376
351
def __getitem__ (self , idx ):
377
- return self .sample_list [0 ][idx ], self .sample_list [1 ][
378
- idx ], self .sample_list [2 ][idx ], self .sample_list [3 ][
379
- idx ], self .sample_list [4 ][idx ]
380
-
381
- def __len__ (self , ):
352
+ return (
353
+ self .sample_list [0 ][idx ],
354
+ self .sample_list [1 ][idx ],
355
+ self .sample_list [2 ][idx ],
356
+ self .sample_list [3 ][idx ],
357
+ self .sample_list [4 ][idx ],
358
+ )
359
+
360
+ def __len__ (
361
+ self ,
362
+ ):
382
363
return self .sample_list [0 ].shape [0 ]
0 commit comments