24
24
BLOCK_SIZE = 8192
25
25
26
26
27
- def get_local_binary_path (name : str , url : str ) -> str :
27
+ def get_local_binary_path (name : str , url : str , tmp_dir : Optional [ str ] = None ) -> str :
28
28
"""
29
29
Returns the path to the executable previously downloaded with the name argument. If
30
30
None is found, the executable at the url argument will be downloaded and stored
31
31
under name for future uses.
32
32
:param name: The name that will be given to the folder containing the extracted data
33
33
:param url: The URL of the zip file
34
+ :param: tmp_dir: Optional override for the temporary directory to save binaries and zips in.
34
35
"""
35
36
NUMBER_ATTEMPTS = 5
36
- with FileLock (os .path .join (tempfile .gettempdir (), name + ".lock" )):
37
- path = get_local_binary_path_if_exists (name , url )
37
+ tmp_dir = tmp_dir or tempfile .gettempdir ()
38
+ lock = FileLock (os .path .join (tmp_dir , name + ".lock" ))
39
+ with lock :
40
+ path = get_local_binary_path_if_exists (name , url , tmp_dir = tmp_dir )
38
41
if path is None :
39
42
logger .debug (
40
43
f"Local environment { name } not found, downloading environment from { url } "
@@ -45,7 +48,7 @@ def get_local_binary_path(name: str, url: str) -> str:
45
48
if path is not None :
46
49
break
47
50
try :
48
- download_and_extract_zip (url , name )
51
+ download_and_extract_zip (url , name , tmp_dir = tmp_dir )
49
52
except Exception :
50
53
if attempt + 1 < NUMBER_ATTEMPTS :
51
54
logger .warning (
@@ -54,7 +57,7 @@ def get_local_binary_path(name: str, url: str) -> str:
54
57
)
55
58
else :
56
59
raise
57
- path = get_local_binary_path_if_exists (name , url )
60
+ path = get_local_binary_path_if_exists (name , url , tmp_dir = tmp_dir )
58
61
59
62
if path is None :
60
63
raise FileNotFoundError (
@@ -64,15 +67,16 @@ def get_local_binary_path(name: str, url: str) -> str:
64
67
return path
65
68
66
69
67
- def get_local_binary_path_if_exists (name : str , url : str ) -> Optional [str ]:
70
+ def get_local_binary_path_if_exists (name : str , url : str , tmp_dir : str ) -> Optional [str ]:
68
71
"""
69
72
Recursively searches for a Unity executable in the extracted files folders. This is
70
73
platform dependent : It will only return a Unity executable compatible with the
71
74
computer's OS. If no executable is found, None will be returned.
72
75
:param name: The name/identifier of the executable
73
76
:param url: The url the executable was downloaded from (for verification)
77
+ :param: tmp_dir: Optional override for the temporary directory to save binaries and zips in.
74
78
"""
75
- _ , bin_dir = get_tmp_dir ( )
79
+ _ , bin_dir = get_tmp_dirs ( tmp_dir )
76
80
extension = None
77
81
78
82
if platform == "linux" or platform == "linux2" :
@@ -100,27 +104,27 @@ def get_local_binary_path_if_exists(name: str, url: str) -> Optional[str]:
100
104
return None
101
105
102
106
103
- def _get_tmp_dir_helper () :
104
- TEMPDIR = "/tmp" if platform == "darwin" else tempfile .gettempdir ()
107
+ def _get_tmp_dir_helper (tmp_dir : Optional [ str ] = None ) -> Tuple [ str , str ] :
108
+ tmp_dir = tmp_dir or ( "/tmp" if platform == "darwin" else tempfile .gettempdir () )
105
109
MLAGENTS = "ml-agents-binaries"
106
110
TMP_FOLDER_NAME = "tmp"
107
111
BINARY_FOLDER_NAME = "binaries"
108
- mla_directory = os .path .join (TEMPDIR , MLAGENTS )
112
+ mla_directory = os .path .join (tmp_dir , MLAGENTS )
109
113
if not os .path .exists (mla_directory ):
110
114
os .makedirs (mla_directory )
111
115
os .chmod (mla_directory , 16877 )
112
- zip_directory = os .path .join (TEMPDIR , MLAGENTS , TMP_FOLDER_NAME )
116
+ zip_directory = os .path .join (tmp_dir , MLAGENTS , TMP_FOLDER_NAME )
113
117
if not os .path .exists (zip_directory ):
114
118
os .makedirs (zip_directory )
115
119
os .chmod (zip_directory , 16877 )
116
- bin_directory = os .path .join (TEMPDIR , MLAGENTS , BINARY_FOLDER_NAME )
120
+ bin_directory = os .path .join (tmp_dir , MLAGENTS , BINARY_FOLDER_NAME )
117
121
if not os .path .exists (bin_directory ):
118
122
os .makedirs (bin_directory )
119
123
os .chmod (bin_directory , 16877 )
120
- return ( zip_directory , bin_directory )
124
+ return zip_directory , bin_directory
121
125
122
126
123
- def get_tmp_dir ( ) -> Tuple [str , str ]:
127
+ def get_tmp_dirs ( tmp_dir : Optional [ str ] = None ) -> Tuple [str , str ]:
124
128
"""
125
129
Returns the path to the folder containing the downloaded zip files and the extracted
126
130
binaries. If these folders do not exist, they will be created.
@@ -130,21 +134,24 @@ def get_tmp_dir() -> Tuple[str, str]:
130
134
# Should only be able to error out 3 times (once for each subdir).
131
135
for _attempt in range (3 ):
132
136
try :
133
- return _get_tmp_dir_helper ()
137
+ return _get_tmp_dir_helper (tmp_dir )
134
138
except FileExistsError :
135
139
continue
136
- return _get_tmp_dir_helper ()
140
+ return _get_tmp_dir_helper (tmp_dir )
137
141
138
142
139
- def download_and_extract_zip (url : str , name : str ) -> None :
143
+ def download_and_extract_zip (
144
+ url : str , name : str , tmp_dir : Optional [str ] = None
145
+ ) -> None :
140
146
"""
141
147
Downloads a zip file under a URL, extracts its contents into a folder with the name
142
148
argument and gives chmod 755 to all the files it contains. Files are downloaded and
143
149
extracted into special folders in the temp folder of the machine.
144
150
:param url: The URL of the zip file
145
151
:param name: The name that will be given to the folder containing the extracted data
152
+ :param: tmp_dir: Optional override for the temporary directory to save binaries and zips in.
146
153
"""
147
- zip_dir , bin_dir = get_tmp_dir ( )
154
+ zip_dir , bin_dir = get_tmp_dirs ( tmp_dir )
148
155
url_hash = "-" + hashlib .md5 (url .encode ()).hexdigest ()
149
156
binary_path = os .path .join (bin_dir , name + url_hash )
150
157
if os .path .exists (binary_path ):
@@ -206,7 +213,7 @@ def load_remote_manifest(url: str) -> Dict[str, Any]:
206
213
"""
207
214
Converts a remote yaml file into a Python dictionary
208
215
"""
209
- tmp_dir , _ = get_tmp_dir ()
216
+ tmp_dir , _ = get_tmp_dirs ()
210
217
try :
211
218
request = urllib .request .urlopen (url , timeout = 30 )
212
219
except urllib .error .HTTPError as e : # type: ignore
0 commit comments