Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, BitsAndBytesConfig | |
import tempfile | |
from huggingface_hub import HfApi | |
from huggingface_hub import list_models | |
from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
from packaging import version | |
import os | |
def hello(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None) -> str: | |
# ^ expect a gr.OAuthProfile object as input to get the user's profile | |
# if the user is not logged in, profile will be None | |
if profile is None: | |
return "Hello !" | |
return f"Hello {profile.name} ! Welcome to BitsAndBytes Space" | |
def check_model_exists(oauth_token: gr.OAuthToken | None, username, quantization_type, model_name, quantized_model_name): | |
"""Check if a model exists in the user's Hugging Face repository.""" | |
try: | |
models = list_models(author=username, token=oauth_token.token) | |
model_names = [model.id for model in models] | |
if quantized_model_name : | |
repo_name = f"{username}/{quantized_model_name}" | |
else : | |
repo_name = f"{username}/{model_name.split('/')[-1]}-BNB-{quantization_type}" | |
if repo_name in model_names: | |
return f"Model '{repo_name}' already exists in your repository." | |
else: | |
return None # Model does not exist | |
except Exception as e: | |
return f"Error checking model existence: {str(e)}" | |
def create_model_card(model_name, quantization_type, threshold, quant_type_4, double_quant_4,): | |
model_card = f"""--- | |
base_model: | |
- {model_name} | |
--- | |
# {model_name} (Quantized) | |
## Description | |
This model is a quantized version of the original model `{model_name}`. It has been quantized using {quantization_type} quantization with bitsandbytes. | |
## Quantization Details | |
- **Quantization Type**: {quantization_type} | |
- **Threshold**: {threshold if quantization_type == "int8" else None} | |
- **bnb_4bit_quant_type**: {quant_type_4 if quantization_type == "int4" else None} | |
- **bnb_4bit_use_double_quant**: {double_quant_4 if quantization_type=="int4" else None} | |
## Usage | |
You can use this model in your applications by loading it directly from the Hugging Face Hub: | |
```python | |
from transformers import AutoModel | |
model = AutoModel.from_pretrained("{model_name}")""" | |
return model_card | |
def load_model(model_name, quantization_config, auth_token) : | |
return AutoModel.from_pretrained(model_name, quantization_config=quantization_config, device_map="cpu", use_auth_token=auth_token.token) | |
def quantize_model(model_name, quantization_type, threshold, quant_type_4, double_quant_4, auth_token=None, username=None): | |
print(f"Quantizing model: {quantization_type}") | |
if quantization_type=="int4": | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type=quant_type_4, | |
bnb_4bit_use_double_quant=True if double_quant_4 == "True" else False, | |
) | |
else : | |
quantization_config = BitsAndBytesConfig( | |
load_in_8bit=True, | |
llm_int8_threshold=threshold, | |
) | |
model = load_model(model_name, quantization_config=quantization_config, auth_token=auth_token) | |
return model | |
def save_model(model, model_name, quantization_type, threshold, quant_type_4, double_quant_4, username=None, auth_token=None, quantized_model_name=None): | |
print("Saving quantized model") | |
with tempfile.TemporaryDirectory() as tmpdirname: | |
model.save_pretrained(tmpdirname, safe_serialization=True, use_auth_token=auth_token.token) | |
if quantized_model_name : | |
repo_name = f"{username}/{quantized_model_name}" | |
else : | |
repo_name = f"{username}/{model_name.split('/')[-1]}-BNB-{quantization_type}" | |
model_card = create_model_card(repo_name, quantization_type, threshold, quant_type_4, double_quant_4) | |
with open(os.path.join(tmpdirname, "README.md"), "w") as f: | |
f.write(model_card) | |
# Push to Hub | |
api = HfApi(token=auth_token.token) | |
api.create_repo(repo_name, exist_ok=True) | |
api.upload_folder( | |
folder_path=tmpdirname, | |
repo_id=repo_name, | |
repo_type="model", | |
) | |
return f'<h1> 🤗 DONE</h1><br/>Find your repo here: <a href="https://huggingface.co/{repo_name}" target="_blank" style="text-decoration:underline">{repo_name}</a>' | |
def is_float(value): | |
try: | |
float(value) | |
return True | |
except ValueError: | |
return False | |
def quantize_and_save(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None, model_name, quantization_type, threshold, quant_type_4, double_quant_4, quantized_model_name): | |
if oauth_token is None : | |
return "Error : Please Sign In to your HuggingFace account to use the quantizer" | |
if not profile: | |
return "Error: Please Sign In to your HuggingFace account to use the quantizer" | |
exists_message = check_model_exists(oauth_token, profile.username, quantization_type, model_name, quantized_model_name) | |
if exists_message : | |
return exists_message | |
if not is_float(threshold) : | |
return "Threshold must be a float" | |
threshold = float(threshold) | |
# try: | |
quantized_model = quantize_model(model_name, quantization_type, threshold, quant_type_4, double_quant_4, oauth_token, profile.username) | |
return save_model(quantized_model, model_name, quantization_type, threshold, quant_type_4, double_quant_4, profile.username, oauth_token, quantized_model_name) | |
# except Exception as e : | |
# print(e) | |
# return f"An error occurred: {str(e)}" | |
css="""/* Custom CSS to allow scrolling */ | |
.gradio-container {overflow-y: auto;} | |
.custom-radio { | |
margin-left: 20px; /* Adjust the value as needed */ | |
} | |
""" | |
with gr.Blocks(theme=gr.themes.Ocean(), css=css) as demo: | |
gr.Markdown( | |
""" | |
# 🤗 LLM Model BitsAndBytes Quantization App | |
Quantize your favorite Hugging Face models using BitsAndBytes and save them to your profile! | |
""" | |
) | |
gr.LoginButton(elem_id="login-button", elem_classes="center-button", min_width=250) | |
m1 = gr.Markdown() | |
demo.load(hello, inputs=None, outputs=m1) | |
# radio = gr.Radio(["show", "hide"], label="Show Instructions") | |
instructions = gr.Markdown( | |
""" | |
## Instructions | |
1. Login to your HuggingFace account | |
2. Enter the name of the Hugging Face LLM model you want to quantize (Make sure you have access to it) | |
3. Choose the quantization type. | |
4. Optionally, specify the group size. | |
5. Optionally, choose a custom name for the quantized model | |
6. Click "Quantize and Save Model" to start the process. | |
7. Once complete, you'll receive a link to the quantized model on Hugging Face. | |
Note: This process may take some time depending on the model size and your hardware you can check the container logs to see where are you at in the process! | |
""", | |
visible=False | |
) | |
instructions_visible = gr.State(False) | |
toggle_button = gr.Button("▼ Show Instructions", elem_id="toggle-button", elem_classes="toggle-button") | |
def toggle_instructions(instructions_visible): | |
new_visibility = not instructions_visible # Toggle the state | |
new_label = "▲ Hide Instructions" if new_visibility else "▼ Show Instructions" # Change label based on visibility | |
return gr.update(visible=new_visibility), new_visibility, gr.update(value=new_label) # Toggle visibility and return new state | |
toggle_button.click(toggle_instructions, instructions_visible, [instructions, instructions_visible, toggle_button]) | |
# def update_visibility(radio): # Accept the event argument, even if not used | |
# value = radio # Get the selected value from the radio button | |
# if value == "show": | |
# return gr.Textbox(visible=True) #make it visible | |
# else: | |
# return gr.Textbox(visible=False) | |
# radio.change(update_visibility, radio, instructions) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
model_name = HuggingfaceHubSearch( | |
label="Hub Model ID", | |
placeholder="Search for model id on Huggingface", | |
search_type="model", | |
) | |
with gr.Row(): | |
with gr.Column(): | |
quantization_type = gr.Dropdown( | |
info="Quantization Type", | |
choices=["int4", "int8"], | |
value="int8", | |
filterable=False, | |
show_label=False, | |
) | |
threshold_8 = gr.Textbox( | |
info="Outlier threshold", | |
value=6, | |
interactive=True, | |
show_label=False, | |
visible=True | |
) | |
quant_type_4 = gr.Dropdown( | |
info="The quantization data type in the bnb.nn.Linear4Bit layers", | |
choices=["fp4", "nf4"], | |
value="fp4", | |
visible=False, | |
show_label=False | |
) | |
radio_4 = gr.Radio(["False", "True"], info="Use Double Quant", visible=False, value="False", elem_classes="custom_radio") | |
def update_visibility(quantization_type): | |
return gr.update(visible=(quantization_type=="int8")), gr.update(visible=(quantization_type=="int4")), gr.update(visible=(quantization_type=="int4")) | |
quantization_type.change(fn=update_visibility, inputs=quantization_type, outputs=[threshold_8, quant_type_4, radio_4]) | |
quantized_model_name = gr.Textbox( | |
info="Model Name (optional : to override default)", | |
value="", | |
interactive=True, | |
show_label=False | |
) | |
with gr.Column(): | |
quantize_button = gr.Button("Quantize and Save Model", variant="primary") | |
output_link = gr.Markdown(label="Quantized Model Link", container=True, min_height=80) | |
# Adding CSS styles for the username box | |
demo.css = """ | |
#username-box { | |
background-color: #f0f8ff; /* Light color */ | |
border-radius: 8px; | |
padding: 10px; | |
} | |
""" | |
demo.css = """ | |
.center-button { | |
display: flex; | |
justify-content: center; | |
align-items: center; | |
margin: 0 auto; /* Center horizontally */ | |
} | |
""" | |
quantize_button.click( | |
fn=quantize_and_save, | |
inputs=[model_name, quantization_type, threshold_8, quant_type_4, radio_4, quantized_model_name], | |
outputs=[output_link] | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |
# Launch the app | |
# demo.launch(share=True, debug=True) |