Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't Fit a Binary Classifier that Uses Gemma Pre-trained Model Embeddings #2102

Open
rlcauvin opened this issue Feb 17, 2025 · 11 comments
Open
Assignees
Labels
Gemma Gemma model specific issues

Comments

@rlcauvin
Copy link

rlcauvin commented Feb 17, 2025

Describe the bug

Using the "gemma2_2b_en" Gemma pre-trained model in a neural network results in ValueError: Cannot get result() since the metric has not yet been built. during training.

To Reproduce

Stripped down example here: https://colab.research.google.com/drive/1r8XkaQBeUxP5fp9i1QLaikFIdbhcrKMw?usp=sharing

Expected behavior

It should be possible to use a Gemma pre-trained model as a neural network layer in a binary classifier and successfully train the model.

Additional context

This use of Gemma to generate embeddings for binary classification is based on this starting point by @jeffcarp.

Would you like to help us fix it?

No

@github-actions github-actions bot added the Gemma Gemma model specific issues label Feb 17, 2025
@rlcauvin
Copy link
Author

The reason I think I've isolated a problem in the Gemma encoding layer is that training the classifier model works fine if I swap a keras.layers.TextVectorization layer in for it.

@jeffcarp jeffcarp self-assigned this Feb 18, 2025
@jeffcarp
Copy link
Member

I think the problem is related to instantiating a sub-model (keras_hub.models.GemmaCausalLM) within the context of another model. Can you try this?

self.gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("gemma2_2b_en", compile=False)

This unblocks training locally for me but I am seeing a TPU error when trying the fix in your Colab, unsure if related.

@rlcauvin
Copy link
Author

Thanks, @jeffcarp. The change did get past the original error. Maybe it's the TPU error to which you're referring?

---------------------------------------------------------------------------
NotFoundError                             Traceback (most recent call last)
<ipython-input-14-d7863a64cd24> in <cell line: 0>()
      2 y_train = keras.ops.array([[1], [0], [0]])
      3 
----> 4 nn_model_history = nn_model.fit(
      5   x = x_train,
      6   y = y_train,

1 frames
/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py in error_handler(*args, **kwargs)
    120             # To get the full stack trace, call:
    121             # `keras.config.disable_traceback_filtering()`
--> 122             raise e.with_traceback(filtered_tb) from None
    123         finally:
    124             del filtered_tb

/usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     57       e.message += " name: " + name
     58     raise core._status_to_exception(e) from None
---> 59   except TypeError as e:
     60     keras_symbolic_tensors = [x for x in inputs if _is_keras_symbolic_tensor(x)]
     61     if keras_symbolic_tensors:

NotFoundError: Graph execution error:

Detected at node StatefulPartitionedCall defined at (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main

  File "<frozen runpy>", line 88, in _run_code

  File "/usr/local/lib/python3.11/dist-packages/colab_kernel_launcher.py", line 37, in <module>

  File "/usr/local/lib/python3.11/dist-packages/traitlets/config/application.py", line 992, in launch_instance

  File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelapp.py", line 712, in start

  File "/usr/local/lib/python3.11/dist-packages/tornado/platform/asyncio.py", line 205, in start

  File "/usr/lib/python3.11/asyncio/base_events.py", line 608, in run_forever

  File "/usr/lib/python3.11/asyncio/base_events.py", line 1936, in _run_once

  File "/usr/lib/python3.11/asyncio/events.py", line 84, in _run

  File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue

  File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py", line 499, in process_one

  File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell

  File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelbase.py", line 730, in execute_request

  File "/usr/local/lib/python3.11/dist-packages/ipykernel/ipkernel.py", line 383, in do_execute

  File "/usr/local/lib/python3.11/dist-packages/ipykernel/zmqshell.py", line 528, in run_cell

  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 2975, in run_cell

  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3030, in _run_cell

  File "/usr/local/lib/python3.11/dist-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner

  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3257, in run_cell_async

  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3473, in run_ast_nodes

  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code

  File "<ipython-input-14-d7863a64cd24>", line 4, in <cell line: 0>

  File "/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler

  File "/usr/local/lib/python3.11/dist-packages/keras/src/backend/tensorflow/trainer.py", line 371, in fit

  File "/usr/local/lib/python3.11/dist-packages/keras/src/backend/tensorflow/trainer.py", line 219, in function

  File "/usr/local/lib/python3.11/dist-packages/keras/src/backend/tensorflow/trainer.py", line 132, in multi_step_on_iterator

could not find registered transfer manager for platform Host -- check target linkage
	 [[{{node StatefulPartitionedCall}}]] [Op:__inference_multi_step_on_iterator_22334]

@jeffcarp
Copy link
Member

I'm able to get training running on GPU, but the Colab instance doesn't have enough memory to load the full Gemma preset:
https://colab.research.google.com/drive/1NoMXBJV_RDH70rTK9ueSR7nRjibv8i2a?usp=sharing

The TPU error looks like it's related to a TF version mismatch:
https://www.kaggle.com/models/google/gemma/discussion/511235

@rlcauvin
Copy link
Author

Is there a magic combination of package versions we can use to get it to run on the TPU?

@Gopi-Uppari
Copy link

Gopi-Uppari commented Feb 26, 2025

Hi @rlcauvin,

I was able to reproduce the issue using your code while running it on a TPU in google colab. To fix this, try setting run_eagerly=True in model.compile(). You can also check out this gist file or reference.

Thank you.

@rlcauvin
Copy link
Author

Thanks, @Gopi-Uppari. Using run_eagerly=True in model.compile() did, indeed, get past the graph execution error. I'm not sure why, but I also had to set compile=False when loading the preset:

self.gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("gemma2_2b_en", compile = False)

Now I'll see what happens when I deploy the binary classifer to a TensorFlow serving endpoint.

@Gopi-Uppari
Copy link

Hi @rlcauvin,

If you run into any issues during deployment, just let us know. If the issue is resolved for you. Please feel free to close the issue.

Thank you.

@rlcauvin
Copy link
Author

@Gopi-Uppari Exporting the model for TensorFlow serving fails:

nn_model.export("export/1/")

Error message:

/usr/local/lib/python3.11/dist-packages/keras/src/models/functional.py:237: UserWarning: The structure of `inputs` doesn't match the expected structure.
Expected: ['headline']
Received: inputs=Tensor(shape=(None,))
  warnings.warn(msg)
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
[<ipython-input-17-120e717fb59e>](https://localhost:8080/#) in <cell line: 0>()
----> 1 nn_model.export("export/1/")

11 frames
[/usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/polymorphic_function/saved_model_exported_concrete.py](https://localhost:8080/#) in _raise_untracked_capture_error(function_name, capture, internal_capture, node_path)
     96   if internal_capture is not None:
     97     msg += f"\n\tInternal Tensor = {internal_capture}"
---> 98   raise AssertionError(msg)
     99 
    100 

AssertionError: Tried to export a function which references an 'untracked' resource. TensorFlow objects (e.g. tf.Variable) captured by functions must be 'tracked' by assigning them to an attribute of a tracked object or assigned to an attribute of the main object directly. See the information below:
	Function name = b'__inference_signature_wrapper___call___9482'
	Captured Tensor = <ResourceHandle(name="_0_SentencepieceOp", device="/job:localhost/replica:0/task:0/device:CPU:0", container="localhost", type="tensorflow::text::(anonymous namespace)::SentencepieceResource", dtype and shapes : "[  ]")>
	Trackable referencing this tensor = <tensorflow_text.python.ops.sentencepiece_tokenizer._SentencepieceModelResource object at 0x7dd038716790>
	Internal Tensor = Tensor("9468:0", shape=(), dtype=resource)

I have updated the Colab to show the attempted export and the error.

@Gopi-Uppari
Copy link

Hi @rlcauvin,

I reproduced the issue, the error suggests that TensorFlow is struggling to track a SentencePieceOp resource, likely coming from the GemmaEncoder layer. Since this layer uses a SentencePiece tokenizer, TensorFlow is unable to properly export it.

There iss also a warning about input structure mismatch.
Expected: {'headline': tensor}
Received: Tensor(shape=(None,))
This means your model expects a dictionary with a key 'headline', but you're passing a direct tensor instead.

To fix the issue, try using the provided code that correctly extracts and tracks the SentencePiece model, while ensuring the input format aligns with the model's expectations.

Image

Could you take a look at this gist file for further reference.

Thank you.

@rlcauvin
Copy link
Author

rlcauvin commented Mar 3, 2025

@Gopi-Uppari Thanks! With your provided code for defining the function and signature and using tf.saved_model.save, the model does save with no error or warning messages. Does the fact that we couldn't get model.export to work indicate there is a bug in model.export?

In any case, I deployed it to TensorFlow serving, and when I invoked the endpoint, the following error occurred:

2025-03-03 21:52:03.569415: I external/org_tensorflow/tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: FAILED_PRECONDITION: Could not find variable dense_16/kernel. This could mean that the variable has been deleted. In TF1, it can also mean the variable is uninitialized. Debug info: container=localhost, status error message=Resource localhost/dense_16/kernel/N10tensorflow3VarE does not exist.

#011 [[{{function_node __inference_serving_fn_645315}}{{node nn_model_1/dense_16_1/Cast/ReadVariableOp}}]]

2025-03-03 21:52:03.569732: I external/org_tensorflow/tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: ABORTED: Stopping remaining executors.

Why would the variable dense_16/kernel have been deleted or not initialized?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues
Projects
None yet
Development

No branches or pull requests

4 participants