From 701e04912d3bda633fc18938cb1b5320f70a8b9e Mon Sep 17 00:00:00 2001 From: hrobarts Date: Fri, 16 Feb 2024 12:14:44 +0000 Subject: [PATCH 1/7] Add Korn and Walnut data examples --- Wrappers/Python/cil/utilities/dataexample.py | 58 +++++++++++++++++++- 1 file changed, 56 insertions(+), 2 deletions(-) diff --git a/Wrappers/Python/cil/utilities/dataexample.py b/Wrappers/Python/cil/utilities/dataexample.py index b922332e1a..72414e70c8 100644 --- a/Wrappers/Python/cil/utilities/dataexample.py +++ b/Wrappers/Python/cil/utilities/dataexample.py @@ -24,7 +24,9 @@ import os import os.path import sys -from cil.io import NEXUSDataReader +from zipfile import ZipFile +from urllib.request import urlretrieve +from cil.io import NEXUSDataReader, NikonDataReader, ZEISSDataReader data_dir = os.path.abspath(os.path.join( os.path.dirname(__file__), @@ -158,6 +160,58 @@ def get(cls, **kwargs): loader = NEXUSDataReader() loader.set_up(file_name=os.path.join(os.path.abspath(ddir), 'sim_volume.nxs')) return loader.read() + +class WALNUT(DATA): + @classmethod + def get(cls, **kwargs): + + ddir = kwargs.get('data_dir', data_dir) + cls.retrieve_data(**kwargs) + loader = ZEISSDataReader(file_name=ddir+'/valnut/valnut_2014-03-21_643_28/tomo-A/valnut_tomo-A.txrm') + return loader.read() + + def retrieve_data(**kwargs): + ddir = kwargs.get('data_dir', data_dir) + if os.path.isdir(ddir+'/valnut') == False: + print('Downloading Walnut dataset to ' + ddir) + urlretrieve('https://zenodo.org/record/4822516/files/walnut.zip', ddir + '/walnut.zip') + myzip = ZipFile(ddir+'/walnut.zip', 'r') + myzip.extractall(path=ddir) + os.remove(ddir+"/walnut.zip") + print("Complete") + else: + print('Data folder exists at ' + ddir +'/valnut') + + return ddir + +class KORN(DATA): + @classmethod + def get(cls, **kwargs): + + ddir = kwargs.get('data_dir', data_dir) + cls.retrieve_data(**kwargs) + loader = NikonDataReader(file_name=ddir+'/Korn i kasse/47209 testscan korn01_recon.xtekct') + return loader.read() + + def retrieve_data(**kwargs): + ddir = kwargs.get('data_dir', data_dir) + if os.path.isdir(ddir+'/Korn i kasse') == False: + print('Downloading Korn dataset to ' + ddir) + urlretrieve('https://zenodo.org/record/6874123/files/korn.zip', ddir + '/korn.zip') + myzip = ZipFile(ddir+'/korn.zip', 'r') + myzip.extractall(path=ddir) + os.remove(ddir+"/korn.zip") + print("Complete") + else: + print('Data folder exists at ' + ddir +'/Korn i kasse') + + return ddir + +class RemoteData(object): + def __init__(self, **kwargs): + self.data_dir = kwargs.get('data_dir', data_dir) + + class TestData(object): '''Class to return test data @@ -507,4 +561,4 @@ def scikit_random_noise(image, mode='gaussian', seed=None, clip=True, **kwargs): if clip: out = np.clip(out, low_clip, 1.0) - return out + return out \ No newline at end of file From 3ba33a3a1bd930ac595e671990b6cbc3784957cf Mon Sep 17 00:00:00 2001 From: hrobarts Date: Fri, 16 Feb 2024 17:50:08 +0000 Subject: [PATCH 2/7] Add internal data class, require data_dir for external --- Wrappers/Python/cil/utilities/dataexample.py | 138 +++++++++---------- 1 file changed, 66 insertions(+), 72 deletions(-) diff --git a/Wrappers/Python/cil/utilities/dataexample.py b/Wrappers/Python/cil/utilities/dataexample.py index 72414e70c8..f29ebd11aa 100644 --- a/Wrappers/Python/cil/utilities/dataexample.py +++ b/Wrappers/Python/cil/utilities/dataexample.py @@ -25,58 +25,57 @@ import os.path import sys from zipfile import ZipFile -from urllib.request import urlretrieve +from urllib.request import urlopen +from io import BytesIO from cil.io import NEXUSDataReader, NikonDataReader, ZEISSDataReader -data_dir = os.path.abspath(os.path.join( - os.path.dirname(__file__), - '../data/') -) - -# this is the default location after a conda install -data_dir = os.path.abspath( - os.path.join(sys.prefix, 'share','cil') -) +# # this is the default location after a conda install +# data_dir = os.path.abspath( +# os.path.join(sys.prefix, 'share','cil') +# ) class DATA(object): @classmethod def dfile(cls): return None + +class INTERNALDATA(DATA): + data_dir = os.path.abspath(os.path.join(sys.prefix, 'share','cil')) @classmethod def get(cls, size=None, scale=(0,1), **kwargs): - ddir = kwargs.get('data_dir', data_dir) + ddir = kwargs.get('data_dir', INTERNALDATA.data_dir) loader = TestData(data_dir=ddir) return loader.load(cls.dfile(), size, scale, **kwargs) -class BOAT(DATA): +class BOAT(INTERNALDATA): @classmethod def dfile(cls): return TestData.BOAT -class CAMERA(DATA): +class CAMERA(INTERNALDATA): @classmethod def dfile(cls): return TestData.CAMERA -class PEPPERS(DATA): +class PEPPERS(INTERNALDATA): @classmethod def dfile(cls): return TestData.PEPPERS -class RESOLUTION_CHART(DATA): +class RESOLUTION_CHART(INTERNALDATA): @classmethod def dfile(cls): return TestData.RESOLUTION_CHART -class SIMPLE_PHANTOM_2D(DATA): +class SIMPLE_PHANTOM_2D(INTERNALDATA): @classmethod def dfile(cls): return TestData.SIMPLE_PHANTOM_2D -class SHAPES(DATA): +class SHAPES(INTERNALDATA): @classmethod def dfile(cls): return TestData.SHAPES -class RAINBOW(DATA): +class RAINBOW(INTERNALDATA): @classmethod def dfile(cls): return TestData.RAINBOW -class SYNCHROTRON_PARALLEL_BEAM_DATA(DATA): +class SYNCHROTRON_PARALLEL_BEAM_DATA(INTERNALDATA): @classmethod def get(cls, **kwargs): ''' @@ -93,11 +92,11 @@ def get(cls, **kwargs): The DLS dataset ''' - ddir = kwargs.get('data_dir', data_dir) + ddir = kwargs.get('data_dir', INTERNALDATA.data_dir) loader = NEXUSDataReader() loader.set_up(file_name=os.path.join(os.path.abspath(ddir), '24737_fd_normalised.nxs')) return loader.read() -class SIMULATED_PARALLEL_BEAM_DATA(DATA): +class SIMULATED_PARALLEL_BEAM_DATA(INTERNALDATA): @classmethod def get(cls, **kwargs): ''' @@ -114,11 +113,11 @@ def get(cls, **kwargs): The simulated spheres dataset ''' - ddir = kwargs.get('data_dir', data_dir) + ddir = kwargs.get('data_dir', INTERNALDATA.data_dir) loader = NEXUSDataReader() loader.set_up(file_name=os.path.join(os.path.abspath(ddir), 'sim_parallel_beam.nxs')) return loader.read() -class SIMULATED_CONE_BEAM_DATA(DATA): +class SIMULATED_CONE_BEAM_DATA(INTERNALDATA): @classmethod def get(cls, **kwargs): ''' @@ -135,11 +134,11 @@ def get(cls, **kwargs): The simulated spheres dataset ''' - ddir = kwargs.get('data_dir', data_dir) + ddir = kwargs.get('data_dir', INTERNALDATA.data_dir) loader = NEXUSDataReader() loader.set_up(file_name=os.path.join(os.path.abspath(ddir), 'sim_cone_beam.nxs')) return loader.read() -class SIMULATED_SPHERE_VOLUME(DATA): +class SIMULATED_SPHERE_VOLUME(INTERNALDATA): @classmethod def get(cls, **kwargs): ''' @@ -156,63 +155,58 @@ def get(cls, **kwargs): The simulated spheres volume ''' - ddir = kwargs.get('data_dir', data_dir) + ddir = kwargs.get('data_dir', INTERNALDATA.data_dir) loader = NEXUSDataReader() loader.set_up(file_name=os.path.join(os.path.abspath(ddir), 'sim_volume.nxs')) return loader.read() class WALNUT(DATA): @classmethod - def get(cls, **kwargs): - - ddir = kwargs.get('data_dir', data_dir) - cls.retrieve_data(**kwargs) - loader = ZEISSDataReader(file_name=ddir+'/valnut/valnut_2014-03-21_643_28/tomo-A/valnut_tomo-A.txrm') - return loader.read() + def get(cls, data_dir): + WALNUT = os.path.join('valnut','valnut_2014-03-21_643_28','tomo-A','valnut_tomo-A.txrm') + try: + loader = ZEISSDataReader(file_name=os.path.join(data_dir,WALNUT)) + return loader.read() + except(FileNotFoundError): + raise(FileNotFoundError("Dataset not found in specifed data_dir: {} \n \ + Specify a different data_dir or download data with dataexample.{}.download_data(data_dir)".format(data_dir, cls.__name__))) - def retrieve_data(**kwargs): - ddir = kwargs.get('data_dir', data_dir) - if os.path.isdir(ddir+'/valnut') == False: - print('Downloading Walnut dataset to ' + ddir) - urlretrieve('https://zenodo.org/record/4822516/files/walnut.zip', ddir + '/walnut.zip') - myzip = ZipFile(ddir+'/walnut.zip', 'r') - myzip.extractall(path=ddir) - os.remove(ddir+"/walnut.zip") - print("Complete") - else: - print('Data folder exists at ' + ddir +'/valnut') - - return ddir + def download_data(data_dir): + zip_url = 'https://zenodo.org/record/4822516/files/walnut.zip' + if input("Are you sure you want to download the dataset from " + zip_url + " ? (y/n)") == "y": + print('Downloading Walnut dataset to ' + data_dir) + with urlopen(zip_url) as response: + with BytesIO(response.read()) as bytes, ZipFile(bytes) as zipfile: + zipfile.extractall(path = data_dir) + print("Download complete") class KORN(DATA): @classmethod - def get(cls, **kwargs): - - ddir = kwargs.get('data_dir', data_dir) - cls.retrieve_data(**kwargs) - loader = NikonDataReader(file_name=ddir+'/Korn i kasse/47209 testscan korn01_recon.xtekct') - return loader.read() - - def retrieve_data(**kwargs): - ddir = kwargs.get('data_dir', data_dir) - if os.path.isdir(ddir+'/Korn i kasse') == False: - print('Downloading Korn dataset to ' + ddir) - urlretrieve('https://zenodo.org/record/6874123/files/korn.zip', ddir + '/korn.zip') - myzip = ZipFile(ddir+'/korn.zip', 'r') - myzip.extractall(path=ddir) - os.remove(ddir+"/korn.zip") - print("Complete") - else: - print('Data folder exists at ' + ddir +'/Korn i kasse') - - return ddir + def get(cls, data_dir): + KORN = os.path.join('Korn i kasse','47209 testscan korn01_recon.xtekct') + try: + loader = ZEISSDataReader(file_name=os.path.join(data_dir,KORN)) + return loader.read() + except(FileNotFoundError): + raise(FileNotFoundError("Dataset not found in specifed data_dir: {} \n \ + Specify a different data_dir or download data with dataexample.{}.download_data(data_dir)".format(data_dir, cls.__name__))) -class RemoteData(object): - def __init__(self, **kwargs): - self.data_dir = kwargs.get('data_dir', data_dir) + def download_data(data_dir): + zip_url = 'https://zenodo.org/record/6874123/files/korn.zip' + if input("Are you sure you want to download the dataset from " + zip_url + " ? (y/n)") == "y": + print('Downloading Walnut dataset to ' + data_dir) + with urlopen(zip_url) as response: + with BytesIO(response.read()) as bytes, ZipFile(bytes) as zipfile: + zipfile.extractall(path = data_dir) + print("Download complete") + +# class RemoteData(object): +# WALNUT_URL = 'https://zenodo.org/record/4822516/files/walnut.zip' +# WALNUT = os.path.join('valnut','valnut_2014-03-21_643_28','tomo-A','valnut_tomo-A.txrm') + +# KORN_URL = 'https://zenodo.org/record/6874123/files/korn.zip' +# KORN = '/Korn i kasse/47209 testscan korn01_recon.xtekct' - - class TestData(object): '''Class to return test data @@ -233,8 +227,8 @@ class TestData(object): SHAPES = 'shapes.png' RAINBOW = 'rainbow.png' - def __init__(self, **kwargs): - self.data_dir = kwargs.get('data_dir', data_dir) + def __init__(self, data_dir): + self.data_dir = data_dir def load(self, which, size=None, scale=(0,1), **kwargs): ''' From 6a9d77513439aea22473e86ee78a755e1a30e42c Mon Sep 17 00:00:00 2001 From: hrobarts Date: Mon, 19 Feb 2024 16:36:56 +0000 Subject: [PATCH 3/7] Update filepath in test_io --- Wrappers/Python/test/test_io.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/Wrappers/Python/test/test_io.py b/Wrappers/Python/test/test_io.py index 41c89e8c11..5b5d3fa92f 100644 --- a/Wrappers/Python/test/test_io.py +++ b/Wrappers/Python/test/test_io.py @@ -23,13 +23,14 @@ from cil.framework import AcquisitionGeometry import numpy as np import os +import sys from cil.framework import ImageGeometry from cil.io import TXRMDataReader, NEXUSDataReader, NikonDataReader, ZEISSDataReader from cil.io import TIFFWriter, TIFFStackReader from cil.io.utilities import HDF5_utilities from cil.processors import Slicer from utils import has_astra, has_nvidia -from cil.utilities.dataexample import data_dir +import cil.utilities.dataexample from cil.utilities.quality_measures import mse from cil.utilities import dataexample import shutil @@ -66,6 +67,10 @@ # change basedir to point to the location of the walnut dataset which can # be downloaded from https://zenodo.org/record/4822516 # basedir = os.path.abspath('/home/edo/scratch/Data/Walnut/valnut_2014-03-21_643_28/tomo-A/') + +data_dir = os.path.abspath( + os.path.join(sys.prefix, 'share','cil') +) basedir = data_dir filename = os.path.join(basedir, "valnut_tomo-A.txrm") has_file = os.path.isfile(filename) From 13f3a70b61902069046a6e6cbf5261450537eded Mon Sep 17 00:00:00 2001 From: hrobarts Date: Tue, 20 Feb 2024 17:30:04 +0000 Subject: [PATCH 4/7] Reorganise and add unit tests with mock http response --- Wrappers/Python/cil/utilities/dataexample.py | 69 +++++++++++--------- Wrappers/Python/test/test_dataexample.py | 66 +++++++++++++++++++ 2 files changed, 103 insertions(+), 32 deletions(-) diff --git a/Wrappers/Python/cil/utilities/dataexample.py b/Wrappers/Python/cil/utilities/dataexample.py index f29ebd11aa..277d0e0607 100644 --- a/Wrappers/Python/cil/utilities/dataexample.py +++ b/Wrappers/Python/cil/utilities/dataexample.py @@ -46,6 +46,28 @@ def get(cls, size=None, scale=(0,1), **kwargs): ddir = kwargs.get('data_dir', INTERNALDATA.data_dir) loader = TestData(data_dir=ddir) return loader.load(cls.dfile(), size, scale, **kwargs) + +class REMOTEDATA(DATA): + PATH = '' + URL = '' + + @classmethod + def download_from_url(cls, data_dir): + with urlopen(cls.URL) as response: + with BytesIO(response.read()) as bytes, ZipFile(bytes) as zipfile: + zipfile.extractall(path = data_dir) + + @classmethod + def download_data(cls, data_dir): + if os.path.isfile(os.path.join(data_dir, cls.PATH)): + print("Dataset already exists in " + data_dir) + else: + if input("Are you sure you want to download the dataset from " + cls.URL + " ? (y/n)") == "y": + print('Downloading dataset from ' + cls.URL) + cls.download_from_url(data_dir) + print('Download complete') + else: + print('Download cancelled') class BOAT(INTERNALDATA): @classmethod @@ -160,52 +182,35 @@ def get(cls, **kwargs): loader.set_up(file_name=os.path.join(os.path.abspath(ddir), 'sim_volume.nxs')) return loader.read() -class WALNUT(DATA): +class WALNUT(REMOTEDATA): + + PATH = os.path.join('valnut','valnut_2014-03-21_643_28','tomo-A','valnut_tomo-A.txrm') + URL = 'https://zenodo.org/record/4822516/files/walnut.zip' + @classmethod def get(cls, data_dir): - WALNUT = os.path.join('valnut','valnut_2014-03-21_643_28','tomo-A','valnut_tomo-A.txrm') try: - loader = ZEISSDataReader(file_name=os.path.join(data_dir,WALNUT)) + loader = ZEISSDataReader(file_name=os.path.join(data_dir,cls.PATH)) return loader.read() except(FileNotFoundError): raise(FileNotFoundError("Dataset not found in specifed data_dir: {} \n \ Specify a different data_dir or download data with dataexample.{}.download_data(data_dir)".format(data_dir, cls.__name__))) - - def download_data(data_dir): - zip_url = 'https://zenodo.org/record/4822516/files/walnut.zip' - if input("Are you sure you want to download the dataset from " + zip_url + " ? (y/n)") == "y": - print('Downloading Walnut dataset to ' + data_dir) - with urlopen(zip_url) as response: - with BytesIO(response.read()) as bytes, ZipFile(bytes) as zipfile: - zipfile.extractall(path = data_dir) - print("Download complete") - -class KORN(DATA): + + + +class KORN(REMOTEDATA): + PATH = os.path.join('Korn i kasse','47209 testscan korn01_recon.xtekct') + URL = 'https://zenodo.org/record/6874123/files/korn.zip' + @classmethod def get(cls, data_dir): - KORN = os.path.join('Korn i kasse','47209 testscan korn01_recon.xtekct') try: - loader = ZEISSDataReader(file_name=os.path.join(data_dir,KORN)) + loader = NikonDataReader(file_name=os.path.join(data_dir, cls.PATH)) return loader.read() except(FileNotFoundError): raise(FileNotFoundError("Dataset not found in specifed data_dir: {} \n \ Specify a different data_dir or download data with dataexample.{}.download_data(data_dir)".format(data_dir, cls.__name__))) - - def download_data(data_dir): - zip_url = 'https://zenodo.org/record/6874123/files/korn.zip' - if input("Are you sure you want to download the dataset from " + zip_url + " ? (y/n)") == "y": - print('Downloading Walnut dataset to ' + data_dir) - with urlopen(zip_url) as response: - with BytesIO(response.read()) as bytes, ZipFile(bytes) as zipfile: - zipfile.extractall(path = data_dir) - print("Download complete") - -# class RemoteData(object): -# WALNUT_URL = 'https://zenodo.org/record/4822516/files/walnut.zip' -# WALNUT = os.path.join('valnut','valnut_2014-03-21_643_28','tomo-A','valnut_tomo-A.txrm') - -# KORN_URL = 'https://zenodo.org/record/6874123/files/korn.zip' -# KORN = '/Korn i kasse/47209 testscan korn01_recon.xtekct' + class TestData(object): '''Class to return test data diff --git a/Wrappers/Python/test/test_dataexample.py b/Wrappers/Python/test/test_dataexample.py index 850b37a831..84e1e2a5df 100644 --- a/Wrappers/Python/test/test_dataexample.py +++ b/Wrappers/Python/test/test_dataexample.py @@ -26,6 +26,10 @@ from testclass import CCPiTestClass import platform import numpy as np +from unittest.mock import patch, MagicMock +from urllib import request +from zipfile import ZipFile +from io import StringIO initialise_tests() @@ -149,3 +153,65 @@ def test_load_SIMULATED_CONE_BEAM_DATA(self): .set_angles(np.linspace(0,360,300,False)) self.assertEqual(ag_expected,image.geometry,msg="Acquisition geometry mismatch") + +class TestRemoteData(unittest.TestCase): + + def setUp(self): + self.tmp_file = 'tmp.txt' + self.tmp_zip = 'tmp.zip' + with ZipFile(self.tmp_zip, 'w') as zipped_file: + zipped_file.writestr(self.tmp_file, np.array([1, 2, 3])) + with open(self.tmp_zip, 'rb') as zipped_file: + self.zipped_bytes = zipped_file.read() + + def tearDown(self): + if os.path.exists(self.tmp_file): + os.remove(self.tmp_file) + if os.path.exists(self.tmp_zip): + os.remove(self.tmp_zip) + + def mock_urlopen(self, mock_urlopen): + mock_response = MagicMock() + mock_response.read.return_value = self.zipped_bytes + mock_response.__enter__.return_value = mock_response + mock_urlopen.return_value = mock_response + + @patch('cil.utilities.dataexample.urlopen') + def test_unzip_remote_data(self, mock_urlopen): + self.mock_urlopen(mock_urlopen) + dataexample.REMOTEDATA.download_from_url('.') + self.assertTrue(os.path.isfile(self.tmp_file)) + + @patch('cil.utilities.dataexample.input', return_value='n') + @patch('cil.utilities.dataexample.urlopen') + def test_download_data_input_n(self, mock_urlopen, input): + self.mock_urlopen(mock_urlopen) + + # redirect print output + capturedOutput = StringIO() + sys.stdout = capturedOutput + + dataexample.WALNUT.download_data('.') + + self.assertFalse(os.path.isfile(self.tmp_file)) + self.assertEqual(capturedOutput.getvalue(),'Download cancelled\n') + + # return to standard print output + sys.stdout = sys.__stdout__ + + @patch('cil.utilities.dataexample.input', return_value='y') + @patch('cil.utilities.dataexample.urlopen') + def test_download_data_input_y(self, mock_urlopen, input): + self.mock_urlopen(mock_urlopen) + + # redirect print output + capturedOutput = StringIO() + sys.stdout = capturedOutput + + dataexample.WALNUT.download_data('.') + self.assertTrue(os.path.isfile(self.tmp_file)) + + # return to standard print output + sys.stdout = sys.__stdout__ + + \ No newline at end of file From 3258b555bff56a4cf2d0743411dddec16a0fc096 Mon Sep 17 00:00:00 2001 From: hrobarts Date: Tue, 20 Feb 2024 17:52:22 +0000 Subject: [PATCH 5/7] Tidy --- Wrappers/Python/cil/utilities/dataexample.py | 9 ++++----- Wrappers/Python/test/test_io.py | 1 - 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/Wrappers/Python/cil/utilities/dataexample.py b/Wrappers/Python/cil/utilities/dataexample.py index 277d0e0607..7b417afbbd 100644 --- a/Wrappers/Python/cil/utilities/dataexample.py +++ b/Wrappers/Python/cil/utilities/dataexample.py @@ -29,11 +29,6 @@ from io import BytesIO from cil.io import NEXUSDataReader, NikonDataReader, ZEISSDataReader -# # this is the default location after a conda install -# data_dir = os.path.abspath( -# os.path.join(sys.prefix, 'share','cil') -# ) - class DATA(object): @classmethod def dfile(cls): @@ -51,6 +46,10 @@ class REMOTEDATA(DATA): PATH = '' URL = '' + @classmethod + def get(cls, data_dir): + return None + @classmethod def download_from_url(cls, data_dir): with urlopen(cls.URL) as response: diff --git a/Wrappers/Python/test/test_io.py b/Wrappers/Python/test/test_io.py index 5b5d3fa92f..b770e8348a 100644 --- a/Wrappers/Python/test/test_io.py +++ b/Wrappers/Python/test/test_io.py @@ -30,7 +30,6 @@ from cil.io.utilities import HDF5_utilities from cil.processors import Slicer from utils import has_astra, has_nvidia -import cil.utilities.dataexample from cil.utilities.quality_measures import mse from cil.utilities import dataexample import shutil From f1657107cf16aa1f7d060833883f0dbd2c09e940 Mon Sep 17 00:00:00 2001 From: hrobarts Date: Thu, 29 Feb 2024 20:34:04 +0000 Subject: [PATCH 6/7] Add USB and sandstone datasets --- Wrappers/Python/cil/utilities/dataexample.py | 154 +++++++++++++++---- Wrappers/Python/test/test_dataexample.py | 43 ++++-- 2 files changed, 149 insertions(+), 48 deletions(-) diff --git a/Wrappers/Python/cil/utilities/dataexample.py b/Wrappers/Python/cil/utilities/dataexample.py index 7b417afbbd..8274626145 100644 --- a/Wrappers/Python/cil/utilities/dataexample.py +++ b/Wrappers/Python/cil/utilities/dataexample.py @@ -27,6 +27,7 @@ from zipfile import ZipFile from urllib.request import urlopen from io import BytesIO +from scipy.io import loadmat from cil.io import NEXUSDataReader, NikonDataReader, ZEISSDataReader class DATA(object): @@ -34,69 +35,80 @@ class DATA(object): def dfile(cls): return None -class INTERNALDATA(DATA): +class CILDATA(DATA): data_dir = os.path.abspath(os.path.join(sys.prefix, 'share','cil')) @classmethod def get(cls, size=None, scale=(0,1), **kwargs): - ddir = kwargs.get('data_dir', INTERNALDATA.data_dir) + ddir = kwargs.get('data_dir', CILDATA.data_dir) loader = TestData(data_dir=ddir) return loader.load(cls.dfile(), size, scale, **kwargs) class REMOTEDATA(DATA): - PATH = '' + + FOLDER = '' URL = '' + FILE_SIZE = '' @classmethod def get(cls, data_dir): return None @classmethod - def download_from_url(cls, data_dir): + def _download_and_extract_from_url(cls, data_dir): with urlopen(cls.URL) as response: with BytesIO(response.read()) as bytes, ZipFile(bytes) as zipfile: zipfile.extractall(path = data_dir) @classmethod def download_data(cls, data_dir): - if os.path.isfile(os.path.join(data_dir, cls.PATH)): + ''' + Download a dataset from a remote repository + + Parameters + ---------- + data_dir: str, optional + The path to the data directory where the downloaded data should be stored + + ''' + if os.path.isdir(os.path.join(data_dir, cls.FOLDER)): print("Dataset already exists in " + data_dir) else: - if input("Are you sure you want to download the dataset from " + cls.URL + " ? (y/n)") == "y": + if input("Are you sure you want to download " + cls.FILE_SIZE + " dataset from " + cls.URL + " ? (y/n)") == "y": print('Downloading dataset from ' + cls.URL) - cls.download_from_url(data_dir) + cls._download_and_extract_from_url(os.path.join(data_dir,cls.FOLDER)) print('Download complete') else: print('Download cancelled') -class BOAT(INTERNALDATA): +class BOAT(CILDATA): @classmethod def dfile(cls): return TestData.BOAT -class CAMERA(INTERNALDATA): +class CAMERA(CILDATA): @classmethod def dfile(cls): return TestData.CAMERA -class PEPPERS(INTERNALDATA): +class PEPPERS(CILDATA): @classmethod def dfile(cls): return TestData.PEPPERS -class RESOLUTION_CHART(INTERNALDATA): +class RESOLUTION_CHART(CILDATA): @classmethod def dfile(cls): return TestData.RESOLUTION_CHART -class SIMPLE_PHANTOM_2D(INTERNALDATA): +class SIMPLE_PHANTOM_2D(CILDATA): @classmethod def dfile(cls): return TestData.SIMPLE_PHANTOM_2D -class SHAPES(INTERNALDATA): +class SHAPES(CILDATA): @classmethod def dfile(cls): return TestData.SHAPES -class RAINBOW(INTERNALDATA): +class RAINBOW(CILDATA): @classmethod def dfile(cls): return TestData.RAINBOW -class SYNCHROTRON_PARALLEL_BEAM_DATA(INTERNALDATA): +class SYNCHROTRON_PARALLEL_BEAM_DATA(CILDATA): @classmethod def get(cls, **kwargs): ''' @@ -113,11 +125,11 @@ def get(cls, **kwargs): The DLS dataset ''' - ddir = kwargs.get('data_dir', INTERNALDATA.data_dir) + ddir = kwargs.get('data_dir', CILDATA.data_dir) loader = NEXUSDataReader() loader.set_up(file_name=os.path.join(os.path.abspath(ddir), '24737_fd_normalised.nxs')) return loader.read() -class SIMULATED_PARALLEL_BEAM_DATA(INTERNALDATA): +class SIMULATED_PARALLEL_BEAM_DATA(CILDATA): @classmethod def get(cls, **kwargs): ''' @@ -134,11 +146,11 @@ def get(cls, **kwargs): The simulated spheres dataset ''' - ddir = kwargs.get('data_dir', INTERNALDATA.data_dir) + ddir = kwargs.get('data_dir', CILDATA.data_dir) loader = NEXUSDataReader() loader.set_up(file_name=os.path.join(os.path.abspath(ddir), 'sim_parallel_beam.nxs')) return loader.read() -class SIMULATED_CONE_BEAM_DATA(INTERNALDATA): +class SIMULATED_CONE_BEAM_DATA(CILDATA): @classmethod def get(cls, **kwargs): ''' @@ -155,11 +167,11 @@ def get(cls, **kwargs): The simulated spheres dataset ''' - ddir = kwargs.get('data_dir', INTERNALDATA.data_dir) + ddir = kwargs.get('data_dir', CILDATA.data_dir) loader = NEXUSDataReader() loader.set_up(file_name=os.path.join(os.path.abspath(ddir), 'sim_cone_beam.nxs')) return loader.read() -class SIMULATED_SPHERE_VOLUME(INTERNALDATA): +class SIMULATED_SPHERE_VOLUME(CILDATA): @classmethod def get(cls, **kwargs): ''' @@ -176,41 +188,117 @@ def get(cls, **kwargs): The simulated spheres volume ''' - ddir = kwargs.get('data_dir', INTERNALDATA.data_dir) + ddir = kwargs.get('data_dir', CILDATA.data_dir) loader = NEXUSDataReader() loader.set_up(file_name=os.path.join(os.path.abspath(ddir), 'sim_volume.nxs')) return loader.read() class WALNUT(REMOTEDATA): - - PATH = os.path.join('valnut','valnut_2014-03-21_643_28','tomo-A','valnut_tomo-A.txrm') + ''' + A microcomputed tomography dataset of a walnut from https://zenodo.org/records/4822516 + ''' + FOLDER = 'walnut' URL = 'https://zenodo.org/record/4822516/files/walnut.zip' + FILE_SIZE = '6.4 GB' @classmethod def get(cls, data_dir): + ''' + A microcomputed tomography dataset of a walnut from https://zenodo.org/records/4822516 + This function returns the raw projection data from the .txrm file + + Parameters + ---------- + data_dir: str + The path to the directory where the dataset is stored. Data can be downloaded with dataexample.WALNUT.download_data(data_dir) + + Returns + ------- + ImageData + The walnut dataset + ''' + filepath = os.path.join(data_dir, cls.FOLDER, 'valnut','valnut_2014-03-21_643_28','tomo-A','valnut_tomo-A.txrm') try: - loader = ZEISSDataReader(file_name=os.path.join(data_dir,cls.PATH)) + loader = ZEISSDataReader(file_name=filepath) return loader.read() except(FileNotFoundError): - raise(FileNotFoundError("Dataset not found in specifed data_dir: {} \n \ - Specify a different data_dir or download data with dataexample.{}.download_data(data_dir)".format(data_dir, cls.__name__))) + raise(FileNotFoundError("Dataset .txrm file not found in specifed data_dir: {} \n \ + Specify a different data_dir or download data with dataexample.{}.download_data(data_dir)".format(filepath, cls.__name__))) - +class USB(REMOTEDATA): + ''' + A microcomputed tomography dataset of a usb memory stick from https://zenodo.org/records/4822516 + ''' + FOLDER = 'USB' + URL = 'https://zenodo.org/record/4822516/files/usb.zip' + FILE_SIZE = '3.2 GB' + + @classmethod + def get(cls, data_dir): + ''' + A microcomputed tomography dataset of a usb memory stick from https://zenodo.org/records/4822516 + This function returns the raw projection data from the .txrm file + Parameters + ---------- + data_dir: str + The path to the directory where the dataset is stored. Data can be downloaded with dataexample.WALNUT.download_data(data_dir) + + Returns + ------- + ImageData + The usb dataset + ''' + filepath = os.path.join(data_dir, cls.FOLDER, 'gruppe 4','gruppe 4_2014-03-20_1404_12','tomo-A','gruppe 4_tomo-A.txrm') + try: + loader = ZEISSDataReader(file_name=filepath) + return loader.read() + except(FileNotFoundError): + raise(FileNotFoundError("Dataset .txrm file not found in: {} \n \ + Specify a different data_dir or download data with dataexample.{}.download_data(data_dir)".format(filepath, cls.__name__))) + class KORN(REMOTEDATA): - PATH = os.path.join('Korn i kasse','47209 testscan korn01_recon.xtekct') + ''' + A microcomputed tomography dataset of a sunflower seeds in a box from https://zenodo.org/records/6874123 + ''' + FOLDER = 'korn' URL = 'https://zenodo.org/record/6874123/files/korn.zip' + FILE_SIZE = '2.9 GB' @classmethod def get(cls, data_dir): + ''' + A microcomputed tomography dataset of a sunflower seeds in a box from https://zenodo.org/records/6874123 + This function returns the raw projection data from the .xtekct file + + Parameters + ---------- + data_dir: str + The path to the directory where the dataset is stored. Data can be downloaded with dataexample.KORN.download_data(data_dir) + + Returns + ------- + ImageData + The korn dataset + ''' + filepath = os.path.join(data_dir, cls.FOLDER, 'Korn i kasse','47209 testscan korn01_recon.xtekct') try: - loader = NikonDataReader(file_name=os.path.join(data_dir, cls.PATH)) + loader = NikonDataReader(file_name=filepath) return loader.read() except(FileNotFoundError): - raise(FileNotFoundError("Dataset not found in specifed data_dir: {} \n \ - Specify a different data_dir or download data with dataexample.{}.download_data(data_dir)".format(data_dir, cls.__name__))) + raise(FileNotFoundError("Dataset .xtekct file not found in: {} \n \ + Specify a different data_dir or download data with dataexample.{}.download_data(data_dir)".format(filepath, cls.__name__))) + + +class SANDSTONE(REMOTEDATA): + ''' + A synchrotron x-ray tomography dataset of sandstone from https://zenodo.org/records/4912435 + A small subset of the data containing selected projections and 4 slices of the reconstruction + ''' + FOLDER = 'sandstone' + URL = 'https://zenodo.org/records/4912435/files/small.zip' + FILE_SIZE = '227 MB' - class TestData(object): '''Class to return test data diff --git a/Wrappers/Python/test/test_dataexample.py b/Wrappers/Python/test/test_dataexample.py index 84e1e2a5df..be0a57f946 100644 --- a/Wrappers/Python/test/test_dataexample.py +++ b/Wrappers/Python/test/test_dataexample.py @@ -22,7 +22,7 @@ from cil.framework.framework import ImageGeometry,AcquisitionGeometry from cil.utilities import dataexample from cil.utilities import noise -import os, sys +import os, sys, shutil from testclass import CCPiTestClass import platform import numpy as np @@ -157,6 +157,8 @@ def test_load_SIMULATED_CONE_BEAM_DATA(self): class TestRemoteData(unittest.TestCase): def setUp(self): + + self.data_list = ['WALNUT','USB','KORN','SANDSTONE'] self.tmp_file = 'tmp.txt' self.tmp_zip = 'tmp.zip' with ZipFile(self.tmp_zip, 'w') as zipped_file: @@ -165,11 +167,17 @@ def setUp(self): self.zipped_bytes = zipped_file.read() def tearDown(self): - if os.path.exists(self.tmp_file): - os.remove(self.tmp_file) + for data in self.data_list: + test_func = getattr(dataexample, data) + if os.path.exists(os.path.join(test_func.FOLDER)): + shutil.rmtree(test_func.FOLDER) + if os.path.exists(self.tmp_zip): os.remove(self.tmp_zip) + if os.path.exists(self.tmp_file): + os.remove(self.tmp_file) + def mock_urlopen(self, mock_urlopen): mock_response = MagicMock() mock_response.read.return_value = self.zipped_bytes @@ -179,7 +187,7 @@ def mock_urlopen(self, mock_urlopen): @patch('cil.utilities.dataexample.urlopen') def test_unzip_remote_data(self, mock_urlopen): self.mock_urlopen(mock_urlopen) - dataexample.REMOTEDATA.download_from_url('.') + dataexample.REMOTEDATA._download_and_extract_from_url('.') self.assertTrue(os.path.isfile(self.tmp_file)) @patch('cil.utilities.dataexample.input', return_value='n') @@ -187,14 +195,16 @@ def test_unzip_remote_data(self, mock_urlopen): def test_download_data_input_n(self, mock_urlopen, input): self.mock_urlopen(mock_urlopen) - # redirect print output - capturedOutput = StringIO() - sys.stdout = capturedOutput - - dataexample.WALNUT.download_data('.') - - self.assertFalse(os.path.isfile(self.tmp_file)) - self.assertEqual(capturedOutput.getvalue(),'Download cancelled\n') + data_list = ['WALNUT','USB','KORN','SANDSTONE'] + for data in data_list: + # redirect print output + capturedOutput = StringIO() + sys.stdout = capturedOutput + test_func = getattr(dataexample, data) + test_func.download_data('.') + + self.assertFalse(os.path.isfile(self.tmp_file)) + self.assertEqual(capturedOutput.getvalue(),'Download cancelled\n') # return to standard print output sys.stdout = sys.__stdout__ @@ -206,10 +216,13 @@ def test_download_data_input_y(self, mock_urlopen, input): # redirect print output capturedOutput = StringIO() - sys.stdout = capturedOutput + sys.stdout = capturedOutput - dataexample.WALNUT.download_data('.') - self.assertTrue(os.path.isfile(self.tmp_file)) + + for data in self.data_list: + test_func = getattr(dataexample, data) + test_func.download_data('.') + self.assertTrue(os.path.isfile(os.path.join(test_func.FOLDER,self.tmp_file))) # return to standard print output sys.stdout = sys.__stdout__ From 013794806ede18ef909c791d2bccdcbcf6e457eb Mon Sep 17 00:00:00 2001 From: hrobarts Date: Thu, 29 Feb 2024 20:35:32 +0000 Subject: [PATCH 7/7] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 26cb61c50d..8a5a9d5993 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ - New unit tests have been implemented for operators and functions to check for in place errors and the behaviour of `out`. - Bug fix for missing factor of 1/2 in SIRT update objective and catch in place errors in the SIRT constraint - Allow Masker to take integer arrays in addition to boolean + - Add remote data class to example data to enable download of relevant datasets from remote repositories * 23.1.0 - Fix bug in IndicatorBox proximal_conjugate