-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconvert_collection_to_memmap.py
66 lines (56 loc) · 2.42 KB
/
convert_collection_to_memmap.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import json
import argparse
import numpy as np
from tqdm import tqdm
def cvt_collection_to_memmap(args):
collection_size = sum(1 for _ in open(args.tokenized_collection))
max_seq_length = 512
token_ids = np.memmap(f"{args.output_dir}/token_ids.memmap", dtype='int32',
mode='w+', shape=(collection_size, max_seq_length))
pids = np.memmap(f"{args.output_dir}/pids.memmap", dtype='int32',
mode='w+', shape=(collection_size,))
lengths = np.memmap(f"{args.output_dir}/lengths.memmap", dtype='int32',
mode='w+', shape=(collection_size,))
for idx, line in enumerate(tqdm(open(args.tokenized_collection),
desc="collection", total=collection_size)):
data = json.loads(line)
assert int(data['id']) == idx
pids[idx] = idx
lengths[idx] = len(data['ids'])
ids = data['ids'][:max_seq_length]
token_ids[idx, :lengths[idx]] = ids
def cvt_collection_to_memmap_docrank(args):
'''
:param args:
:return:
'''
collection_size = sum(1 for _ in open(args.tokenized_collection))
max_seq_length = 512
token_ids = np.memmap(f"{args.output_dir}/token_ids.memmap", dtype='int32',
mode='w+', shape=(collection_size, max_seq_length))
pids = np.memmap(f"{args.output_dir}/pids.memmap", dtype='int32',
mode='w+', shape=(collection_size,))
lengths = np.memmap(f"{args.output_dir}/lengths.memmap", dtype='int32',
mode='w+', shape=(collection_size,))
for idx, line in enumerate(tqdm(open(args.tokenized_collection),
desc="collection", total=collection_size)):
data = json.loads(line)
# assert int(data['id']) == idx
pids[idx] = int(data['id'][1:])
lengths[idx] = len(data['ids'])
ids = data['ids'][:max_seq_length]
token_ids[idx, :lengths[idx]] = ids
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tokenized_collection", type=str,
default="./data/tokenize/collection.tokenize.json")
parser.add_argument("--output_dir", type=str, default="./data/collection_memmap")
parser.add_argument("--dataset_type", type=str, default="document_ranking")
args = parser.parse_args()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
if args.dataset_type == 'document_ranking':
cvt_collection_to_memmap_docrank(args)
else:
cvt_collection_to_memmap(args)