Announcing New Hugging Face and KerasHub integration
The Hugging Face Hub is a vast repository, currently hosting 750K+ public models, offering a diverse range of pre-trained models for various machine learning frameworks. Among these, 346,268 (as of the time of writing) models are built using the popular Transformers library. The KerasHub library recently added an integration with the Hub compatible with a first batch of 33 models.
In this first version, users of KerasHub were limited to only the KerasHub-based models available on the Hugging Face Hub.
from keras_hub.models import GemmaCausalLM
gemma_lm = GemmaCausalLM.from_preset(
"hf://google/gemma-2b-keras"
)
They were able to train/fine-tune the model and upload it back to the Hub (notice that the model is still a Keras model).
model.save_to_preset("./gemma-2b-finetune")
keras_hub.upload_preset(
"hf://username/gemma-2b-finetune",
"./gemma-2b-finetune"
)
They were missing out on the extensive collection of over 300K models created with the transformers library. Figure 1 shows 4k Gemma models in the Hub.
Figure 1: Gemma Models in the Hugging Face Hub (Source:https://huggingface.co/models?other=gemma) |
However, what if we told you that you can now access and use these 300K+ models with KerasHub, significantly expanding your model selection and capabilities?
from keras_hub.models import GemmaCausalLM
gemma_lm = GemmaCausalLM.from_preset(
"hf://google/gemma-2b" # this is not a keras model!
)
We're thrilled to announce a significant step forward for the Hub community: Transformers and KerasHub now have a shared model save format. This means that models of the transformers library on the Hugging Face Hub can now also be loaded directly into KerasHub - immediately making a huge range of fine-tuned models available to KerasHub users. Initially, this integration focuses on enabling the use of Gemma (1 and 2), Llama 3, and PaliGemma models, with plans to expand compatibility to a wider range of architectures in the near future.
Use a wider range of frameworks
Because KerasHub models can seamlessly use TensorFlow, JAX, or PyTorch backends, this means that a huge range of model checkpoints can now be loaded into any of these frameworks in a single line of code. Found a great checkpoint on Hugging Face, but you wish you could deploy it to TFLite for serving or port it into JAX to do research? Now you can!
How to use it
Using the integration requires updating your Keras versions
$ pip install -U -q keras-hub
$ pip install -U keras>=3.3.3
Once updated, trying out the integration is as simple as:
from keras_hub.models import Llama3CausalLM
# this model was not fine-tuned with Keras but can still be loaded
causal_lm = Llama3CausalLM.from_preset(
"hf://NousResearch/Hermes-2-Pro-Llama-3-8B"
)
causal_lm.summary()
Under the Hood: How It Works
Transformers models are stored as a set of config files in JSON format, a tokenizer (usually also a .JSON file), and a set of safetensors weights files. The actual modeling code is contained in the Transformers library itself. This means that cross-loading a Transformers checkpoint into KerasHub is relatively straightforward as long as both libraries have modeling code for the relevant architecture. All we need to do is map config variables, weight names, and tokenizer vocabularies from one format to the other, and we create a KerasHub checkpoint from a Transformers checkpoint, or vice-versa.
All of this is handled internally for you, so you can focus on trying out the models rather than converting them!
Common Use Cases
Generation
A first use case of language models is to generate text. Here is an
example to load a transformers model and generate new tokens using
the .generate
method from KerasHub.
from keras_hub.models import Llama3CausalLM
# Get the model
causal_lm = Llama3CausalLM.from_preset(
"hf://NousResearch/Hermes-2-Pro-Llama-3-8B"
)
prompts = [
"""<|im_start|>system
You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.<|im_end|>
<|im_start|>user
Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.<|im_end|>
<|im_start|>assistant""",
]
# Generate from the model
causal_lm.generate(prompts, max_length=200)[0]
Changing precision
You can change the precision of your model using keras.config
like so
import keras
keras.config.set_dtype_policy("bfloat16")
from keras_hub.models import Llama3CausalLM
causal_lm = Llama3CausalLM.from_preset(
"hf://NousResearch/Hermes-2-Pro-Llama-3-8B"
)
Using the checkpoint with JAX backend
To test drive a model using JAX, you can leverage Keras to run the model with the JAX backend. This can be achieved by simply switching Keras's backend to JAX. Here’s how you can use the model within the JAX environment.
import os
os.environ["KERAS_BACKEND"] = "jax"
from keras_hub.models import Llama3CausalLM
causal_lm = Llama3CausalLM.from_preset(
"hf://NousResearch/Hermes-2-Pro-Llama-3-8B"
)
Gemma 2
We are pleased to inform you that the Gemma 2 models are also compatible with this integration.
from keras_hub.models import GemmaCausalLM
causal_lm = keras_hub.models.GemmaCausalLM.from_preset(
"hf://google/gemma-2-9b" # This is Gemma 2!
)
PaliGemma
You can also use any PaliGemma safetensor checkpoint in your KerasHub pipeline.
from keras_hub.models import PaliGemmaCausalLM
pali_gemma_lm = PaliGemmaCausalLM.from_preset(
"hf://gokaygokay/sd3-long-captioner" # A finetuned version of PaliGemma
)
What's Next?
This is just the beginning. We envision expanding this integration to encompass an even wider range of Hugging Face models and architectures. Stay tuned for updates and be sure to explore the incredible potential that this collaboration unlocks!
I would like to take this opportunity to thank Matthew Carrigan and Matthew Watson for their help in the entire process.