diff --git a/mario/hyper_utils.py b/mario/hyper_utils.py index 9b24f47..8c737dd 100644 --- a/mario/hyper_utils.py +++ b/mario/hyper_utils.py @@ -347,8 +347,6 @@ def save_hyper_as_csv(hyper_file: str, file_path: str, **kwargs): # Get column names column_names = ','.join(f'"{column}"' for column in columns) - sql = f"SELECT {column_names} FROM \"{schema}\".\"{table}\" ORDER BY row_number" - offset = 0 if options.use_pantab: # Use pantab to stream hyper to csv @@ -356,6 +354,8 @@ def save_hyper_as_csv(hyper_file: str, file_path: str, **kwargs): mode = 'w' header = True + sql = f"SELECT {column_names} FROM \"{schema}\".\"{table}\" ORDER BY row_number" + offset = 0 while True: query = f"{sql} LIMIT {options.chunk_size} OFFSET {offset}" @@ -375,22 +375,30 @@ def save_hyper_as_csv(hyper_file: str, file_path: str, **kwargs): with HyperProcess(Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU, 'test') as hyper: with Connection(endpoint=hyper.endpoint, database=temp_hyper) as connection: + # Get min and max value from row_number + sql_range = f"SELECT MIN(row_number), MAX(row_number) FROM \"{schema}\".\"{table}\"" + value_range = connection.execute_query(sql_range) + start_val, end_val = list(value_range)[0] + logging.info(f"Get row_number range: [{str(start_val)}, {str(end_val)}]") + + sql = f"SELECT {column_names} FROM \"{schema}\".\"{table}\"" + with open_func(file_path, mode, newline='', encoding="utf-8") as f: writer = csv.writer(f) # write header writer.writerow(columns) - while True: - query = f"{sql} LIMIT {options.chunk_size} OFFSET {offset}" + while start_val <= end_val: + chunk_end = min(start_val+options.chunk_size-1, end_val) + query = f"{sql} WHERE row_number BETWEEN {start_val} AND {chunk_end}" + logging.info(f"Query between {start_val} and {chunk_end}") + result = connection.execute_query(query) rows = list(result) - if not rows: - break - writer.writerows(rows) - offset += options.chunk_size + start_val += options.chunk_size def save_dataframe_as_hyper(df, file_path, **kwargs): diff --git a/setup.py b/setup.py index 1ba179a..e989d3a 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='mario-pipeline-tools', - version='0.59', + version='0.60', packages=['mario'], url='https://github.com/JiscDACT/mario', license='all rights reserved', diff --git a/test/test_data_extractor.py b/test/test_data_extractor.py index e5e8587..0e146cf 100644 --- a/test/test_data_extractor.py +++ b/test/test_data_extractor.py @@ -553,7 +553,7 @@ def test_hyper_to_csv(): compress_using_gzip=False ) assert extractor.get_total() == 10194 - assert extractor.get_total(measure='Sales') == 2326534.354299952 + assert round(extractor.get_total(measure='Sales'), 2) == 2326534.35 df = pd.read_csv(output_file) assert round(df['Sales'].sum(), 4) == 2326534.3543 @@ -585,7 +585,7 @@ def test_hyper_to_csv_without_copy_to_tmp(): do_not_modify_source=False ) assert extractor.get_total() == 10194 - assert extractor.get_total(measure='Sales') == 2326534.3542999607 + assert round(extractor.get_total(measure='Sales'), 2) == 2326534.35 df = pd.read_csv(output_file) assert round(df['Sales'].sum(), 4) == 2326534.3543 @@ -618,7 +618,7 @@ def test_hyper_to_csv_without_using_pantab(): use_pantab=False ) assert extractor.get_total() == 10194 - assert extractor.get_total(measure='Sales') == 2326534.3542999607 + assert round(extractor.get_total(measure='Sales'), 2) == 2326534.35 df = pd.read_csv(output_file) assert round(df['Sales'].sum(), 4) == 2326534.3543