Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions keras_hub/src/utils/transformers/export/gpt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import keras.ops as ops
import transformers


def get_gpt2_config(keras_model):
"""Convert Keras GPT-2 config to Hugging Face GPT2Config."""
return transformers.GPT2Config(
vocab_size=keras_model.vocabulary_size,
n_positions=keras_model.max_sequence_length,
n_embd=keras_model.hidden_dim,
n_layer=keras_model.num_layers,
n_head=keras_model.num_heads,
n_inner=keras_model.intermediate_dim,
activation_function="gelu_new",
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
summary_type="cls_index",
summary_use_proj=True,
summary_activation=None,
summary_proj_to_labels=True,
summary_first_dropout=0.1,
scale_attn_weights=True,
use_cache=True,
bos_token_id=50256,
eos_token_id=50256,
)


def get_gpt2_weights_map(keras_model, include_lm_head=False):
"""Create a weights map for a given GPT-2 model."""
weights_map = {}

# Token and position embeddings
weights_map["transformer.wte.weight"] = keras_model.get_layer(
"token_embedding"
).embeddings
weights_map["transformer.wpe.weight"] = keras_model.get_layer(
"position_embedding"
).position_embeddings

for i in range(keras_model.num_layers):
# Attention weights
# KerasHub uses Dense layers:
# kernel shape [hidden_dim, num_heads, key_dim]
# HF uses Conv1D: weight shape [hidden_dim, 3 * hidden_dim]
q_w = keras_model.get_layer(
f"transformer_layer_{i}"
)._self_attention_layer._query_dense.kernel
k_w = keras_model.get_layer(
f"transformer_layer_{i}"
)._self_attention_layer._key_dense.kernel
v_w = keras_model.get_layer(
f"transformer_layer_{i}"
)._self_attention_layer._value_dense.kernel
Comment on lines +49 to +57
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Accessing private layer attributes like _self_attention_layer and its sub-layers makes this code brittle. If the internal structure of TransformerDecoder changes, this export script will break. It would be more robust to expose these weights via a public API on the layer to create a more stable interface.

q_b = keras_model.get_layer(
f"transformer_layer_{i}"
)._self_attention_layer._query_dense.bias
k_b = keras_model.get_layer(
f"transformer_layer_{i}"
)._self_attention_layer._key_dense.bias
v_b = keras_model.get_layer(
f"transformer_layer_{i}"
)._self_attention_layer._value_dense.bias

# Flatten the head dimensions to match HF Conv1D input
q_w = ops.reshape(q_w, (keras_model.hidden_dim, keras_model.hidden_dim))
k_w = ops.reshape(k_w, (keras_model.hidden_dim, keras_model.hidden_dim))
v_w = ops.reshape(v_w, (keras_model.hidden_dim, keras_model.hidden_dim))

# Concatenate Q, K, V
c_attn_w = ops.concatenate([q_w, k_w, v_w], axis=-1)
weights_map[f"transformer.h.{i}.attn.c_attn.weight"] = c_attn_w

# Reshape biases
q_b = ops.reshape(q_b, [-1])
k_b = ops.reshape(k_b, [-1])
v_b = ops.reshape(v_b, [-1])

c_attn_b = ops.concatenate([q_b, k_b, v_b], axis=-1)
weights_map[f"transformer.h.{i}.attn.c_attn.bias"] = c_attn_b

# Attention projection
c_proj_w = keras_model.get_layer(
f"transformer_layer_{i}"
)._self_attention_layer._output_dense.kernel
c_proj_w = ops.reshape(
c_proj_w, (keras_model.hidden_dim, keras_model.hidden_dim)
)
weights_map[f"transformer.h.{i}.attn.c_proj.weight"] = c_proj_w
weights_map[f"transformer.h.{i}.attn.c_proj.bias"] = (
keras_model.get_layer(
f"transformer_layer_{i}"
)._self_attention_layer._output_dense.bias
)

# Layer norms
weights_map[f"transformer.h.{i}.ln_1.weight"] = keras_model.get_layer(
f"transformer_layer_{i}"
)._self_attention_layer_norm.gamma
weights_map[f"transformer.h.{i}.ln_1.bias"] = keras_model.get_layer(
f"transformer_layer_{i}"
)._self_attention_layer_norm.beta
weights_map[f"transformer.h.{i}.ln_2.weight"] = keras_model.get_layer(
f"transformer_layer_{i}"
)._feedforward_layer_norm.gamma
weights_map[f"transformer.h.{i}.ln_2.bias"] = keras_model.get_layer(
f"transformer_layer_{i}"
)._feedforward_layer_norm.beta

# MLP
c_fc_w = keras_model.get_layer(
f"transformer_layer_{i}"
)._feedforward_intermediate_dense.kernel
weights_map[f"transformer.h.{i}.mlp.c_fc.weight"] = c_fc_w
weights_map[f"transformer.h.{i}.mlp.c_fc.bias"] = keras_model.get_layer(
f"transformer_layer_{i}"
)._feedforward_intermediate_dense.bias
c_proj_w_mlp = keras_model.get_layer(
f"transformer_layer_{i}"
)._feedforward_output_dense.kernel
weights_map[f"transformer.h.{i}.mlp.c_proj.weight"] = c_proj_w_mlp
weights_map[f"transformer.h.{i}.mlp.c_proj.bias"] = (
keras_model.get_layer(
f"transformer_layer_{i}"
)._feedforward_output_dense.bias
)

# Final layer norm
weights_map["transformer.ln_f.weight"] = keras_model.get_layer(
"layer_norm"
).gamma
weights_map["transformer.ln_f.bias"] = keras_model.get_layer(
"layer_norm"
).beta

if include_lm_head:
# lm_head is tied to token embeddings
weights_map["lm_head.weight"] = weights_map["transformer.wte.weight"]

return weights_map


def get_gpt2_tokenizer_config(tokenizer):
return {
"model_type": "gpt2",
"bos_token": "<|endoftext|>",
"eos_token": "<|endoftext|>",
"unk_token": "<|endoftext|>",
}
71 changes: 71 additions & 0 deletions keras_hub/src/utils/transformers/export/gpt2_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import os
import shutil
import tempfile

import keras.ops as ops
from absl.testing import parameterized
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer

from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM
from keras_hub.src.tests.test_case import TestCase
from keras_hub.src.utils.transformers.export.hf_exporter import (
export_to_safetensors,
)


class GPT2ExportTest(TestCase):
@parameterized.named_parameters(
("gpt2_base_en", "gpt2_base_en"),
)
def test_gpt2_export(self, preset):
# Create a temporary directory to save the converted model.
temp_dir = tempfile.mkdtemp()
output_path = os.path.join(temp_dir, preset)

# Load Keras model.
keras_model = GPT2CausalLM.from_preset(preset)

# Export to Hugging Face format.
export_to_safetensors(keras_model, output_path)

# Load the converted model with Hugging Face Transformers.
hf_model = AutoModelForCausalLM.from_pretrained(output_path)
hf_tokenizer = AutoTokenizer.from_pretrained(output_path)

# Assertions for config parameters.
self.assertEqual(
keras_model.backbone.hidden_dim, hf_model.config.hidden_size
)
self.assertEqual(
keras_model.backbone.num_layers, hf_model.config.n_layer
)
self.assertEqual(keras_model.backbone.num_heads, hf_model.config.n_head)
self.assertEqual(
keras_model.backbone.intermediate_dim, hf_model.config.n_inner
)
self.assertEqual(
keras_model.backbone.vocabulary_size, hf_model.config.vocab_size
)
self.assertEqual(
keras_model.backbone.max_sequence_length,
hf_model.config.n_positions,
)

# Test logits.
prompt = "Hello, my name is"
token_ids = ops.array(keras_model.preprocessor.tokenizer([prompt]))
padding_mask = ops.ones_like(token_ids, dtype="int32")
keras_inputs = {"token_ids": token_ids, "padding_mask": padding_mask}
keras_logits = keras_model(keras_inputs)

hf_inputs = hf_tokenizer(prompt, return_tensors="pt")
hf_logits = hf_model(**hf_inputs).logits

keras_logits_np = ops.convert_to_numpy(keras_logits)
hf_logits_np = hf_logits.detach().cpu().numpy()

self.assertAllClose(keras_logits_np, hf_logits_np, atol=1e-3, rtol=1e-3)

# Clean up the temporary directory.
shutil.rmtree(temp_dir)
91 changes: 70 additions & 21 deletions keras_hub/src/utils/transformers/export/hf_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,35 @@
import warnings

import keras
import torch

# --- Import GPT2Tokenizer ---
from keras_hub.src.utils.transformers.export.gemma import get_gemma_config
from keras_hub.src.utils.transformers.export.gemma import (
get_gemma_tokenizer_config,
)
from keras_hub.src.utils.transformers.export.gemma import get_gemma_weights_map
from keras_hub.src.utils.transformers.export.gpt2 import get_gpt2_config
from keras_hub.src.utils.transformers.export.gpt2 import (
get_gpt2_tokenizer_config,
)
from keras_hub.src.utils.transformers.export.gpt2 import get_gpt2_weights_map

MODEL_CONFIGS = {
"GemmaBackbone": get_gemma_config,
"GPT2Backbone": get_gpt2_config,
# Add for future models, e.g., "MistralBackbone": get_mistral_config
}

MODEL_EXPORTERS = {
"GemmaBackbone": get_gemma_weights_map,
"GPT2Backbone": get_gpt2_weights_map,
# Add for future models, e.g., "MistralBackbone": get_mistral_weights_map
}

MODEL_TOKENIZER_CONFIGS = {
"GemmaTokenizer": get_gemma_tokenizer_config,
"GPT2Tokenizer": get_gpt2_tokenizer_config,
# Add for future models, e.g., "MistralTokenizer":
# get_mistral_tokenizer_config
}
Expand Down Expand Up @@ -54,23 +64,54 @@ def export_backbone(backbone, path, include_lm_head=False):
weights_dict = get_weights_fn(backbone, include_lm_head=include_lm_head)
if not weights_dict:
raise ValueError("No weights to save.")

# Save config
os.makedirs(path, exist_ok=True)
config_path = os.path.join(path, "config.json")

# Handle Config Objects vs Dicts
config_to_save = hf_config
if hasattr(hf_config, "to_dict"):
config_to_save = hf_config.to_dict()

with open(config_path, "w") as f:
json.dump(hf_config, f)
json.dump(config_to_save, f, indent=2)

# Save weights based on backend
weights_path = os.path.join(path, "model.safetensors")
if backend == "torch":
from safetensors.torch import save_file

weights_dict_contiguous = {
k: v.value.contiguous() if hasattr(v, "value") else v.contiguous()
for k, v in weights_dict.items()
}
save_file(
weights_dict_contiguous, weights_path, metadata={"format": "pt"}
)
weights_dict_torch = {}
for k, v in weights_dict.items():
tensor = v.value if hasattr(v, "value") else v

if isinstance(tensor, torch.Tensor):
t = tensor.detach().to("cpu")
elif hasattr(tensor, "numpy"):
t = torch.tensor(tensor.numpy())
elif hasattr(tensor, "__array__"):
t = torch.tensor(tensor)
else:
t = tensor

if hasattr(t, "contiguous"):
t = t.contiguous()

weights_dict_torch[k] = t

# Handle Tied Weights (GPT-2)
if (
"lm_head.weight" in weights_dict_torch
and "transformer.wte.weight" in weights_dict_torch
):
wte = weights_dict_torch["transformer.wte.weight"]
lm = weights_dict_torch["lm_head.weight"]
if wte.data_ptr() == lm.data_ptr():
weights_dict_torch["lm_head.weight"] = lm.clone().contiguous()

save_file(weights_dict_torch, weights_path, metadata={"format": "pt"})

elif backend == "tensorflow":
from safetensors.tensorflow import save_file

Expand All @@ -91,31 +132,39 @@ def export_tokenizer(tokenizer, path):
path: str. Path to save the exported tokenizer.
"""
os.makedirs(path, exist_ok=True)

# Save tokenizer assets
tokenizer.save_assets(path)

# Export tokenizer config
tokenizer_type = tokenizer.__class__.__name__
if tokenizer_type not in MODEL_TOKENIZER_CONFIGS:
raise ValueError(
"Export to Transformers format not implemented for {tokenizer_type}"
f"Export to Transformer format not implemented for {tokenizer_type}"
)
get_tokenizer_config_fn = MODEL_TOKENIZER_CONFIGS[tokenizer_type]
tokenizer_config = get_tokenizer_config_fn(tokenizer)
tokenizer_config_path = os.path.join(path, "tokenizer_config.json")
with open(tokenizer_config_path, "w") as f:
json.dump(tokenizer_config, f, indent=4)
# Rename vocabulary file
vocab_spm_path = os.path.join(path, "vocabulary.spm")
tokenizer_model_path = os.path.join(path, "tokenizer.model")
if os.path.exists(vocab_spm_path):
shutil.move(vocab_spm_path, tokenizer_model_path)
else:
warnings.warn(
f"{vocab_spm_path} not found. Tokenizer may not load "
"correctly. Ensure that the tokenizer configuration "
"is correct and that the vocabulary file is present "
"in the original model."
)

# 2. Rename files to match Hugging Face expectations
if tokenizer_type == "GemmaTokenizer":
vocab_spm_path = os.path.join(path, "vocabulary.spm")
tokenizer_model_path = os.path.join(path, "tokenizer.model")
if os.path.exists(vocab_spm_path):
shutil.move(vocab_spm_path, tokenizer_model_path)
else:
warnings.warn(f"{vocab_spm_path} not found.")

elif tokenizer_type == "GPT2Tokenizer":
# Rename vocabulary.json -> vocab.json
vocab_json_path = os.path.join(path, "vocabulary.json")
vocab_hf_path = os.path.join(path, "vocab.json")
if os.path.exists(vocab_json_path):
shutil.move(vocab_json_path, vocab_hf_path)
else:
warnings.warn(f"{vocab_json_path} not found.")


def export_to_safetensors(keras_model, path):
Expand Down
Loading