forked from ZzZZCHS/Chat-Scene
-
Notifications
You must be signed in to change notification settings - Fork 0
/
prepare_sqa3d_annos.py
87 lines (76 loc) · 2.8 KB
/
prepare_sqa3d_annos.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import json
import numpy as np
import os
import nltk
import random
from tqdm import tqdm
anno_dir = 'annotations/sqa3d'
def convert_person_view(sentence):
# first-person view to second-person view
forms = {'i': 'you', 'me': 'you', 'my': 'your', 'mine': 'yours', 'am': 'are'}
def translate(word):
if word.lower() in forms:
return forms[word.lower()]
return word
result = ' '.join([translate(word) for word in nltk.wordpunct_tokenize(sentence)])
return result.capitalize()
def get_sqa_question_type(question):
question = question.lstrip()
if question[:4].lower() == 'what':
return 0
elif question[:2].lower() == 'is':
return 1
elif question[:3].lower() == 'how':
return 2
elif question[:3].lower() == 'can':
return 3
elif question[:5].lower() == 'which':
return 4
else:
return 5 # others
for split in ['train', 'val']:
scan_ids = []
sqa_annos = []
question_file = os.path.join(anno_dir, f'v1_balanced_questions_{split}_scannetv2.json')
with open(question_file, 'r', encoding='utf-8') as f:
question_data = json.load(f)['questions']
question_map = {}
for item in question_data:
question_map[item['question_id']] = {
's': [item['situation']] + item['alternative_situation'], # list of str
'q': item['question'], # str
}
anno_file = os.path.join(anno_dir, f'v1_balanced_sqa_annotations_{split}_scannetv2.json')
with open(anno_file, 'r', encoding='utf-8') as f:
anno_data = json.load(f)['annotations']
for item in tqdm(anno_data):
scan_ids.append(item['scene_id'])
scene_id = item['scene_id']
obj_id = 0
situation = random.choice(question_map[item['question_id']]['s'])
question = question_map[item['question_id']]['q']
question_type = get_sqa_question_type(question)
prompt = situation + ' ' + question + " Answer the question using a single word or phrase."
answers = [meta['answer'] for meta in item['answers']]
if split == 'train':
answer = random.choice(answers)
answer = answer.capitalize()
if answer[-1] != ".":
answer += "."
sqa_annos.append({
'scene_id': scene_id,
'obj_id': obj_id,
'prompt': prompt,
'caption': answer,
'sqa_type': question_type
})
else:
sqa_annos.append({
'scene_id': scene_id,
'obj_id': obj_id,
'prompt': prompt,
'ref_captions': answers,
'sqa_type': question_type
})
with open(f"annotations/sqa3d_{split}.json", "w") as f:
json.dump(sqa_annos, f, indent=4)