6363CONVERTIBLE_FORMATS = {p .id : p .convertible_formats for p in format_presets .PRESETLIST }
6464
6565
66+ def extract_ext_from_header (res ):
67+ if res :
68+ content_dis = res .headers .get ("content-disposition" )
69+ if content_dis :
70+ ext = content_dis .split ("." )
71+ return ext [- 1 ]
72+ return None
73+
74+
6675def extract_path_ext (path , default_ext = None ):
6776 """
6877 Extract file extension (without dot) from `path` or return `default_ext` if
@@ -141,22 +150,23 @@ def download(path, default_ext=None):
141150 cache_file = get_cache_filename (key )
142151 if not config .UPDATE and not cache_is_outdated (path , cache_file ):
143152 config .LOGGER .info ("\t Using cached file for: {}" .format (path ))
144- return cache_file
153+ return cache_file , None
145154
146155 config .LOGGER .info ("\t Downloading {}" .format (path ))
147156
148157 # Write file to temporary file
149158 with tempfile .NamedTemporaryFile (delete = False ) as tempf :
150159 tempf .close ()
151- write_path_to_filename (path , tempf .name )
160+ ext = write_path_to_filename (path , tempf .name )
152161 # Get extension of file or use `default_ext` if none found
153- ext = extract_path_ext (path , default_ext = default_ext )
162+ if not ext :
163+ ext = extract_path_ext (path , default_ext = default_ext )
154164 filename = copy_file_to_storage (tempf .name , ext = ext )
155165 FILECACHE .set (key , bytes (filename , "utf-8" ))
156166 config .LOGGER .info ("\t --- Downloaded {}" .format (filename ))
157167 os .unlink (tempf .name )
158168
159- return filename
169+ return filename , ext
160170
161171
162172def download_and_convert_video (path , ffmpeg_settings = None ):
@@ -218,9 +228,11 @@ def write_path_to_filename(path, write_to_file):
218228 if is_valid_url (path ):
219229 # CASE A: path is a URL (http://, https://, or file://, etc.)
220230 r = config .DOWNLOAD_SESSION .get (path , stream = True )
231+ default_ext = extract_ext_from_header (r )
221232 r .raise_for_status ()
222233 for chunk in r :
223234 f .write (chunk )
235+ return default_ext
224236 else :
225237 # CASE B: path points to a local filesystem file
226238 with open (path , "rb" ) as fobj :
@@ -464,6 +476,7 @@ def process_file(self):
464476
465477class DownloadFile (File ):
466478 allowed_formats = []
479+ ext = None
467480
468481 def __init__ (self , path , ** kwargs ):
469482 self .path = path .strip ()
@@ -474,19 +487,13 @@ def validate(self):
474487 Ensure `self.path` has one of the extensions in `self.allowed_formats`.
475488 """
476489 assert self .path , "{} must have a path" .format (self .__class__ .__name__ )
477- ext = extract_path_ext (self .path , default_ext = self .default_ext )
478- # don't validate for single-digit extension, or no extension
479- if len (ext ) > 1 :
480- assert ext in self .allowed_formats , (
481- "{} must have one of the "
482- "following extensions: {} (instead, got '{}' from '{}')" .format (
483- self .__class__ .__name__ , self .allowed_formats , ext , self .path
484- )
485- )
486490
487491 def process_file (self ):
488492 try :
489- self .filename = download (self .path , default_ext = self .default_ext )
493+ self .filename , self .ext = download (self .path , default_ext = self .default_ext )
494+ # don't validate for single-digit extension, or no extension
495+ if not self .ext :
496+ self .ext = extract_path_ext (self .path )
490497 return self .filename
491498 # Catch errors related to reading file path and handle silently
492499 except HTTP_CAUGHT_EXCEPTIONS as err :
@@ -553,8 +560,10 @@ def process_file(self):
553560 if self .filename :
554561 try :
555562 image_path = config .get_storage_path (self .filename )
556- ext = extract_path_ext (image_path )
557- if ext == "svg" :
563+ extension = self .ext
564+ if not extension :
565+ extension = extract_path_ext (image_path )
566+ if extension == "svg" :
558567 ElementTree .parse (image_path )
559568 else :
560569 self .filename = process_image (image_path )
@@ -675,13 +684,17 @@ def validate(self):
675684 Ensure `self.path` has one of the extensions in `self.allowed_formats`.
676685 """
677686 assert self .path , "{} must have a path" .format (self .__class__ .__name__ )
678- ext = extract_path_ext (self .path , default_ext = self .default_ext )
687+ extension = self .ext
688+ if not extension :
689+ extension = extract_path_ext (self .path , default_ext = self .default_ext )
679690 if (
680- ext not in self .allowed_formats
681- and ext not in CONVERTIBLE_FORMATS [format_presets .VIDEO_HIGH_RES ]
691+ extension not in self .allowed_formats
692+ and extension not in CONVERTIBLE_FORMATS [format_presets .VIDEO_HIGH_RES ]
682693 ):
683694 raise ValueError (
684- "Incompatible extension {} for VideoFile at {}" .format (ext , self .path )
695+ "Incompatible extension {} for VideoFile at {}" .format (
696+ self .ext , self .path
697+ )
685698 )
686699
687700 def process_unsupported_video_file (self ):
@@ -701,16 +714,20 @@ def process_unsupported_video_file(self):
701714 config .FAILED_FILES .append (self )
702715
703716 def process_file (self ):
704- ext = extract_path_ext (self .path , default_ext = self .default_ext )
717+ extension = self .ext
718+ if not extension :
719+ extension = extract_path_ext (self .path , default_ext = self .default_ext )
705720 if (
706- ext not in self .allowed_formats
707- and ext not in CONVERTIBLE_FORMATS [format_presets .VIDEO_HIGH_RES ]
721+ extension not in self .allowed_formats
722+ and extension not in CONVERTIBLE_FORMATS [format_presets .VIDEO_HIGH_RES ]
708723 ):
709724 raise ValueError (
710- "Incompatible extension {} for VideoFile at {}" .format (ext , self .path )
725+ "Incompatible extension {} for VideoFile at {}" .format (
726+ extension , self .path
727+ )
711728 )
712729 try :
713- if ext not in self .allowed_formats :
730+ if extension not in self .allowed_formats :
714731 # Handle videos that don't have an .mp4 or .webm extension
715732 self .filename = self .process_unsupported_video_file ()
716733 else :
@@ -915,7 +932,9 @@ def validate(self):
915932 info is specified in `self.subtitlesformat`.
916933 """
917934 assert self .path , "{} must have a path" .format (self .__class__ .__name__ )
918- ext = extract_path_ext (self .path , default_ext = self .subtitlesformat )
935+ ext = self .ext
936+ if not ext :
937+ ext = extract_path_ext (self .path , default_ext = self .subtitlesformat )
919938 convertible_exts = CONVERTIBLE_FORMATS [self .get_preset ()]
920939 if (
921940 ext != self .default_ext
0 commit comments