Skip to content

Commit

Permalink
Update extractors.py
Browse files Browse the repository at this point in the history
Cleaned up a bit
  • Loading branch information
LukasMut authored Dec 4, 2024
1 parent 1f95a73 commit c8b5ea6
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions thingsvision/core/extraction/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,21 +197,25 @@ def load_model_from_source(self) -> None:
else:
weights = None
self.model = model(weights=weights)
prep_function_name = self.get_keras_preprocessing(self.model_name)
if prep_function_name is None: # we do not want to change the default preprocessing if we couldn't find a keras preprocess_input function
return

# first get the module name, the the preprocess_input attached to it
prep_function = getattr(getattr(tensorflow_models, prep_function_name), 'preprocess_input')

# different models take different sized inputs. has to be accounted for
resize_dim = self.model.layers[0].input_shape[0][-2] # -2 and -3 are h and w.
self.preprocess = tf.keras.Sequential([Lambda(prep_function), tf.keras.layers.experimental.preprocessing.Resizing(resize_dim, resize_dim)])
preproc_fun_name = self.get_keras_preprocessing(self.model_name)
if isinstance(preproc_fun_name, str):
# get preprocessing function associated with a specific model
preproc_fun = self.get_preproc_fun(preproc_fun_name)
# different models take differently sized inputs. this has to be accounted for.
resize_dim = self.model.layers[0].input_shape[0][-2] # -2 and -3 are H and W dims.
self.preprocess = tf.keras.Sequential([Lambda(preproc_fun), tf.keras.layers.experimental.preprocessing.Resizing(resize_dim, resize_dim)])
else:
raise ValueError(
f"\nCould not find {self.model_name} among TensorFlow models.\n"
)


@staticmethod
def get_preproc_fun(preproc_fun_name: str) -> Callable:
"""Get the preprocessing function associated with a specific model."""
return getattr(getattr(tensorflow_models, preproc_fun_name), "preprocess_input"))


def get_keras_preprocessing(self, model_name:str) -> Union[str, None]:
"""Get the preprocessing function for the corresponding model from `tensorflow.keras.applications.*`"""

Expand All @@ -233,10 +237,11 @@ def get_keras_preprocessing(self, model_name:str) -> Union[str, None]:
(r'^Xception$', 'xception')
]
# Try each pattern
for pattern, preprocess_value in patterns:
for pattern, preproc_val in patterns:
if re.match(pattern, model_name):
return preprocess_value
# If no match found
return preproc_val

# If no match is found, print a warning message
warnings.warn(f"No preprocessing function found for model {model_name}, so falling back to default preprocessing.\nOften, models that come from Keras Applications have their own preprocessing functions, therefore this may create inaccurate results. If you need to manually specify a preprocessing function, please do so under the `transforms` argument when creating your DataSet")
return None

Expand Down

0 comments on commit c8b5ea6

Please sign in to comment.