Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from huggingface_hub import ModelCard, DatasetCard, model_info, dataset_info | |
import logging | |
from typing import Tuple, Literal | |
import functools | |
import spaces | |
from cachetools import TTLCache | |
from cachetools.func import ttl_cache | |
import time | |
import os | |
import json | |
os.environ['HF_TRANSFER'] = "1" | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Global variables | |
MODEL_NAME = "davanstrien/Smol-Hub-tldr" | |
model = None | |
tokenizer = None | |
device = None | |
CACHE_TTL = 6 * 60 * 60 # 6 hours in seconds | |
CACHE_MAXSIZE = 100 | |
def load_model(): | |
global model, tokenizer, device | |
logger.info("Loading model and tokenizer...") | |
try: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) | |
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) | |
model = model.to(device) | |
model.eval() | |
return True | |
except Exception as e: | |
logger.error(f"Failed to load model: {e}") | |
return False | |
def get_card_info(hub_id: str, repo_type: str = "auto") -> Tuple[str, str]: | |
"""Get card information from a Hugging Face hub_id.""" | |
model_exists = False | |
dataset_exists = False | |
model_text = None | |
dataset_text = None | |
# Handle based on repo type | |
if repo_type == "auto": | |
# Try getting model card | |
try: | |
info = model_info(hub_id) | |
card = ModelCard.load(hub_id) | |
model_exists = True | |
model_text = card.text | |
except Exception as e: | |
logger.debug(f"No model card found for {hub_id}: {e}") | |
# Try getting dataset card | |
try: | |
info = dataset_info(hub_id) | |
card = DatasetCard.load(hub_id) | |
dataset_exists = True | |
dataset_text = card.text | |
except Exception as e: | |
logger.debug(f"No dataset card found for {hub_id}: {e}") | |
elif repo_type == "model": | |
try: | |
info = model_info(hub_id) | |
card = ModelCard.load(hub_id) | |
model_exists = True | |
model_text = card.text | |
except Exception as e: | |
logger.error(f"Failed to get model card for {hub_id}: {e}") | |
raise ValueError(f"Could not find model with id {hub_id}") | |
elif repo_type == "dataset": | |
try: | |
info = dataset_info(hub_id) | |
card = DatasetCard.load(hub_id) | |
dataset_exists = True | |
dataset_text = card.text | |
except Exception as e: | |
logger.error(f"Failed to get dataset card for {hub_id}: {e}") | |
raise ValueError(f"Could not find dataset with id {hub_id}") | |
else: | |
raise ValueError(f"Invalid repo_type: {repo_type}. Must be 'auto', 'model', or 'dataset'") | |
# Handle different cases | |
if model_exists and dataset_exists: | |
return "both", (model_text, dataset_text) | |
elif model_exists: | |
return "model", model_text | |
elif dataset_exists: | |
return "dataset", dataset_text | |
else: | |
raise ValueError(f"Could not find model or dataset with id {hub_id}") | |
def _generate_summary_gpu(card_text: str, card_type: str) -> str: | |
"""Internal function that runs on GPU.""" | |
# Determine prefix based on card type | |
prefix = "<MODEL_CARD>" if card_type == "model" else "<DATASET_CARD>" | |
# Format input according to the chat template | |
messages = [{"role": "user", "content": f"{prefix}{card_text[:5000]}"}] | |
inputs = tokenizer.apply_chat_template( | |
messages, add_generation_prompt=True, return_tensors="pt" | |
) | |
inputs = inputs.to(device) | |
# Generate with optimized settings | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs, | |
max_new_tokens=60, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
temperature=0.4, | |
do_sample=True, | |
use_cache=True, | |
) | |
# Extract and clean up the summary | |
input_length = inputs.shape[1] | |
response = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=False) | |
# Extract just the summary part | |
try: | |
summary = response.split("<CARD_SUMMARY>")[-1].split("</CARD_SUMMARY>")[0].strip() | |
except IndexError: | |
summary = response.strip() | |
return summary | |
def generate_summary(card_text: str, card_type: str) -> str: | |
"""Cached wrapper for generate_summary with TTL.""" | |
return _generate_summary_gpu(card_text, card_type) | |
def summarize(hub_id: str = "", repo_type: str = "auto") -> str: | |
"""Interface function for Gradio. Returns JSON format.""" | |
try: | |
if hub_id: | |
# Fetch card information with specified repo_type | |
card_type, card_text = get_card_info(hub_id, repo_type) | |
if card_type == "both": | |
model_text, dataset_text = card_text | |
model_summary = generate_summary(model_text, "model") | |
dataset_summary = generate_summary(dataset_text, "dataset") | |
return json.dumps({ | |
"type": "both", | |
"hub_id": hub_id, | |
"model_summary": model_summary, | |
"dataset_summary": dataset_summary | |
}) | |
else: | |
summary = generate_summary(card_text, card_type) | |
return json.dumps({ | |
"summary": summary, | |
"type": card_type, | |
"hub_id": hub_id | |
}) | |
else: | |
return json.dumps({"error": "Hub ID must be provided"}) | |
except Exception as e: | |
return json.dumps({"error": str(e)}) | |
def create_interface(): | |
interface = gr.Interface( | |
fn=summarize, | |
inputs=[ | |
gr.Textbox(label="Hub ID", placeholder="e.g., huggingface/llama-7b"), | |
gr.Radio( | |
choices=["auto", "model", "dataset"], | |
value="auto", | |
label="Repository Type", | |
info="Choose 'auto' to detect automatically, or specify the repository type" | |
) | |
], | |
outputs=gr.JSON(label="Output"), | |
title="Hugging Face Hub TLDR Generator", | |
description="Generate concise summaries of model and dataset cards from the Hugging Face Hub.", | |
) | |
return interface | |
if __name__ == "__main__": | |
if load_model(): | |
interface = create_interface() | |
interface.launch() | |
else: | |
print("Failed to load model. Please check the logs for details.") |