Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformers can create unconventional python module names when loading certain repositories #35570

Open
kory opened this issue Jan 8, 2025 · 2 comments · May be fixed by #36478
Open

Transformers can create unconventional python module names when loading certain repositories #35570

kory opened this issue Jan 8, 2025 · 2 comments · May be fixed by #36478
Labels

Comments

@kory
Copy link

kory commented Jan 8, 2025

System Info

  • transformers version: 4.41.1
  • Platform: Linux-5.15.0-113-generic-x86_64-with-glibc2.35
  • Python version: 3.10.15
  • Huggingface_hub version: 0.23.5
  • Safetensors version: 0.4.5
  • Accelerate version: 1.1.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.1+cu121 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: no
  • Using distributed or parallel set-up in script?: no

Who can help?

@Rocketknight1 (maybe?)

Information

Python module names cannot typically:

  • Start with anything but letter or underscore
  • Contain hyphens

Transformers can create and load python modules that break both of these conventions. This can cause unexpected behavior with code that uses the modules that transformers creates, such as creating, saving, and loading pyTorch traces from disk.

Tasks

Load a model from huggingface and trace it.

Reproduction

I try to load, trace, save to disk, and reload the model from this repo: https://huggingface.co/nomic-ai/nomic-bert-2048

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel

# Define mean pooling function
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

# Create a wrapper class for tracing
class TransformerWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        
    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        pooled = mean_pooling(outputs, attention_mask)
        return F.normalize(pooled, p=2, dim=1)

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')
tokenizer.model_max_length = 128
base_model = AutoModel.from_pretrained('nomic-ai/nomic-embed-text-v1', trust_remote_code=True)
base_model.eval()

# Create wrapped model
wrapped_model = TransformerWrapper(base_model)

# Prepare example input for tracing
example_sentences = ['example sentence']
encoded_input = tokenizer(
    example_sentences,
    padding="max_length",
    truncation=True,
    return_tensors='pt'
)

with torch.no_grad():
    output = wrapped_model(encoded_input["input_ids"], encoded_input["attention_mask"])


# Trace the model
with torch.no_grad():
    traced_model = torch.jit.trace(
        wrapped_model,
        (
            encoded_input['input_ids'],
            encoded_input['attention_mask']
        )
    )

print(type(base_model))
      
torch.jit.save(traced_model, "my_model.pt")
torch.jit.load("my_model.pt") # this will fail

The model is loaded in an unconventionally-named python module:

$ print(type(base_model))
<class 'transformers_modules.nomic-ai.nomic-bert-2048.40b98394640e630d5276807046089b233113aa87.modeling_hf_nomic_bert.NomicBertModel'>`

The module name is serialized inside the torch trace. When the trace is loaded again, it fails to parse because the module name of the class does not follow python conventions:

    return torch.jit.load(model_path)
    cpp_module = torch._C.import_ir_module(cu, str(f), map_location, _extra_files, _restore_shapes)  # type: ignore[call-arg]
RuntimeError: expected newline but found 'ident' here:
Serialized   File "code/__torch__.py", line 6
  training : bool
  _is_full_backward_hook : Optional[bool]
  model : __torch__.transformers_modules.nomic-ai.nomic-bert-2048.40b98394640e630d5276807046089b233113aa87.modeling_hf_nomic_bert.NomicBertModel
                                                                    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
  def forward(self: __torch__.TransformerWrapper,
    input_ids: Tensor,

Expected behavior

The module names created by transformers should be sanitized to follow python convention. I was able to solve this problem with a simple modification:

kory@b3fde4f

I am unsure if this is the best fix, or whether it would be considered safe, for the package as a whole, but this does fix the tracing issue I'm hitting:

print(type(base_model))
<class 'transformers_modules.nomic_ai.nomic_bert_2048._40b98394640e630d5276807046089b233113aa87.modeling_hf_nomic_bert.NomicBertModel'>
@kory kory added the bug label Jan 8, 2025
@kory kory changed the title Python code in HF repos can create unconventional python module names that break pyTorch tracing Python code in HF repos can create unconventional python module names that break pytorch tracing Jan 8, 2025
@kory kory changed the title Python code in HF repos can create unconventional python module names that break pytorch tracing Transformers can create unconventional python module names when loading certain repositories Jan 8, 2025
Copy link

github-actions bot commented Feb 8, 2025

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@Rocketknight1
Copy link
Member

Hi @kory, sorry for missing this PR earlier! I think it got lost in the pile of notifications around Christmas.

It seems like a real issue, and your fix looks reasonable - can you open a PR with it so we can run some tests and review it? It's easier if we have something concrete to start with, even if we end up making changes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants