|
7 | 7 | import os |
8 | 8 | import requests |
9 | 9 | import warnings |
10 | | -import yaml |
11 | | -from typing import Iterable, List, Dict, Any, Optional |
12 | | - |
13 | | - |
14 | | - |
15 | | -def _gather_files_from_response(resp: requests.Response) -> List[Dict[str, Any]]: |
16 | | - """ |
17 | | - Normalize Figshare API responses into a list of file dicts. |
18 | | -
|
19 | | - Supports: |
20 | | - 1) Article endpoint: https://api.figshare.com/v2/articles/{id} |
21 | | - -> JSON object with key 'files' (list) |
22 | | -
|
23 | | - 2) Files endpoint: https://api.figshare.com/v2/articles/{id}/files[?...] |
24 | | - -> JSON list of file objects (possibly paginated with Link headers) |
25 | | - """ |
26 | | - data = resp.json() |
27 | | - if isinstance(data, dict) and "files" in data and isinstance(data["files"], list): |
28 | | - return data["files"] |
29 | | - if isinstance(data, list): |
30 | | - return data |
31 | | - raise ValueError("Unexpected Figshare API response structure; expected dict with 'files' " |
32 | | - "or a list of file objects.") |
33 | | - |
34 | | - |
35 | | -def _iter_paginated_files(url: str, session: Optional[requests.Session] = None) -> Iterable[Dict[str, Any]]: |
36 | | - """ |
37 | | - Iterate over all files, following 'Link: <...>; rel=\"next\"' pagination if present. |
38 | | - Works for both the article endpoint (no pagination) and the files endpoint (may paginate). |
39 | | - """ |
40 | | - sess = session or requests.Session() |
41 | | - next_url = url |
42 | | - |
43 | | - while next_url: |
44 | | - resp = sess.get(next_url) |
45 | | - if resp.status_code != 200: |
46 | | - raise Exception(f"Failed to get dataset details from Figshare: {resp.text}") |
47 | | - |
48 | | - for f in _gather_files_from_response(resp): |
49 | | - yield f |
50 | 10 |
|
51 | | - # RFC5988-style 'Link' header pagination |
52 | | - link = resp.headers.get("Link") or resp.headers.get("link") |
53 | | - next_url = None |
54 | | - if link: |
55 | | - parts = [p.strip() for p in link.split(",")] |
56 | | - for part in parts: |
57 | | - if 'rel="next"' in part: |
58 | | - start = part.find("<") + 1 |
59 | | - end = part.find(">", start) |
60 | | - if start > 0 and end > start: |
61 | | - next_url = part[start:end] |
62 | | - break |
| 11 | +import yaml |
63 | 12 |
|
64 | 13 | def download( |
65 | 14 | name: str='all', |
@@ -97,73 +46,81 @@ def download( |
97 | 46 | local_path = Path(local_path) |
98 | 47 |
|
99 | 48 | if not local_path.exists(): |
100 | | - local_path.mkdir(parents=True, exist_ok=True) |
| 49 | + Path.mkdir(local_path) |
101 | 50 | # Get the dataset details |
102 | 51 | with resources.open_text('coderdata', 'dataset.yml') as f: |
103 | 52 | data_information = yaml.load(f, Loader=yaml.FullLoader) |
104 | 53 | url = data_information['figshare'] |
| 54 | + |
| 55 | + response = requests.get(url) |
| 56 | + if response.status_code != 200: |
| 57 | + raise Exception( |
| 58 | + f"Failed to get dataset details from Figshare: {response.text}" |
| 59 | + ) |
| 60 | + |
| 61 | + data = response.json() |
105 | 62 |
|
106 | | - name = (name or "all").casefold() |
107 | | - session = requests.Session() |
108 | | - all_files = list(_iter_paginated_files(url, session=session)) |
| 63 | + # making sure that we are case insensitive |
| 64 | + name = name.casefold() |
109 | 65 |
|
| 66 | + # Filter files by the specified prefix |
110 | 67 | if name != "all": |
111 | 68 | filtered_files = [ |
112 | | - f for f in all_files |
113 | | - if (f.get('name', '').casefold().startswith(name)) or ('genes' in f.get('name', '').casefold()) |
114 | | - ] |
| 69 | + file |
| 70 | + for file |
| 71 | + in data['files'] |
| 72 | + if file['name'].startswith(name) or 'genes' in file['name'] |
| 73 | + ] |
115 | 74 | else: |
116 | | - filtered_files = all_files |
| 75 | + filtered_files = data['files'] |
117 | 76 |
|
| 77 | + # Group files by name and select the one with the highest ID |
118 | 78 | unique_files = {} |
119 | 79 | for file in filtered_files: |
120 | | - fname = file.get('name') |
121 | | - fid = file.get('id') |
122 | | - if fname is None or fid is None: |
123 | | - continue |
124 | | - file_name = local_path.joinpath(fname) |
125 | | - if (file_name not in unique_files) or (fid > unique_files[file_name]['id']): |
126 | | - unique_files[file_name] = {'file_info': file, 'id': fid} |
| 80 | + file_name = local_path.joinpath(file['name']) |
| 81 | + file_id = file['id'] |
| 82 | + if ( |
| 83 | + file_name not in unique_files |
| 84 | + or file_id > unique_files[file_name]['id'] |
| 85 | + ): |
| 86 | + unique_files[file_name] = {'file_info': file, 'id': file_id} |
127 | 87 |
|
128 | 88 | for file_name, file_data in unique_files.items(): |
129 | 89 | file_info = file_data['file_info'] |
130 | 90 | file_id = str(file_info['id']) |
131 | | - file_url = f"https://api.figshare.com/v2/file/download/{file_id}" |
132 | | - file_md5sum = file_info.get('supplied_md5') |
133 | | - |
134 | | - if file_name.exists() and not exist_ok: |
135 | | - warnings.warn( |
136 | | - f"{file_name} already exists. Use argument 'exist_ok=True' to overwrite the existing file." |
137 | | - ) |
138 | | - |
| 91 | + file_url = "https://api.figshare.com/v2/file/download/" + file_id |
| 92 | + file_md5sum = file_info['supplied_md5'] |
139 | 93 | retry_count = 10 |
| 94 | + # Download the file |
140 | 95 | while retry_count > 0: |
141 | | - with session.get(file_url, stream=True) as r: |
| 96 | + with requests.get(file_url, stream=True) as r: |
142 | 97 | r.raise_for_status() |
143 | | - with open(file_name, 'wb') as f: |
144 | | - for chunk in r.iter_content(chunk_size=8192): |
145 | | - f.write(chunk) |
146 | | - |
147 | | - if file_md5sum: |
148 | | - with open(file_name, 'rb') as f: |
149 | | - check_md5sum = md5(f.read()).hexdigest() |
150 | | - if file_md5sum == check_md5sum: |
151 | | - break |
152 | | - else: |
153 | | - retry_count -= 1 |
154 | | - if retry_count > 0: |
155 | | - warnings.warn( |
156 | | - f"{file_name} failed MD5 verification " |
157 | | - f"(expected: {file_md5sum}, got: {check_md5sum}). Retrying..." |
| 98 | + if file_name.exists() and not exist_ok: |
| 99 | + warnings.warn( |
| 100 | + f"{file_name} already exists. Use argument 'exist_ok=True'" |
| 101 | + "to overwrite existing file." |
158 | 102 | ) |
159 | | - else: |
| 103 | + else: |
| 104 | + with open(file_name, 'wb') as f: |
| 105 | + for chunk in r.iter_content(chunk_size=8192): |
| 106 | + f.write(chunk) |
| 107 | + with open(file_name, 'rb') as f: |
| 108 | + check_md5sum = md5(f.read()).hexdigest() |
| 109 | + if file_md5sum == check_md5sum: |
160 | 110 | break |
161 | | - |
162 | | - if retry_count == 0 and file_md5sum: |
| 111 | + elif retry_count > 0: |
| 112 | + warnings.warn( |
| 113 | + f"{file_name} could not be downloaded successfully. " |
| 114 | + f"(expected md5sum: {file_md5sum} - " |
| 115 | + f"calculated md5sum: {check_md5sum})... retrying..." |
| 116 | + ) |
| 117 | + retry_count = retry_count - 1 |
| 118 | + if retry_count == 0: |
163 | 119 | warnings.warn( |
164 | | - f"{file_name} could not be downloaded with a matching MD5 after retries." |
165 | | - ) |
| 120 | + f"{file_name} could not be downloaded. Try again." |
| 121 | + ) |
166 | 122 | else: |
167 | 123 | print(f"Downloaded '{file_url}' to '{file_name}'") |
168 | 124 |
|
| 125 | + return |
169 | 126 |
|
0 commit comments