Skip to content

Commit 545403e

Browse files
author
Rajarshi Das
committed
initial commit, sharing implementation
0 parents  commit 545403e

30 files changed

+57041
-0
lines changed

README.md

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# joint-text-and-kb-inference-semantic-parsing
2+
3+
##Data Processing
4+
```
5+
cd code/
6+
sh run_data.sh ./config_data.sh
7+
```
8+
##Training
9+
```
10+
cd code/
11+
sh run.sh ./config.sh
12+
```

code/KBQA.py

+401
Large diffs are not rendered by default.

code/TextQA.py

Whitespace-only changes.

code/__init__.py

Whitespace-only changes.

code/baseline_eval.py

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import json
2+
3+
def augment_with_baseline_answers(baseline_answer_file, input_file, output_file):
4+
out = open(output_file, 'w')
5+
with open(baseline_answer_file) as input, open(input_file) as data_file:
6+
for baseline_answer_line, line in zip(input, data_file):
7+
baseline_answer_line = baseline_answer_line.strip()
8+
sentence, correct_answer, predicted_answer = baseline_answer_line.split('\t')
9+
correct = 1 if correct_answer == predicted_answer else 0
10+
data = json.loads(line)
11+
data['baseline_answer'] = predicted_answer
12+
data['is_correct'] =correct
13+
out.write(json.dumps(data)+'\n')
14+
15+
16+
def get_baseline_accuracy(input_file, min_num_mem, max_num_mem):
17+
num_correct = 0
18+
num_data = 0
19+
with open(input_file) as input:
20+
for line in input:
21+
line = line.strip()
22+
data = json.loads(line)
23+
num_facts = data['num_facts']
24+
if num_facts < min_num_mem or num_facts > max_num_mem:
25+
continue
26+
num_data += 1
27+
num_correct += data['is_correct']
28+
29+
print('Num data {0:10d}, Num correct {1:10d}, %correct {2:10.4f}'.format(num_data, num_correct, 1.0*num_correct/num_data))
30+
31+
32+
if __name__ == '__main__':
33+
34+
baseline_answer_file = "/home/rajarshi/canvas/data/TextKBQA/dev_answers.txt"
35+
input_file = "/home/rajarshi/canvas/data/TextKBQA/dev_with_facts.json"
36+
output_file = "/home/rajarshi/canvas/data/TextKBQA/dev_with_baseline_answers.json"
37+
38+
# augment_with_baseline_answers(baseline_answer_file, input_file, output_file)
39+
get_baseline_accuracy(output_file, 0, 25000)

code/data_utils.py

+524
Large diffs are not rendered by default.

code/feed_data.py

+245
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
from data_utils import KB, Text, TextKb
2+
import numpy as np
3+
from tqdm import tqdm
4+
5+
6+
class Batcher(object):
7+
def __init__(self, input_file, kb_file, text_kb_file, batch_size, vocab_dir, return_one_epoch=False, shuffle=True,
8+
min_num_mem_slots=100,
9+
max_num_mem_slots=500,
10+
min_num_text_mem_slots=0,
11+
max_num_text_mem_slots=1000,
12+
use_kb_mem=True,
13+
use_text_mem=False):
14+
self.batch_size = batch_size
15+
self.input_file = input_file
16+
self.kb_file = kb_file
17+
self.text_kb_file = text_kb_file
18+
self.shuffle = shuffle
19+
self.max_num_mem_slots = max_num_mem_slots
20+
self.min_num_mem_slots = min_num_mem_slots
21+
self.max_num_text_mem_slots = max_num_text_mem_slots
22+
self.min_num_text_mem_slots = min_num_text_mem_slots
23+
self.vocab_dir = vocab_dir
24+
self.return_one_epoch = return_one_epoch
25+
self.use_kb_mem = use_kb_mem
26+
self.use_text_mem = use_text_mem
27+
self.questions, self.q_lengths, self.answers, \
28+
self.kb_memory_slots, self.kb_num_memories, \
29+
self.text_key_mem, self.text_key_len, \
30+
self.text_val_mem, self.num_text_mems = self.read_files()
31+
self.max_key_len = None
32+
33+
if self.use_text_mem and self.use_kb_mem:
34+
assert self.text_key_mem is not None and self.kb_memory_slots is not None
35+
elif self.use_kb_mem:
36+
assert self.text_key_mem is None and self.kb_memory_slots is not None
37+
else:
38+
assert self.text_key_mem is not None and self.kb_memory_slots is None
39+
40+
self.num_questions = len(self.questions)
41+
print('Num questions {}'.format(self.num_questions))
42+
self.start_index = 0
43+
if self.shuffle:
44+
self.shuffle_data()
45+
46+
def get_next_batch(self):
47+
"""
48+
returns the next batch
49+
TODO(rajarshd): move the if-check outside the loop, so that conditioned is not checked every damn time. the conditions are suppose to be immutable.
50+
"""
51+
while True:
52+
if self.start_index >= self.num_questions:
53+
if self.return_one_epoch:
54+
return # stop after returning one epoch
55+
self.start_index = 0
56+
if self.shuffle:
57+
self.shuffle_data()
58+
else:
59+
num_data_returned = min(self.batch_size, self.num_questions - self.start_index)
60+
assert num_data_returned > 0
61+
end_index = self.start_index + num_data_returned
62+
if self.use_kb_mem and self.use_text_mem:
63+
yield self.questions[self.start_index:end_index], self.q_lengths[self.start_index:end_index], \
64+
self.answers[self.start_index:end_index], self.kb_memory_slots[self.start_index:end_index], \
65+
self.kb_num_memories[self.start_index:end_index], self.text_key_mem[self.start_index:end_index], \
66+
self.text_key_len[self.start_index:end_index], self.text_val_mem[self.start_index:end_index], \
67+
self.num_text_mems[self.start_index:end_index]
68+
elif self.use_kb_mem:
69+
yield self.questions[self.start_index:end_index], self.q_lengths[self.start_index:end_index], \
70+
self.answers[self.start_index:end_index], self.kb_memory_slots[self.start_index:end_index], \
71+
self.kb_num_memories[self.start_index:end_index]
72+
else:
73+
yield self.questions[self.start_index:end_index], self.q_lengths[self.start_index:end_index], \
74+
self.answers[self.start_index:end_index], self.text_key_mem[self.start_index:end_index], \
75+
self.text_key_len[self.start_index:end_index], self.text_val_mem[self.start_index:end_index], \
76+
self.num_text_mems[self.start_index:end_index]
77+
self.start_index = end_index
78+
79+
def shuffle_data(self):
80+
"""
81+
Shuffles maintaining the same order.
82+
"""
83+
perm = np.random.permutation(self.num_questions) # perm of index in range(0, num_questions)
84+
assert len(perm) == self.num_questions
85+
if self.use_kb_mem and self.use_text_mem:
86+
self.questions, self.q_lengths, self.answers, self.kb_memory_slots, self.kb_num_memories, self.text_key_mem,\
87+
self.text_key_len, self.text_val_mem, self.num_text_mems = \
88+
self.questions[perm], self.q_lengths[perm], self.answers[perm], self.kb_memory_slots[perm], \
89+
self.kb_num_memories[perm], self.text_key_mem[perm], self.text_key_len[perm], self.text_val_mem[perm], self.num_text_mems[perm]
90+
elif self.use_kb_mem:
91+
self.questions, self.q_lengths, self.answers, self.kb_memory_slots, self.kb_num_memories = \
92+
self.questions[perm], self.q_lengths[perm], self.answers[perm], self.kb_memory_slots[perm], \
93+
self.kb_num_memories[perm]
94+
else:
95+
self.questions, self.q_lengths, self.answers, self.text_key_mem, self.text_key_len, self.text_val_mem,\
96+
self.num_text_mems = self.questions[perm], self.q_lengths[perm], self.answers[perm], self.text_key_mem[perm],\
97+
self.text_key_len[perm], self.text_val_mem[perm], self.num_text_mems[perm]
98+
def reset(self):
99+
self.start_index = 0
100+
101+
def read_files(self):
102+
"""reads the kb and text files and creates the numpy arrays after padding"""
103+
# read the KB file
104+
kb = KB(self.kb_file, vocab_dir=self.vocab_dir) if self.use_kb_mem else None
105+
# read text kb file
106+
text_kb = TextKb(self.text_kb_file, vocab_dir=self.vocab_dir) if self.use_text_mem else None
107+
self.max_key_len = text_kb.max_key_length if self.use_text_mem else None
108+
# Question file
109+
questions = Text(self.input_file,
110+
max_num_facts=self.max_num_mem_slots,
111+
min_num_facts=self.min_num_mem_slots,
112+
min_num_text_facts=self.min_num_text_mem_slots,
113+
max_num_text_facts=self.max_num_text_mem_slots)
114+
max_q_length, max_num_kb_facts, max_num_text_kb_facts, question_list = questions.max_q_length, \
115+
questions.max_num_kb_facts, \
116+
questions.max_num_text_kb_facts, \
117+
questions.question_list
118+
entity_vocab = kb.entity_vocab if self.use_kb_mem else text_kb.entity_vocab
119+
relation_vocab = kb.relation_vocab if self.use_kb_mem else text_kb.relation_vocab
120+
num_questions = len(question_list)
121+
question_lengths = np.ones([num_questions]) * -1
122+
questions = np.ones([num_questions, max_q_length]) * entity_vocab['PAD']
123+
answers = np.ones_like(question_lengths) * entity_vocab['UNK']
124+
all_kb_memories = None
125+
num_kb_memories = None
126+
text_key_memories = None
127+
text_key_lengths = None
128+
text_val_memories = None
129+
num_text_memories = None
130+
131+
if self.use_kb_mem:
132+
print('Make data tensors for kb')
133+
all_kb_memories = np.ones([num_questions, max_num_kb_facts, 3])
134+
all_kb_memories[:, :, 0].fill(entity_vocab['DUMMY_MEM'])
135+
all_kb_memories[:, :, 2].fill(entity_vocab['DUMMY_MEM'])
136+
all_kb_memories[:, :, 1].fill(relation_vocab['DUMMY_MEM'])
137+
num_kb_memories = np.ones_like(question_lengths) * -1
138+
for q_counter, q in enumerate(tqdm(question_list)):
139+
question_str = q.parsed_question['question']
140+
question_entities = q.parsed_question['entities']
141+
question_indices = q.parsed_question['indices']
142+
q_answers = q.parsed_question['answers']
143+
# num_kb_memories.append(q.parsed_question['num_facts'])
144+
num_kb_memories[q_counter] = q.parsed_question['num_facts']
145+
q_start_indices = np.asarray(q.parsed_question['start_indices'])
146+
q_fact_lengths = np.asarray(
147+
q.parsed_question['fact_lengths']) # for each entity in question retrieve the fact
148+
sorted_index = np.argsort(q_fact_lengths)
149+
q_fact_lengths = q_fact_lengths[sorted_index]
150+
q_start_indices = q_start_indices[sorted_index]
151+
question_words_list = question_str.split(' ')
152+
for counter, index in enumerate(question_indices): # replace the entities with their ids
153+
question_words_list[index] = question_entities[counter]
154+
question_int = [entity_vocab[w_q] if w_q.strip() in entity_vocab else entity_vocab['UNK'] for w_q in
155+
question_words_list]
156+
question_len = len(question_int)
157+
questions[q_counter, 0:question_len] = question_int
158+
question_lengths[q_counter] = question_len
159+
answer_int = [entity_vocab[a] if a in entity_vocab else entity_vocab['UNK'] for a in q_answers]
160+
answers[q_counter] = answer_int[0]
161+
162+
# memories
163+
kb_facts = kb.facts
164+
mem_counter = 0
165+
for counter, start_index in enumerate(q_start_indices):
166+
num_facts = q_fact_lengths[counter]
167+
if mem_counter < self.max_num_mem_slots:
168+
for mem_index in xrange(start_index, start_index + num_facts):
169+
mem = kb_facts[mem_index]
170+
e1_int = entity_vocab[mem['e1']] if mem['e1'] in entity_vocab else entity_vocab['UNK']
171+
e2_int = entity_vocab[mem['e2']] if mem['e2'] in entity_vocab else entity_vocab['UNK']
172+
r_int = relation_vocab[mem['r']] if mem['r'] in relation_vocab else relation_vocab['UNK']
173+
all_kb_memories[q_counter][mem_counter][0] = e1_int
174+
all_kb_memories[q_counter][mem_counter][1] = r_int
175+
all_kb_memories[q_counter][mem_counter][2] = e2_int
176+
mem_counter += 1
177+
if mem_counter == self.max_num_mem_slots: # will use the first max_num_mem_slots slots
178+
break
179+
if self.use_text_mem:
180+
181+
print('Make data tensors for text kb')
182+
max_key_len = text_kb.max_key_length
183+
text_key_memories = np.ones([num_questions, max_num_text_kb_facts, max_key_len]) * entity_vocab['DUMMY_MEM']
184+
text_key_lengths = np.zeros([num_questions, max_num_text_kb_facts])
185+
text_val_memories = np.ones([num_questions, max_num_text_kb_facts]) * entity_vocab['DUMMY_MEM']
186+
num_text_memories = np.ones_like(question_lengths) * -1
187+
for q_counter, q in enumerate(tqdm(question_list)):
188+
# TODO (rajarshd): Move the repeated piece of code in a method.
189+
question_str = q.parsed_question['question']
190+
question_entities = q.parsed_question['entities']
191+
question_indices = q.parsed_question['indices']
192+
q_answers = q.parsed_question['answers']
193+
question_words_list = question_str.split(' ')
194+
for counter, index in enumerate(question_indices): # replace the entities with their ids
195+
question_words_list[index] = question_entities[counter]
196+
question_int = [entity_vocab[w_q] if w_q.strip() in entity_vocab else entity_vocab['UNK'] for w_q in
197+
question_words_list]
198+
question_len = len(question_int)
199+
questions[q_counter, 0:question_len] = question_int
200+
question_lengths[q_counter] = question_len
201+
answer_int = [entity_vocab[a] if a in entity_vocab else entity_vocab['UNK'] for a in q_answers]
202+
answers[q_counter] = answer_int[0]
203+
204+
# memories
205+
num_q_text_memories = q.parsed_question['text_kb_num_facts']
206+
# in the training set, account for the discarded memories
207+
if 'black_lists' in q.parsed_question:
208+
num_discarded = 0
209+
for black_list in q.parsed_question['black_lists']:
210+
num_discarded += len(black_list)
211+
num_q_text_memories -= num_discarded
212+
num_text_memories[q_counter] = num_q_text_memories
213+
q_start_indices = np.asarray(q.parsed_question['text_kb_start_indices'])
214+
q_fact_lengths = np.asarray(
215+
q.parsed_question['text_kb_lengths']) # for each entity in question retrieve the fact
216+
q_black_lists = np.asarray(
217+
q.parsed_question['black_lists']) if 'black_lists' in q.parsed_question else None
218+
sorted_index = np.argsort(q_fact_lengths)
219+
q_fact_lengths = q_fact_lengths[sorted_index]
220+
q_start_indices = q_start_indices[sorted_index]
221+
q_black_lists = q_black_lists[sorted_index] if q_black_lists is not None else None
222+
text_kb_facts = text_kb.facts_list
223+
mem_counter = 0
224+
for counter, start_index in enumerate(q_start_indices):
225+
num_facts = q_fact_lengths[counter]
226+
black_list_entity = set(q_black_lists[counter]) if q_black_lists is not None else None
227+
if mem_counter < self.max_num_text_mem_slots:
228+
for mem_entity_counter, mem_index in enumerate(xrange(start_index, start_index + num_facts)):
229+
if black_list_entity is not None and mem_entity_counter in black_list_entity:
230+
continue
231+
mem = text_kb_facts[mem_index]
232+
key = mem['key']
233+
key_int = [entity_vocab[k] if k in entity_vocab else entity_vocab['UNK'] for k in key]
234+
val = mem['value']
235+
val_int = entity_vocab[val] if val in entity_vocab else entity_vocab['UNK']
236+
key_len = int(mem['key_length'])
237+
text_key_memories[q_counter][mem_counter][0:key_len] = key_int
238+
text_val_memories[q_counter][mem_counter] = val_int
239+
text_key_lengths[q_counter][mem_counter] = key_len
240+
mem_counter += 1
241+
if mem_counter == self.max_num_text_mem_slots: # will use the first max_num_mem_slots slots
242+
break
243+
244+
return questions, question_lengths, answers, all_kb_memories, num_kb_memories, \
245+
text_key_memories, text_key_lengths, text_val_memories, num_text_memories

code/get_stats.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import json
2+
import util
3+
from collections import defaultdict
4+
5+
6+
def get_fb_stats(freebase_data_file):
7+
with open(freebase_data_file) as fb:
8+
fact_counter = 0
9+
relation_set = set()
10+
entity_set = set()
11+
for line in fb:
12+
line = line.strip()
13+
line = line[1:-1]
14+
e1, r1, r2, e2 = [a.strip('"') for a in [x.strip() for x in line.split(',')]]
15+
r = r1 + '_' + r2
16+
fact_counter += 1
17+
relation_set.add(r)
18+
entity_set.add(e1)
19+
entity_set.add(e2)
20+
21+
print("Total num of facts {}".format(fact_counter))
22+
print("Num unique entities {}".format(len(entity_set)))
23+
print("Num unique relations {}".format(len(relation_set)))
24+
25+
26+
def get_questions_stats(train_data_file, dev_data_file):
27+
print('1. Getting the number of blanks')
28+
29+
blank_str = '_blank_'
30+
num_blanks_map = defaultdict(int)
31+
word_freq_train = defaultdict(int)
32+
with open(train_data_file) as train_file:
33+
for counter, line in enumerate(util.verboserate(train_file)):
34+
line = line.strip()
35+
q_json = json.loads(line)
36+
q = q_json['sentence']
37+
count = q.count(blank_str)
38+
num_blanks_map[count] += 1
39+
words = q.split(' ')
40+
for word in words:
41+
word = word.strip()
42+
word_freq_train[word] += 1
43+
a_list = q_json['answerSubset']
44+
for a in a_list:
45+
word_freq_train[a] = word_freq_train[word] + 1
46+
47+
print(num_blanks_map)
48+
49+
print '2. Number of word types in the train set {}'.format(len(word_freq_train))
50+
51+
print '3. Checking overlap with the dev answers'
52+
dev_answers_present = set()
53+
dev_answers_oov = set()
54+
dev_answers = set()
55+
with open(dev_data_file) as dev_file:
56+
for line in dev_file:
57+
line = line.strip()
58+
dev_json = json.loads(line)
59+
a_list = dev_json['answerSubset']
60+
for a in a_list:
61+
if a in word_freq_train:
62+
dev_answers_present.add(a)
63+
else:
64+
dev_answers_oov.add(a)
65+
dev_answers.add(a)
66+
67+
print 'Number of unique dev answer strings {}'.format(len(dev_answers))
68+
69+
print 'Number of oov answer strings in dev set {}'.format(len(dev_answers_oov))
70+
71+
print 'Number of dev answer strings which have atleast 1 occurrences in train set {}'.format(
72+
len(dev_answers_present))
73+
74+
75+
freebase_data_file = "/home/rajarshi/research/graph-parser/data/spades/freebase.spades.txt"
76+
train_data_file = "/home/rajarshi/research/graph-parser/data/spades/train.json"
77+
dev_data_file = "/home/rajarshi/research/graph-parser/data/spades/dev.json"
78+
# get_fb_stats()
79+
get_questions_stats(train_data_file, dev_data_file)

0 commit comments

Comments
 (0)