Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(export+models): Enhance support for dictionary-based model input signatures in TensorFlow and JAX #20842

Open
wants to merge 16 commits into
base: master
Choose a base branch
from

Conversation

harshaljanjani
Copy link
Contributor

@harshaljanjani harshaljanjani commented Feb 2, 2025

This fix goes beyond the requirements of the issue and adds support for handling Keras models with dictionary-based inputs, particularly when exporting to the TFSavedModel format for both the TensorFlow and JAX backends. Previously, models with dictionary inputs would fail during export with ValueErrors related to input structure mismatches.

Key changes:

  • Added proper handling of model._input_names for dictionary-based inputs in Functional and Model classes
  • Enhanced input signature generation to properly handle dictionary input structures
  • Added test coverage for dictionary input model export
  • Added debug logging throughout the input signature handling flow

This PR aims to fix #20835 where models with dictionary inputs would fail to export properly to SavedModel format.

Example of fixed functionality:

inputs = {
  "foo": layers.Input(shape=()),
  "bar": layers.Input(shape=()),
}
outputs = layers.Add()([inputs["foo"], inputs["bar"]])
model = models.Model(inputs, outputs)
ref_input = {"foo": tf.constant([1.0]), "bar": tf.constant([2.0])}
ref_output = model(ref_input)
model.export('/test', format="tf_saved_model")
revived_model = tf.saved_model.load(temp_filepath)
revived_output = revived_model.serve(ref_input)

…cases)

- Improves input structure validation in Model and Functional classes

- Adds strict validation with clear error messages for mismatches
@codecov-commenter
Copy link

codecov-commenter commented Feb 3, 2025

Codecov Report

Attention: Patch coverage is 59.61538% with 42 lines in your changes missing coverage. Please review.

Project coverage is 82.20%. Comparing base (e045b6a) to head (72371d7).
Report is 4 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/models/functional.py 55.76% 12 Missing and 11 partials ⚠️
keras/src/models/model.py 21.05% 10 Missing and 5 partials ⚠️
keras/src/export/export_utils.py 71.42% 2 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #20842      +/-   ##
==========================================
- Coverage   82.26%   82.20%   -0.06%     
==========================================
  Files         561      561              
  Lines       52693    53035     +342     
  Branches     8146     8228      +82     
==========================================
+ Hits        43347    43600     +253     
- Misses       7344     7391      +47     
- Partials     2002     2044      +42     
Flag Coverage Δ
keras 82.02% <59.61%> (-0.06%) ⬇️
keras-jax 64.14% <54.80%> (-0.11%) ⬇️
keras-numpy 58.89% <23.07%> (-0.10%) ⬇️
keras-openvino 32.36% <1.92%> (-0.19%) ⬇️
keras-tensorflow 64.60% <46.15%> (-0.24%) ⬇️
keras-torch 64.16% <39.42%> (-0.12%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

- Added shape logging for preprocessed and crossed features

- Added debug messages in Functional model processing (To be removed)
@harshaljanjani
Copy link
Contributor Author

harshaljanjani commented Feb 3, 2025

TODO: Remove debugging prints.
Edit: Done and ready for review.

@harshaljanjani harshaljanjani marked this pull request as ready for review February 3, 2025 19:54
@harshaljanjani
Copy link
Contributor Author

Ping: @mattdangerw @fchollet.

Copy link
Member

@jeffcarp jeffcarp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! Left some comments.

@harshaljanjani
Copy link
Contributor Author

Thanks for the review! Will fix it at the earliest.

@harshaljanjani
Copy link
Contributor Author

Please do look into the changes and let me know if it's the intended behavior. Looking forward to your guidance, thanks!

@harshaljanjani harshaljanjani changed the title fix(export+models): Enhance input signature handling for dictionary-based models fix(export+models): Enhance support for dictionary-based model input signatures in TensorFlow and JAX Feb 9, 2025
@harshaljanjani
Copy link
Contributor Author

Ping: @fchollet, looking forward to an update on this PR. Thanks!

@fchollet
Copy link
Collaborator

Thanks for the update!

This branch has conflicts that must be resolved

Please resolve merge conflicts.

@jeffcarp does the PR look good?

@harshaljanjani
Copy link
Contributor Author

@fchollet
Resolved merge conflicts, thanks!

Copy link
Member

@jeffcarp jeffcarp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % some nits

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Feb 18, 2025
@google-ml-butler google-ml-butler bot removed the ready to pull Ready to be merged into the codebase label Feb 19, 2025
@harshaljanjani
Copy link
Contributor Author

@fchollet @jeffcarp
Fixed! Thanks for reviewing the PR.

Copy link
Member

@jeffcarp jeffcarp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Feb 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
awaiting review ready to pull Ready to be merged into the codebase size:M
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ValueErrors when calling Model.export() for TF SavedModel format on Keras Models with dict inputs
6 participants