Skip to content

Commit b1cbd13

Browse files
Merge pull request #12 from diffgram/make-py-torch-optional
make pytorch optional
2 parents ed4d850 + 7015f69 commit b1cbd13

File tree

3 files changed

+8
-4
lines changed

3 files changed

+8
-4
lines changed

sdk/diffgram/core/directory.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from diffgram.file.file import File
22
from ..regular.regular import refresh_from_dict
33
import logging
4-
from diffgram.pytorch_diffgram.diffgram_pytorch_dataset import DiffgramPytorchDataset
54
from diffgram.tensorflow_diffgram.diffgram_tensorflow_dataset import DiffgramTensorflowDataset
65
from diffgram.core.diffgram_dataset_iterator import DiffgramDatasetIterator
76
from multiprocessing.pool import ThreadPool as Pool
@@ -155,6 +154,7 @@ def to_pytorch(self, transform = None):
155154
Transforms the file list inside the dataset into a pytorch dataset.
156155
:return:
157156
"""
157+
from diffgram.pytorch_diffgram.diffgram_pytorch_dataset import DiffgramPytorchDataset
158158
file_id_list = self.file_id_list
159159
pytorch_dataset = DiffgramPytorchDataset(
160160
project = self.client,

sdk/diffgram/core/sliced_directory.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from diffgram.core.directory import Directory
2-
from diffgram.pytorch_diffgram.diffgram_pytorch_dataset import DiffgramPytorchDataset
32
from diffgram.tensorflow_diffgram.diffgram_tensorflow_dataset import DiffgramTensorflowDataset
43
import urllib
54

@@ -37,7 +36,7 @@ def to_pytorch(self, transform = None):
3736
Transforms the file list inside the dataset into a pytorch dataset.
3837
:return:
3938
"""
40-
39+
from diffgram.pytorch_diffgram.diffgram_pytorch_dataset import DiffgramPytorchDataset
4140
pytorch_dataset = DiffgramPytorchDataset(
4241
project = self.client,
4342
diffgram_file_id_list = self.file_id_list,

sdk/diffgram/pytorch_diffgram/diffgram_pytorch_dataset.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
from torch.utils.data import Dataset, DataLoader
2-
import torch as torch # type: ignore
32
from diffgram.core.diffgram_dataset_iterator import DiffgramDatasetIterator
43

4+
try:
5+
import torch as torch # type: ignore
6+
except ModuleNotFoundError:
7+
raise ModuleNotFoundError(
8+
"'torch' module should be installed to convert the Dataset into torch (pytorch) format"
9+
)
510

611
class DiffgramPytorchDataset(DiffgramDatasetIterator, Dataset):
712

0 commit comments

Comments
 (0)