Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit b3d4223

Browse files
sungchun12Sung Won Chungdlawin
authored
Make dbt data diffs concurrent (#776)
* v0 of concurrency * concurrent logging * remove todo * remove todo * better var name * add node name to logger * format string logs * add optional logger param * avoid extra threads * use thread pools * not multithreaded at the connection level anymore * show errors as they happen * show full stacktrace on error * rearrange trace * more logs for debugging * update for threads mocking * clear log params * remove extra space * remove long traceback * Ensure log_message is optional Co-authored-by: Dan Lawin <[email protected]> * map threaded result to proper model id * explicit type and optional * rm submodules again --------- Co-authored-by: Sung Won Chung <[email protected]> Co-authored-by: Dan Lawin <[email protected]>
1 parent 29b48b0 commit b3d4223

File tree

6 files changed

+110
-64
lines changed

6 files changed

+110
-64
lines changed

data_diff/databases/base.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,7 @@ def name(self):
931931
def compile(self, sql_ast):
932932
return self.dialect.compile(Compiler(self), sql_ast)
933933

934-
def query(self, sql_ast: Union[Expr, Generator], res_type: type = None):
934+
def query(self, sql_ast: Union[Expr, Generator], res_type: type = None, log_message: Optional[str] = None):
935935
"""Query the given SQL code/AST, and attempt to convert the result to type 'res_type'
936936
937937
If given a generator, it will execute all the yielded sql queries with the same thread and cursor.
@@ -956,7 +956,10 @@ def query(self, sql_ast: Union[Expr, Generator], res_type: type = None):
956956
if sql_code is SKIP:
957957
return SKIP
958958

959-
logger.debug("Running SQL (%s):\n%s", self.name, sql_code)
959+
if log_message:
960+
logger.debug("Running SQL (%s): %s \n%s", self.name, log_message, sql_code)
961+
else:
962+
logger.debug("Running SQL (%s):\n%s", self.name, sql_code)
960963

961964
if self._interactive and isinstance(sql_ast, Select):
962965
explained_sql = self.compile(Explain(sql_ast))
@@ -1022,7 +1025,7 @@ def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
10221025
Note: This method exists instead of select_table_schema(), just because not all databases support
10231026
accessing the schema using a SQL query.
10241027
"""
1025-
rows = self.query(self.select_table_schema(path), list)
1028+
rows = self.query(self.select_table_schema(path), list, log_message=path)
10261029
if not rows:
10271030
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")
10281031

@@ -1044,7 +1047,7 @@ def query_table_unique_columns(self, path: DbPath) -> List[str]:
10441047
"""Query the table for its unique columns for table in 'path', and return {column}"""
10451048
if not self.SUPPORTS_UNIQUE_CONSTAINT:
10461049
raise NotImplementedError("This database doesn't support 'unique' constraints")
1047-
res = self.query(self.select_table_unique_columns(path), List[str])
1050+
res = self.query(self.select_table_unique_columns(path), List[str], log_message=path)
10481051
return list(res)
10491052

10501053
def _process_table_schema(
@@ -1086,7 +1089,9 @@ def _refine_coltypes(
10861089
fields = [Code(self.dialect.normalize_uuid(self.dialect.quote(c), String_UUID())) for c in text_columns]
10871090

10881091
samples_by_row = self.query(
1089-
table(*table_path).select(*fields).where(Code(where) if where else SKIP).limit(sample_size), list
1092+
table(*table_path).select(*fields).where(Code(where) if where else SKIP).limit(sample_size),
1093+
list,
1094+
log_message=table_path,
10901095
)
10911096
if not samples_by_row:
10921097
raise ValueError(f"Table {table_path} is empty.")

data_diff/dbt.py

+27-16
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pydantic
99
import rich
1010
from rich.prompt import Prompt
11+
from concurrent.futures import ThreadPoolExecutor, as_completed
1112

1213
from data_diff.errors import (
1314
DataDiffCustomSchemaNoConfigError,
@@ -80,7 +81,6 @@ def dbt_diff(
8081
production_schema_flag: Optional[str] = None,
8182
) -> None:
8283
print_version_info()
83-
diff_threads = []
8484
set_entrypoint_name(os.getenv("DATAFOLD_TRIGGERED_BY", "CLI-dbt"))
8585
dbt_parser = DbtParser(profiles_dir_override, project_dir_override, state)
8686
models = dbt_parser.get_models(dbt_selection)
@@ -112,7 +112,11 @@ def dbt_diff(
112112
else:
113113
dbt_parser.set_connection()
114114

115-
with log_status_handler.status if log_status_handler else nullcontext():
115+
futures = {}
116+
117+
with log_status_handler.status if log_status_handler else nullcontext(), ThreadPoolExecutor(
118+
max_workers=dbt_parser.threads
119+
) as executor:
116120
for model in models:
117121
if log_status_handler:
118122
log_status_handler.set_prefix(f"Diffing {model.alias} \n")
@@ -140,12 +144,12 @@ def dbt_diff(
140144

141145
if diff_vars.primary_keys:
142146
if is_cloud:
143-
diff_thread = run_as_daemon(
147+
future = executor.submit(
144148
_cloud_diff, diff_vars, config.datasource_id, api, org_meta, log_status_handler
145149
)
146-
diff_threads.append(diff_thread)
147150
else:
148-
_local_diff(diff_vars, json_output)
151+
future = executor.submit(_local_diff, diff_vars, json_output, log_status_handler)
152+
futures[future] = model
149153
else:
150154
if json_output:
151155
print(
@@ -165,10 +169,12 @@ def dbt_diff(
165169
+ "Skipped due to unknown primary key. Add uniqueness tests, meta, or tags.\n"
166170
)
167171

168-
# wait for all threads
169-
if diff_threads:
170-
for thread in diff_threads:
171-
thread.join()
172+
for future in as_completed(futures):
173+
model = futures[future]
174+
try:
175+
future.result() # if error occurred, it will be raised here
176+
except Exception as e:
177+
logger.error(f"An error occurred during the execution of a diff task: {model.unique_id} - {e}")
172178

173179
_extension_notification()
174180

@@ -265,15 +271,17 @@ def _get_prod_path_from_manifest(model, prod_manifest) -> Union[Tuple[str, str,
265271
return prod_database, prod_schema, prod_alias
266272

267273

268-
def _local_diff(diff_vars: TDiffVars, json_output: bool = False) -> None:
274+
def _local_diff(
275+
diff_vars: TDiffVars, json_output: bool = False, log_status_handler: Optional[LogStatusHandler] = None
276+
) -> None:
277+
if log_status_handler:
278+
log_status_handler.diff_started(diff_vars.dev_path[-1])
269279
dev_qualified_str = ".".join(diff_vars.dev_path)
270280
prod_qualified_str = ".".join(diff_vars.prod_path)
271281
diff_output_str = _diff_output_base(dev_qualified_str, prod_qualified_str)
272282

273-
table1 = connect_to_table(
274-
diff_vars.connection, prod_qualified_str, tuple(diff_vars.primary_keys), diff_vars.threads
275-
)
276-
table2 = connect_to_table(diff_vars.connection, dev_qualified_str, tuple(diff_vars.primary_keys), diff_vars.threads)
283+
table1 = connect_to_table(diff_vars.connection, prod_qualified_str, tuple(diff_vars.primary_keys))
284+
table2 = connect_to_table(diff_vars.connection, dev_qualified_str, tuple(diff_vars.primary_keys))
277285

278286
try:
279287
table1_columns = table1.get_schema()
@@ -373,6 +381,9 @@ def _local_diff(diff_vars: TDiffVars, json_output: bool = False) -> None:
373381
diff_output_str += no_differences_template()
374382
rich.print(diff_output_str)
375383

384+
if log_status_handler:
385+
log_status_handler.diff_finished(diff_vars.dev_path[-1])
386+
376387

377388
def _initialize_api() -> Optional[DatafoldAPI]:
378389
datafold_host = os.environ.get("DATAFOLD_HOST")
@@ -406,7 +417,7 @@ def _cloud_diff(
406417
log_status_handler: Optional[LogStatusHandler] = None,
407418
) -> None:
408419
if log_status_handler:
409-
log_status_handler.cloud_diff_started(diff_vars.dev_path[-1])
420+
log_status_handler.diff_started(diff_vars.dev_path[-1])
410421
diff_output_str = _diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path))
411422
payload = TCloudApiDataDiff(
412423
data_source1_id=datasource_id,
@@ -476,7 +487,7 @@ def _cloud_diff(
476487
rich.print(diff_output_str)
477488

478489
if log_status_handler:
479-
log_status_handler.cloud_diff_finished(diff_vars.dev_path[-1])
490+
log_status_handler.diff_finished(diff_vars.dev_path[-1])
480491
except BaseException as ex: # Catch KeyboardInterrupt too
481492
error = ex
482493
finally:

data_diff/dbt_parser.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -446,17 +446,17 @@ def get_pk_from_model(self, node, unique_columns: dict, pk_tag: str) -> List[str
446446

447447
from_meta = [name for name, params in node.columns.items() if pk_tag in params.meta] or None
448448
if from_meta:
449-
logger.debug("Found PKs via META: " + str(from_meta))
449+
logger.debug(f"Found PKs via META [{node.name}]: " + str(from_meta))
450450
return from_meta
451451

452452
from_tags = [name for name, params in node.columns.items() if pk_tag in params.tags] or None
453453
if from_tags:
454-
logger.debug("Found PKs via Tags: " + str(from_tags))
454+
logger.debug(f"Found PKs via Tags [{node.name}]: " + str(from_tags))
455455
return from_tags
456456
if node.unique_id in unique_columns:
457457
from_uniq = unique_columns.get(node.unique_id)
458458
if from_uniq is not None:
459-
logger.debug("Found PKs via Uniqueness tests: " + str(from_uniq))
459+
logger.debug(f"Found PKs via Uniqueness tests [{node.name}]: {str(from_uniq)}")
460460
return list(from_uniq)
461461

462462
except (KeyError, IndexError, TypeError) as e:

data_diff/joindiff_tables.py

+42-19
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def _diff_tables_root(self, table1: TableSegment, table2: TableSegment, info_tre
162162
yield from self._diff_segments(None, table1, table2, info_tree, None)
163163
else:
164164
yield from self._bisect_and_diff_tables(table1, table2, info_tree)
165-
logger.info("Diffing complete")
165+
logger.info(f"Diffing complete: {table1.table_path} <> {table2.table_path}")
166166
if self.materialize_to_table:
167167
logger.info("Materialized diff to table '%s'.", ".".join(self.materialize_to_table))
168168

@@ -193,8 +193,8 @@ def _diff_segments(
193193
partial(self._collect_stats, 1, table1, info_tree),
194194
partial(self._collect_stats, 2, table2, info_tree),
195195
partial(self._test_null_keys, table1, table2),
196-
partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols),
197-
partial(self._count_diff_per_column, db, diff_rows, list(a_cols), is_diff_cols),
196+
partial(self._sample_and_count_exclusive, db, diff_rows, a_cols, b_cols, table1, table2),
197+
partial(self._count_diff_per_column, db, diff_rows, list(a_cols), is_diff_cols, table1, table2),
198198
partial(
199199
self._materialize_diff,
200200
db,
@@ -205,8 +205,8 @@ def _diff_segments(
205205
else None,
206206
):
207207
assert len(a_cols) == len(b_cols)
208-
logger.debug("Querying for different rows")
209-
diff = db.query(diff_rows, list)
208+
logger.debug(f"Querying for different rows: {table1.table_path}")
209+
diff = db.query(diff_rows, list, log_message=table1.table_path)
210210
info_tree.info.set_diff(diff, schema=tuple(diff_rows.schema.items()))
211211
for is_xa, is_xb, *x in diff:
212212
if is_xa and is_xb:
@@ -227,7 +227,7 @@ def _diff_segments(
227227
yield "+", tuple(b_row)
228228

229229
def _test_duplicate_keys(self, table1: TableSegment, table2: TableSegment):
230-
logger.debug("Testing for duplicate keys")
230+
logger.debug(f"Testing for duplicate keys: {table1.table_path} <> {table2.table_path}")
231231

232232
# Test duplicate keys
233233
for ts in [table1, table2]:
@@ -240,24 +240,24 @@ def _test_duplicate_keys(self, table1: TableSegment, table2: TableSegment):
240240

241241
unvalidated = list(set(key_columns) - set(unique))
242242
if unvalidated:
243-
logger.info(f"Validating that the are no duplicate keys in columns: {unvalidated}")
243+
logger.info(f"Validating that the are no duplicate keys in columns: {unvalidated} for {ts.table_path}")
244244
# Validate that there are no duplicate keys
245245
self.stats["validated_unique_keys"] = self.stats.get("validated_unique_keys", []) + [unvalidated]
246246
q = t.select(total=Count(), total_distinct=Count(Concat(this[unvalidated]), distinct=True))
247-
total, total_distinct = ts.database.query(q, tuple)
247+
total, total_distinct = ts.database.query(q, tuple, log_message=ts.table_path)
248248
if total != total_distinct:
249249
raise ValueError("Duplicate primary keys")
250250

251251
def _test_null_keys(self, table1, table2):
252-
logger.debug("Testing for null keys")
252+
logger.debug(f"Testing for null keys: {table1.table_path} <> {table2.table_path}")
253253

254254
# Test null keys
255255
for ts in [table1, table2]:
256256
t = ts.make_select()
257257
key_columns = ts.key_columns
258258

259259
q = t.select(*this[key_columns]).where(or_(this[k] == None for k in key_columns))
260-
nulls = ts.database.query(q, list)
260+
nulls = ts.database.query(q, list, log_message=ts.table_path)
261261
if nulls:
262262
if self.skip_null_keys:
263263
logger.warning(
@@ -267,7 +267,7 @@ def _test_null_keys(self, table1, table2):
267267
raise ValueError(f"NULL values in one or more primary keys of {ts.table_path}")
268268

269269
def _collect_stats(self, i, table_seg: TableSegment, info_tree: InfoTree):
270-
logger.debug(f"Collecting stats for table #{i}")
270+
logger.debug(f"Collecting stats for table #{i}: {table_seg.table_path}")
271271
db = table_seg.database
272272

273273
# Metrics
@@ -288,7 +288,7 @@ def _collect_stats(self, i, table_seg: TableSegment, info_tree: InfoTree):
288288
)
289289
col_exprs["count"] = Count()
290290

291-
res = db.query(table_seg.make_select().select(**col_exprs), tuple)
291+
res = db.query(table_seg.make_select().select(**col_exprs), tuple, log_message=table_seg.table_path)
292292

293293
for col_name, value in safezip(col_exprs, res):
294294
if value is not None:
@@ -303,7 +303,7 @@ def _collect_stats(self, i, table_seg: TableSegment, info_tree: InfoTree):
303303
else:
304304
self.stats[stat_name] = value
305305

306-
logger.debug("Done collecting stats for table #%s", i)
306+
logger.debug("Done collecting stats for table #%s: %s", i, table_seg.table_path)
307307

308308
def _create_outer_join(self, table1, table2):
309309
db = table1.database
@@ -334,23 +334,46 @@ def _create_outer_join(self, table1, table2):
334334
diff_rows = all_rows.where(or_(this[c] == 1 for c in is_diff_cols))
335335
return diff_rows, a_cols, b_cols, is_diff_cols, all_rows
336336

337-
def _count_diff_per_column(self, db, diff_rows, cols, is_diff_cols):
338-
logger.debug("Counting differences per column")
339-
is_diff_cols_counts = db.query(diff_rows.select(sum_(this[c]) for c in is_diff_cols), tuple)
337+
def _count_diff_per_column(
338+
self,
339+
db,
340+
diff_rows,
341+
cols,
342+
is_diff_cols,
343+
table1: Optional[TableSegment] = None,
344+
table2: Optional[TableSegment] = None,
345+
):
346+
logger.info(type(table1))
347+
logger.debug(f"Counting differences per column: {table1.table_path} <> {table2.table_path}")
348+
is_diff_cols_counts = db.query(
349+
diff_rows.select(sum_(this[c]) for c in is_diff_cols),
350+
tuple,
351+
log_message=f"{table1.table_path} <> {table2.table_path}",
352+
)
340353
diff_counts = {}
341354
for name, count in safezip(cols, is_diff_cols_counts):
342355
diff_counts[name] = diff_counts.get(name, 0) + (count or 0)
343356
self.stats["diff_counts"] = diff_counts
344357

345-
def _sample_and_count_exclusive(self, db, diff_rows, a_cols, b_cols):
358+
def _sample_and_count_exclusive(
359+
self,
360+
db,
361+
diff_rows,
362+
a_cols,
363+
b_cols,
364+
table1: Optional[TableSegment] = None,
365+
table2: Optional[TableSegment] = None,
366+
):
346367
if isinstance(db, (Oracle, MsSQL)):
347368
exclusive_rows_query = diff_rows.where((this.is_exclusive_a == 1) | (this.is_exclusive_b == 1))
348369
else:
349370
exclusive_rows_query = diff_rows.where(this.is_exclusive_a | this.is_exclusive_b)
350371

351372
if not self.sample_exclusive_rows:
352-
logger.debug("Counting exclusive rows")
353-
self.stats["exclusive_count"] = db.query(exclusive_rows_query.count(), int)
373+
logger.debug(f"Counting exclusive rows: {table1.table_path} <> {table2.table_path}")
374+
self.stats["exclusive_count"] = db.query(
375+
exclusive_rows_query.count(), int, log_message=f"{table1.table_path} <> {table2.table_path}"
376+
)
354377
return
355378

356379
logger.info("Counting and sampling exclusive rows")

data_diff/utils.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -485,31 +485,31 @@ def __init__(self):
485485
super().__init__()
486486
self.status = Status("")
487487
self.prefix = ""
488-
self.cloud_diff_status = {}
488+
self.diff_status = {}
489489

490490
def emit(self, record):
491491
log_entry = self.format(record)
492-
if self.cloud_diff_status:
493-
self._update_cloud_status(log_entry)
492+
if self.diff_status:
493+
self._update_diff_status(log_entry)
494494
else:
495495
self.status.update(self.prefix + log_entry)
496496

497497
def set_prefix(self, prefix_string):
498498
self.prefix = prefix_string
499499

500-
def cloud_diff_started(self, model_name):
501-
self.cloud_diff_status[model_name] = "[yellow]In Progress[/]"
502-
self._update_cloud_status()
500+
def diff_started(self, model_name):
501+
self.diff_status[model_name] = "[yellow]In Progress[/]"
502+
self._update_diff_status()
503503

504-
def cloud_diff_finished(self, model_name):
505-
self.cloud_diff_status[model_name] = "[green]Finished [/]"
506-
self._update_cloud_status()
504+
def diff_finished(self, model_name):
505+
self.diff_status[model_name] = "[green]Finished [/]"
506+
self._update_diff_status()
507507

508-
def _update_cloud_status(self, log=None):
509-
cloud_status_string = "\n"
510-
for model_name, status in self.cloud_diff_status.items():
511-
cloud_status_string += f"{status} {model_name}\n"
512-
self.status.update(f"{cloud_status_string}{log or ''}")
508+
def _update_diff_status(self, log=None):
509+
status_string = "\n"
510+
for model_name, status in self.diff_status.items():
511+
status_string += f"{status} {model_name}\n"
512+
self.status.update(f"{status_string}{log or ''}")
513513

514514

515515
class UnknownMeta(type):

0 commit comments

Comments
 (0)