-
Notifications
You must be signed in to change notification settings - Fork 0
/
build_query_result.py
222 lines (202 loc) · 9.59 KB
/
build_query_result.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import json
import re
d_train= json.load(open('createData/multiwoz21/db/train_db.json'))
d_rest = json.load(open('createData/multiwoz21/db/restaurant_db.json'))
d_hotel = json.load(open('createData/multiwoz21/db/hotel_db.json'))
d_police = json.load(open('createData/multiwoz21/db/police_db.json'))
d_hosp = json.load(open('createData/multiwoz21/db/hospital_db.json'))
d_attr = json.load(open('createData/multiwoz21/db/attraction_db.json'))
d_taxi = [{
"taxi_colors" : ["black","white","red","yellow","blue","grey"],
"taxi_types": ["toyota","skoda","bmw","honda","ford","audi","lexus","volvo","volkswagen","tesla"],
"taxi_phone": ["^[0-9]{10}$"]
}]
entity_db_map = {'train':d_train, 'restaurant': d_rest, 'police': d_police, 'hospital': d_hosp, 'attraction': d_attr, 'taxi':d_taxi,'hotel':d_hotel}
d_data = json.load(open('createData/multiwoz21/delex.json'))
def get_results(semi,ent):
db = entity_db_map[ent]
ret=[]
for row in db:
# print(row)
match = True
for k in semi.keys():
if k not in row:
continue
# normalise
val = row[k].lower().strip()
semi[k] = semi[k].lower().strip()
if k == 'name':
val = val.replace('the', '')
semi[k] = semi[k].replace('the', '')
val = val.replace("b & b", "bed and breakfast")
semi[k] = semi[k].replace("b & b", "bed and breakfast")
val = val.replace("restaurant", "")
semi[k] = semi[k].replace("restaurant", "")
if "hotel" in val and 'gonville' not in val:
val = val.replace(" hotel", "")
if "hotel" in semi[k] and 'gonville' not in semi[k]:
semi[k] = semi[k].replace("hotel", "")
val = val.strip()
semi[k] = semi[k].strip()
if(val!=semi[k] and semi[k]!='not mentioned' and semi[k]!='dontcare' and semi[k]!='none' and semi[k]!=''):
match=False
break
if(match):
ret.append(row)
if len(ret) == 0:
if "leaveAt" in semi.keys() or "arriveBy" in semi.keys():
# print(semi)
ret1=[]
ret2=[]
for row in db:
# print(row)
match = True
for k in semi.keys():
if k not in row:
continue
row[k] = row[k].lower()
if(k == "arriveBy" and semi[k]!='not mentioned' and semi[k]!='dontcare'and semi[k]!='none' and semi[k]!=''):
if semi[k] == "afternoon" or semi[k] == "after lunch":
if int(row[k][:2])<=16:
match=True
else:
match=False
break
elif semi[k] == "morning":
if int(row[k][:2])<=11:
match=True
else:
match=False
break
elif semi[k][0]!= ':' and int(semi[k].split(':')[0][-2:]) >= int(row[k].split(':')[0][-2:]):
match=True
else:
match=False
break
elif(k == "leaveAt" and semi[k]!='not mentioned' and semi[k]!='dontcare'and semi[k]!='none' and semi[k]!=''):
if semi[k] == "afternoon" or semi[k] == "after lunch":
if int(row[k][:2])<=16:
match=True
else:
match=False
break
elif semi[k] == "morning":
if int(row[k][:2])<=11:
match=True
else:
match=False
break
elif semi[k][0]!= ':' and int(semi[k].split(':')[0][-2:])-1 <= int(row[k].split(':')[0][-2:]):
match=True
else:
match=False
break
elif(row[k]!=semi[k] and semi[k]!='not mentioned' and semi[k]!='dontcare'and semi[k]!='none' and semi[k]!=''):
match=False
break
if match:
ret.append(row)
return ret
def check_query_semi(semi):
for k in semi.keys():
if((semi[k]!='not mentioned') and semi[k]!='' and semi[k]!='none'):
return True
for dial_k in d_data.keys():
all_results = []
all_queries = []
goal = d_data[dial_k]['goal']
topics_allowed = []
# mark the topics that are mentioned in the goal
for t in ['train', 'attraction', 'taxi', 'restaurant', 'hotel']:
if goal[t]:
topics_allowed.append(t)
current_topic = ''
# mark the topics are finished
topics_done = {'train' : False, 'restaurant' : False, 'attraction' : False, 'hotel':False, 'taxi':False}
# go through theconversation logs and add the query and result key to it
for utt in d_data[dial_k]['log']:
meta = utt['metadata']
new_meta = utt['metadata']
text = utt['text']
if(meta=={}):
continue
found = False
# add the queries and results key
utt['results']=[]
utt['queries']=[]
possible_topic = []
for k in meta.keys():
if (k in topics_allowed and k!='bus' and check_query_semi(meta[k]['semi'])):
if(topics_done[k] != True):
possible_topic.append(k)
# determine the current topic i.e. if 2 topics are possible then it means the current topics will be changed
if len(possible_topic) == 1:
current_topic = possible_topic[0]
elif len(possible_topic) == 2:
if current_topic == possible_topic[0]:
topics_done[current_topic] = True
current_topic = possible_topic[1]
else:
topics_done[current_topic] = True
current_topic = possible_topic[0]
elif len(possible_topic) >= 3:
print('ERROR'*100)
print(possible_topic)
# try:
if current_topic!='train' and current_topic!='taxi' and current_topic != '':
# print(current_topic)
# print(meta[current_topic]['semi'])
utt['results'].extend(get_results(meta[current_topic]['semi'],current_topic))
q = meta[current_topic]['semi']
for b in meta[current_topic]['book']:
if b != 'booked':
q[b] = meta[current_topic]['book'][b]
utt['queries'].extend((current_topic,q))
all_queries.append((current_topic,q))
# utt['queries'].extend((k,meta[k]['semi']))
# all_queries.append((k,meta[k]['semi']))
all_results.append(utt['results'])
found=True
elif current_topic=='taxi':
utt['results'].extend([])
q = meta[current_topic]['semi']
for b in meta[current_topic]['book']:
if b != 'booked':
q[b] = meta[current_topic]['book'][b]
utt['queries'].extend((current_topic,q))
all_queries.append((current_topic,q))
# utt['queries'].extend((k,meta[k]['semi']))
# all_queries.append((k,meta[k]['semi']))
all_results.append(utt['results'])
found=True
elif current_topic == 'train':
text = text.split(' ')
for i in range(len(text)):
if text[i].find('arriv') != -1:
for cand in text[i:i+5]:
if cand.find(":") != -1:
new_meta[current_topic]['semi']['arriveBy'] = cand
break
if text[i].find('leav') != -1 or text[i].find('depart') != -1:
for cand in text[i:i+5]:
if cand.find(":") != -1:
new_meta[current_topic]['semi']['leaveAt'] = cand
break
if text[i][:2] == 'TR' and text[i][2:].isdecimal():
new_meta[current_topic]['semi']['trainID'] = text[i]
utt['results'].extend(get_results(new_meta[current_topic]['semi'],current_topic))
q = meta[current_topic]['semi']
for b in meta[current_topic]['book']:
if b != 'booked':
q[b] = meta[current_topic]['book'][b]
utt['queries'].extend((current_topic,q))
all_queries.append((current_topic,q))
# utt['queries'].extend((k,meta[k]['semi']))
# all_queries.append((k,meta[k]['semi']))
all_results.append(utt['results'])
found=True
d_data[dial_k]['all_queries'] = all_queries
d_data[dial_k]['all_results'] = all_results
print('Adding ', len(all_queries), 'queries')
print('Adding ', sum(map(lambda x: len(x),all_results)), 'results')
json.dump(d_data, open('createData/multiwoz21/delex_query_results.json','w'))