diff --git a/.github/workflows/audit.yml b/.github/workflows/audit.yml index 7c301dd..a56f194 100644 --- a/.github/workflows/audit.yml +++ b/.github/workflows/audit.yml @@ -14,9 +14,13 @@ jobs: - uses: actions/checkout@v3 - name: install run: | - python -m venv env/ - source env/bin/activate - python -m pip install -r requirements.txt + python -m venv env + # Upgrade wheel & setuptools first + env/bin/python -m pip install --upgrade wheel setuptools + # Install patched pip from GitHub commit to fix tarfile vulnerability + env/bin/python -m pip install "git+https://github.com/pypa/pip@f2b9231" + # Install project dependencies + env/bin/python -m pip install -r requirements.txt - uses: pypa/gh-action-pip-audit@v1.0.7 with: virtual-environment: env/ diff --git a/mario/hyper_utils.py b/mario/hyper_utils.py index 708a81e..47f5ba2 100644 --- a/mario/hyper_utils.py +++ b/mario/hyper_utils.py @@ -294,11 +294,16 @@ def save_hyper_as_csv(hyper_file: str, file_path: str, **kwargs): options = CsvOptions(**kwargs) with tempfile.TemporaryDirectory() as temp_dir: - temp_hyper = os.path.join(temp_dir, 'temp.hyper') - shutil.copyfile( - src=hyper_file, - dst=temp_hyper - ) + if options.do_not_modify_source: + logging.info('Copy the source hyper file into the temp directory.') + temp_hyper = os.path.join(temp_dir, 'temp.hyper') + shutil.copyfile( + src=hyper_file, + dst=temp_hyper + ) + else: + logging.info('Use the source hyper file directly without creating a temp copy.') + temp_hyper = hyper_file schema, table = get_default_table_and_schema(temp_hyper) diff --git a/mario/options.py b/mario/options.py index 7d06c6f..97703dc 100644 --- a/mario/options.py +++ b/mario/options.py @@ -27,6 +27,7 @@ class CsvOptions(OutputOptions): def __init__(self, **kwargs): super().__init__(**kwargs) self.compress_using_gzip = kwargs.get('compress_using_gzip', False) + self.do_not_modify_source = kwargs.get('do_not_modify_source', True) class HyperOptions(OutputOptions): diff --git a/test/test_data_extractor.py b/test/test_data_extractor.py index d41870f..01acd93 100644 --- a/test/test_data_extractor.py +++ b/test/test_data_extractor.py @@ -558,6 +558,38 @@ def test_hyper_to_csv(): df = pd.read_csv(output_file) assert round(df['Sales'].sum(), 4) == 2326534.3543 +def test_hyper_to_csv_without_copy_to_tmp(): + dataset = dataset_from_json(os.path.join('test', 'dataset.json')) + dataset.measures = [] + metadata = metadata_from_json(os.path.join('test', 'metadata.json')) + # Copy the source data to avoid overwriting during other pytest runs + shutil.copyfile( + src=os.path.join('test', 'orders.hyper'), + dst=os.path.join('test', 'orders_copy.hyper') + ) + configuration = Configuration( + file_path=os.path.join('test', 'orders_copy.hyper') + ) + extractor = HyperFile( + dataset_specification=dataset, + metadata=metadata, + configuration=configuration + ) + output_folder = os.path.join('output', 'test_hyper_to_csv') + os.makedirs(output_folder, exist_ok=True) + output_file = os.path.join(output_folder, 'orders_without_temp.csv') + extractor.save_data_as_csv( + file_path=output_file, + minimise=False, + compress_using_gzip=False, + do_not_modify_source=False + ) + assert extractor.get_total() == 10194 + assert extractor.get_total(measure='Sales') == 2326534.3542999607 + + df = pd.read_csv(output_file) + assert round(df['Sales'].sum(), 4) == 2326534.3543 + def test_partitioning_extractor_partition_sql_no_data_in_partition(): # Skip this test if we don't have a connection string