Skip to content

Commit a74e193

Browse files
authored
Merge pull request #372 from AtKristijan/get-ext-content-disposition
Get ext content disposition
2 parents 2c4a528 + ac1cbca commit a74e193

File tree

2 files changed

+51
-27
lines changed

2 files changed

+51
-27
lines changed

ricecooker/classes/files.py

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,15 @@
6363
CONVERTIBLE_FORMATS = {p.id: p.convertible_formats for p in format_presets.PRESETLIST}
6464

6565

66+
def extract_ext_from_header(res):
67+
if res:
68+
content_dis = res.headers.get("content-disposition")
69+
if content_dis:
70+
ext = content_dis.split(".")
71+
return ext[-1]
72+
return None
73+
74+
6675
def extract_path_ext(path, default_ext=None):
6776
"""
6877
Extract file extension (without dot) from `path` or return `default_ext` if
@@ -141,22 +150,23 @@ def download(path, default_ext=None):
141150
cache_file = get_cache_filename(key)
142151
if not config.UPDATE and not cache_is_outdated(path, cache_file):
143152
config.LOGGER.info("\tUsing cached file for: {}".format(path))
144-
return cache_file
153+
return cache_file, None
145154

146155
config.LOGGER.info("\tDownloading {}".format(path))
147156

148157
# Write file to temporary file
149158
with tempfile.NamedTemporaryFile(delete=False) as tempf:
150159
tempf.close()
151-
write_path_to_filename(path, tempf.name)
160+
ext = write_path_to_filename(path, tempf.name)
152161
# Get extension of file or use `default_ext` if none found
153-
ext = extract_path_ext(path, default_ext=default_ext)
162+
if not ext:
163+
ext = extract_path_ext(path, default_ext=default_ext)
154164
filename = copy_file_to_storage(tempf.name, ext=ext)
155165
FILECACHE.set(key, bytes(filename, "utf-8"))
156166
config.LOGGER.info("\t--- Downloaded {}".format(filename))
157167
os.unlink(tempf.name)
158168

159-
return filename
169+
return filename, ext
160170

161171

162172
def download_and_convert_video(path, ffmpeg_settings=None):
@@ -218,9 +228,11 @@ def write_path_to_filename(path, write_to_file):
218228
if is_valid_url(path):
219229
# CASE A: path is a URL (http://, https://, or file://, etc.)
220230
r = config.DOWNLOAD_SESSION.get(path, stream=True)
231+
default_ext = extract_ext_from_header(r)
221232
r.raise_for_status()
222233
for chunk in r:
223234
f.write(chunk)
235+
return default_ext
224236
else:
225237
# CASE B: path points to a local filesystem file
226238
with open(path, "rb") as fobj:
@@ -464,6 +476,7 @@ def process_file(self):
464476

465477
class DownloadFile(File):
466478
allowed_formats = []
479+
ext = None
467480

468481
def __init__(self, path, **kwargs):
469482
self.path = path.strip()
@@ -474,19 +487,13 @@ def validate(self):
474487
Ensure `self.path` has one of the extensions in `self.allowed_formats`.
475488
"""
476489
assert self.path, "{} must have a path".format(self.__class__.__name__)
477-
ext = extract_path_ext(self.path, default_ext=self.default_ext)
478-
# don't validate for single-digit extension, or no extension
479-
if len(ext) > 1:
480-
assert ext in self.allowed_formats, (
481-
"{} must have one of the "
482-
"following extensions: {} (instead, got '{}' from '{}')".format(
483-
self.__class__.__name__, self.allowed_formats, ext, self.path
484-
)
485-
)
486490

487491
def process_file(self):
488492
try:
489-
self.filename = download(self.path, default_ext=self.default_ext)
493+
self.filename, self.ext = download(self.path, default_ext=self.default_ext)
494+
# don't validate for single-digit extension, or no extension
495+
if not self.ext:
496+
self.ext = extract_path_ext(self.path)
490497
return self.filename
491498
# Catch errors related to reading file path and handle silently
492499
except HTTP_CAUGHT_EXCEPTIONS as err:
@@ -553,8 +560,10 @@ def process_file(self):
553560
if self.filename:
554561
try:
555562
image_path = config.get_storage_path(self.filename)
556-
ext = extract_path_ext(image_path)
557-
if ext == "svg":
563+
extension = self.ext
564+
if not extension:
565+
extension = extract_path_ext(image_path)
566+
if extension == "svg":
558567
ElementTree.parse(image_path)
559568
else:
560569
self.filename = process_image(image_path)
@@ -675,13 +684,17 @@ def validate(self):
675684
Ensure `self.path` has one of the extensions in `self.allowed_formats`.
676685
"""
677686
assert self.path, "{} must have a path".format(self.__class__.__name__)
678-
ext = extract_path_ext(self.path, default_ext=self.default_ext)
687+
extension = self.ext
688+
if not extension:
689+
extension = extract_path_ext(self.path, default_ext=self.default_ext)
679690
if (
680-
ext not in self.allowed_formats
681-
and ext not in CONVERTIBLE_FORMATS[format_presets.VIDEO_HIGH_RES]
691+
extension not in self.allowed_formats
692+
and extension not in CONVERTIBLE_FORMATS[format_presets.VIDEO_HIGH_RES]
682693
):
683694
raise ValueError(
684-
"Incompatible extension {} for VideoFile at {}".format(ext, self.path)
695+
"Incompatible extension {} for VideoFile at {}".format(
696+
self.ext, self.path
697+
)
685698
)
686699

687700
def process_unsupported_video_file(self):
@@ -701,16 +714,20 @@ def process_unsupported_video_file(self):
701714
config.FAILED_FILES.append(self)
702715

703716
def process_file(self):
704-
ext = extract_path_ext(self.path, default_ext=self.default_ext)
717+
extension = self.ext
718+
if not extension:
719+
extension = extract_path_ext(self.path, default_ext=self.default_ext)
705720
if (
706-
ext not in self.allowed_formats
707-
and ext not in CONVERTIBLE_FORMATS[format_presets.VIDEO_HIGH_RES]
721+
extension not in self.allowed_formats
722+
and extension not in CONVERTIBLE_FORMATS[format_presets.VIDEO_HIGH_RES]
708723
):
709724
raise ValueError(
710-
"Incompatible extension {} for VideoFile at {}".format(ext, self.path)
725+
"Incompatible extension {} for VideoFile at {}".format(
726+
extension, self.path
727+
)
711728
)
712729
try:
713-
if ext not in self.allowed_formats:
730+
if extension not in self.allowed_formats:
714731
# Handle videos that don't have an .mp4 or .webm extension
715732
self.filename = self.process_unsupported_video_file()
716733
else:
@@ -915,7 +932,9 @@ def validate(self):
915932
info is specified in `self.subtitlesformat`.
916933
"""
917934
assert self.path, "{} must have a path".format(self.__class__.__name__)
918-
ext = extract_path_ext(self.path, default_ext=self.subtitlesformat)
935+
ext = self.ext
936+
if not ext:
937+
ext = extract_path_ext(self.path, default_ext=self.subtitlesformat)
919938
convertible_exts = CONVERTIBLE_FORMATS[self.get_preset()]
920939
if (
921940
ext != self.default_ext

tests/conftest.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -793,7 +793,12 @@ def fake_thumbnail_file():
793793

794794
@pytest.fixture
795795
def exercise_image_file():
796-
return _ExerciseImageFile("tests/testcontent/exercises/no-wifi.png")
796+
path = os.path.abspath(
797+
os.path.join(
798+
os.path.dirname(__file__), "testcontent", "exercises", "no-wifi.png"
799+
)
800+
)
801+
return _ExerciseImageFile(path)
797802

798803

799804
@pytest.fixture

0 commit comments

Comments
 (0)