diff --git a/mario/hyper_utils.py b/mario/hyper_utils.py index 47f5ba2..3dc6ad6 100644 --- a/mario/hyper_utils.py +++ b/mario/hyper_utils.py @@ -290,6 +290,9 @@ def save_hyper_as_csv(hyper_file: str, file_path: str, **kwargs): import tempfile import shutil import os + import csv + import gzip + from tableauhyperapi import HyperProcess, Telemetry, Connection options = CsvOptions(**kwargs) @@ -319,7 +322,7 @@ def save_hyper_as_csv(hyper_file: str, file_path: str, **kwargs): input_hyper_file_path=temp_hyper, schema=schema, table=table - ) + ) else: log.debug("Data source already contains row numbers") @@ -331,29 +334,62 @@ def save_hyper_as_csv(hyper_file: str, file_path: str, **kwargs): if options.compress_using_gzip: compression_options = dict(method='gzip') file_path = file_path + '.gz' + open_func = gzip.open + mode = "wt" elif file_path.endswith('.gz'): compression_options = dict(method='gzip') + open_func = gzip.open + mode = "wt" else: compression_options = None + open_func = open + mode = "w" - mode = 'w' - header = True - offset = 0 + # Get column names column_names = ','.join(f'"{column}"' for column in columns) - sql = f"SELECT {column_names} FROM \"{schema}\".\"{table}\" ORDER BY row_number" + offset = 0 - while True: - query = f"{sql} LIMIT {options.chunk_size} OFFSET {offset}" - df_chunk = pantab.frame_from_hyper_query(temp_hyper, query) - if df_chunk.empty: - break - df_chunk.to_csv(file_path, index=False, mode=mode, header=header, - compression=compression_options) - offset += options.chunk_size - if header: - header = False - mode = "a" + if options.use_pantab: + # Use pantab to stream hyper to csv + mode = 'w' + header = True + + while True: + query = f"{sql} LIMIT {options.chunk_size} OFFSET {offset}" + df_chunk = pantab.frame_from_hyper_query(temp_hyper, query) + if df_chunk.empty: + break + df_chunk.to_csv(file_path, index=False, mode=mode, header=header, + compression=compression_options) + offset += options.chunk_size + if header: + header = False + mode = "a" + else: + # Use tableau hyper api to stream data to csv + with HyperProcess(telemetry=Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU) as hyper: + with Connection(endpoint=hyper.endpoint, database=temp_hyper) as connection: + # Use an iterator cursor for streaming + result = connection.execute_query(sql) + + with open_func(file_path, mode, newline='', encoding="utf-8") as f: + writer = csv.writer(f) + # write header + writer.writerow(columns) + + buffer = [] + for row in result: + buffer.append(row) + if len(buffer) >= options.chunk_size: + writer.writerows(buffer) + buffer.clear() + offset += options.chunk_size + + # write remaining + if buffer: + writer.writerows(buffer) + offset += len(buffer) def save_dataframe_as_hyper(df, file_path, **kwargs): diff --git a/mario/options.py b/mario/options.py index 97703dc..6ac5d6b 100644 --- a/mario/options.py +++ b/mario/options.py @@ -28,6 +28,7 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self.compress_using_gzip = kwargs.get('compress_using_gzip', False) self.do_not_modify_source = kwargs.get('do_not_modify_source', True) + self.use_pantab = kwargs.get('use_pantab', True) class HyperOptions(OutputOptions): diff --git a/setup.py b/setup.py index 349cf42..be172f2 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='mario-pipeline-tools', - version='0.57', + version='0.58', packages=['mario'], url='https://github.com/JiscDACT/mario', license='all rights reserved', diff --git a/test/test_data_extractor.py b/test/test_data_extractor.py index 01acd93..e5e8587 100644 --- a/test/test_data_extractor.py +++ b/test/test_data_extractor.py @@ -590,6 +590,38 @@ def test_hyper_to_csv_without_copy_to_tmp(): df = pd.read_csv(output_file) assert round(df['Sales'].sum(), 4) == 2326534.3543 +def test_hyper_to_csv_without_using_pantab(): + dataset = dataset_from_json(os.path.join('test', 'dataset.json')) + dataset.measures = [] + metadata = metadata_from_json(os.path.join('test', 'metadata.json')) + # Copy the source data to avoid overwriting during other pytest runs + shutil.copyfile( + src=os.path.join('test', 'orders.hyper'), + dst=os.path.join('test', 'orders_pantab_test.hyper') + ) + configuration = Configuration( + file_path=os.path.join('test', 'orders_pantab_test.hyper') + ) + extractor = HyperFile( + dataset_specification=dataset, + metadata=metadata, + configuration=configuration + ) + output_folder = os.path.join('output', 'test_hyper_to_csv') + os.makedirs(output_folder, exist_ok=True) + output_file = os.path.join(output_folder, 'orders_without_pantab.csv') + extractor.save_data_as_csv( + file_path=output_file, + minimise=False, + compress_using_gzip=False, + do_not_modify_source=False, + use_pantab=False + ) + assert extractor.get_total() == 10194 + assert extractor.get_total(measure='Sales') == 2326534.3542999607 + + df = pd.read_csv(output_file) + assert round(df['Sales'].sum(), 4) == 2326534.3543 def test_partitioning_extractor_partition_sql_no_data_in_partition(): # Skip this test if we don't have a connection string