-
Notifications
You must be signed in to change notification settings - Fork 81
Expand file tree
/
Copy pathutils.py
More file actions
164 lines (144 loc) · 5.65 KB
/
utils.py
File metadata and controls
164 lines (144 loc) · 5.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# -*- coding: utf-8 -*-
"""
# @Time : 2019/5/24
# @Author : Jiaqi&Zecheng
# @File : utils.py
# @Software: PyCharm
"""
import json
from nltk.stem import WordNetLemmatizer
VALUE_FILTER = ['what', 'how', 'list', 'give', 'show', 'find', 'id', 'order', 'when']
AGG = ['average', 'sum', 'max', 'min', 'minimum', 'maximum', 'between']
wordnet_lemmatizer = WordNetLemmatizer()
def load_dataSets(args):
with open(args.table_path, 'r', encoding='utf8') as f:
table_datas = json.load(f)
with open(args.data_path, 'r', encoding='utf8') as f:
datas = json.load(f)
output_tab = {}
tables = {}
tabel_name = set()
for i in range(len(table_datas)):
table = table_datas[i]
temp = {}
temp['col_map'] = table['column_names']
temp['table_names'] = table['table_names']
tmp_col = []
for cc in [x[1] for x in table['column_names']]:
if cc not in tmp_col:
tmp_col.append(cc)
table['col_set'] = tmp_col
db_name = table['db_id']
tabel_name.add(db_name)
table['schema_content'] = [col[1] for col in table['column_names']]
table['col_table'] = [col[0] for col in table['column_names']]
output_tab[db_name] = temp
tables[db_name] = table
for d in datas:
d['names'] = tables[d['db_id']]['schema_content']
d['table_names'] = tables[d['db_id']]['table_names']
d['col_set'] = tables[d['db_id']]['col_set']
d['col_table'] = tables[d['db_id']]['col_table']
keys = {}
for kv in tables[d['db_id']]['foreign_keys']:
keys[kv[0]] = kv[1]
keys[kv[1]] = kv[0]
for id_k in tables[d['db_id']]['primary_keys']:
keys[id_k] = id_k
d['keys'] = keys
return datas, tables
def group_header(toks, idx, num_toks, header_toks):
for endIdx in reversed(range(idx + 1, num_toks+1)):
sub_toks = toks[idx: endIdx]
sub_toks = " ".join(sub_toks)
if sub_toks in header_toks:
return endIdx, sub_toks
return idx, None
def fully_part_header(toks, idx, num_toks, header_toks):
for endIdx in reversed(range(idx + 1, num_toks+1)):
sub_toks = toks[idx: endIdx]
if len(sub_toks) > 1:
sub_toks = " ".join(sub_toks)
if sub_toks in header_toks:
return endIdx, sub_toks
return idx, None
def partial_header(toks, idx, header_toks):
def check_in(list_one, list_two):
if len(set(list_one) & set(list_two)) == len(list_one) and (len(list_two) <= 3):
return True
for endIdx in reversed(range(idx + 1, len(toks))):
sub_toks = toks[idx: min(endIdx, len(toks))]
if len(sub_toks) > 1:
flag_count = 0
tmp_heads = None
for heads in header_toks:
if check_in(sub_toks, heads):
flag_count += 1
tmp_heads = heads
if flag_count == 1:
return endIdx, tmp_heads
return idx, None
def symbol_filter(questions):
question_tmp_q = []
for q_id, q_val in enumerate(questions):
if len(q_val) > 2 and q_val[0] in ["'", '"', '`', '鈥�', '鈥�'] and q_val[-1] in ["'", '"', '`', '鈥�']:
question_tmp_q.append("'")
question_tmp_q += ["".join(q_val[1:-1])]
question_tmp_q.append("'")
elif len(q_val) > 2 and q_val[0] in ["'", '"', '`', '鈥�'] :
question_tmp_q.append("'")
question_tmp_q += ["".join(q_val[1:])]
elif len(q_val) > 2 and q_val[-1] in ["'", '"', '`', '鈥�']:
question_tmp_q += ["".join(q_val[0:-1])]
question_tmp_q.append("'")
elif q_val in ["'", '"', '`', '鈥�', '鈥�', '``', "''"]:
question_tmp_q += ["'"]
else:
question_tmp_q += [q_val]
return question_tmp_q
def group_values(toks, idx, num_toks):
def check_isupper(tok_lists):
for tok_one in tok_lists:
if tok_one[0].isupper() is False:
return False
return True
for endIdx in reversed(range(idx + 1, num_toks + 1)):
sub_toks = toks[idx: endIdx]
if len(sub_toks) > 1 and check_isupper(sub_toks) is True:
return endIdx, sub_toks
if len(sub_toks) == 1:
if sub_toks[0][0].isupper() and sub_toks[0].lower() not in VALUE_FILTER and \
sub_toks[0].lower().isalnum() is True:
return endIdx, sub_toks
return idx, None
def group_digital(toks, idx):
test = toks[idx].replace(':', '')
test = test.replace('.', '')
if test.isdigit():
return True
else:
return False
def group_symbol(toks, idx, num_toks):
if toks[idx-1] == "'":
for i in range(0, min(3, num_toks-idx)):
if toks[i + idx] == "'":
return i + idx, toks[idx:i+idx]
return idx, None
def num2year(tok):
if len(str(tok)) == 4 and str(tok).isdigit() and int(str(tok)[:2]) < 22 and int(str(tok)[:2]) > 15:
return True
return False
def set_header(toks, header_toks, tok_concol, idx, num_toks):
def check_in(list_one, list_two):
if set(list_one) == set(list_two):
return True
for endIdx in range(idx, num_toks):
toks += tok_concol[endIdx]
if len(tok_concol[endIdx]) > 1:
break
for heads in header_toks:
if check_in(toks, heads):
return heads
return None