-
Notifications
You must be signed in to change notification settings - Fork 47
/
Copy pathsubset_utils.py
179 lines (144 loc) · 5.69 KB
/
subset_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import config_reader
import database_helper
from db_connect import MySqlConnection
# this function generally copies all columns as is, but if the table has been selected as
# breaking a dependency cycle, then it will insert NULLs instead of that table's foreign keys
# to the downstream dependency that breaks the cycle
def columns_to_copy(table, relationships, conn):
target_breaks = set()
opportunists = config_reader.get_preserve_fk_opportunistically()
for dep_break in config_reader.get_dependency_breaks():
if dep_break.fk_table == table and dep_break not in opportunists:
target_breaks.add(dep_break.target_table)
columns_to_null = set()
for rel in relationships:
if rel['fk_table'] == table and rel['target_table'] in target_breaks:
columns_to_null.update(rel['fk_columns'])
columns = database_helper.get_specific_helper().get_table_columns(table_name(table), schema_name(table), conn)
return ','.join(['{}.{}'.format(quoter(table_name(table)), quoter(c)) if c not in columns_to_null else 'NULL as {}'.format(quoter(c)) for c in columns])
def upstream_filter_match(target, table_columns):
retval = []
filters = config_reader.get_upstream_filters()
for filter in filters:
if "table" in filter and target == filter["table"]:
retval.append(filter["condition"])
if "column" in filter and filter["column"] in table_columns:
retval.append(filter["condition"])
return retval
def redact_relationships(relationships):
breaks = config_reader.get_dependency_breaks()
retval = [r for r in relationships if (r['fk_table'], r['target_table']) not in breaks]
return retval
def find(f, seq):
"""Return first item in sequence where f(item) == True."""
for item in seq:
if f(item):
return item
def compute_upstream_tables(target_tables, order):
upstream_tables = []
in_upstream = False
for strata in order:
if in_upstream:
upstream_tables.extend(strata)
if any([tt in strata for tt in target_tables]):
in_upstream = True
return upstream_tables
def compute_downstream_tables(passthrough_tables, disconnected_tables, order):
downstream_tables = []
for strata in order:
downstream_tables.extend(strata)
downstream_tables = list(reversed(list(filter(lambda table: table not in passthrough_tables and table not in disconnected_tables, downstream_tables))))
return downstream_tables
def compute_disconnected_tables(target_tables, passthrough_tables, all_tables, relationships):
uf = UnionFind()
for t in all_tables:
uf.make_set(t)
for rel in relationships:
uf.link(rel['fk_table'], rel['target_table'])
connected_components = set([uf.find(tt) for tt in target_tables])
connected_components.update([uf.find(pt) for pt in passthrough_tables])
return [t for t in all_tables if uf.find(t) not in connected_components]
def fully_qualified_table(table):
if '.' in table:
return quoter(schema_name(table)) + '.' + quoter(table_name(table))
else:
return quoter(table_name(table))
def schema_name(table):
return table.split('.')[0] if '.' in table else None
def table_name(table):
split = table.split('.')
return split[1] if len(split) > 1 else split[0]
def columns_tupled(columns):
return '(' + ','.join([quoter(c) for c in columns]) + ')'
def columns_joined(columns):
return ','.join([quoter(c) for c in columns])
def quoter(id):
q = '"' if config_reader.get_db_type() == 'postgres' else '`'
return q + id + q
def print_progress(target, idx, count):
print('Processing {} of {}: {}'.format(idx, count, target))
class UnionFind:
def __init__(self):
self.elementsToId = dict()
self.elements = []
self.roots = []
self.ranks = []
def __len__(self):
return len(self.roots)
def make_set(self, elem):
self.id_of(elem)
def find(self, elem):
x = self.elementsToId[elem]
if x == None:
return None
rootId = self.find_internal(x)
return self.elements[rootId]
def find_internal(self, x):
x0 = x
while self.roots[x] != x:
x = self.roots[x]
while self.roots[x0] != x:
y = self.roots[x0]
self.roots[x0] = x
x0 = y
return x
def id_of(self, elem):
if elem not in self.elementsToId:
idx = len(self.roots)
self.elements.append(elem)
self.elementsToId[elem] = idx
self.roots.append(idx)
self.ranks.append(0)
return self.elementsToId[elem]
def link(self, elem1, elem2):
x = self.id_of(elem1)
y = self.id_of(elem2)
xr = self.find_internal(x)
yr = self.find_internal(y)
if xr == yr:
return
xd = self.ranks[xr]
yd = self.ranks[yr]
if xd < yd:
self.roots[xr] = yr
elif yd < xd:
self.roots[yr] = xr
else:
self.roots[yr] = xr
self.ranks[xr] = self.ranks[xr] + 1
def members_of(self, elem):
id = self.elementsToId[elem]
if id is None:
raise ValueError("tried calling membersOf on an unknown element")
elemRoot = self.find_internal(id)
retval = []
for idx in range(len(self.elements)):
otherRoot = self.find_internal(idx)
if elemRoot == otherRoot:
retval.append(self.elements[idx])
return retval
def mysql_db_name_hack(target, conn):
if not isinstance(conn, MySqlConnection) or '.' not in target:
return target
else:
return conn.db_name + '.' + table_name(target)