DemahAlmutairi's picture
Update app.py
aa6c44a verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
import spaces
import os
import gc
import torch
# Create the necessary directories
os.makedirs('.gradio/cached_examples/17', exist_ok=True)
def get_model_name(language):
"""Map language choice to the corresponding model."""
model_mapping = {
"English": "microsoft/Phi-3-mini-4k-instruct",
"Arabic": "ALLaM-AI/ALLaM-7B-Instruct-preview"
}
return model_mapping.get(language, "ALLaM-AI/ALLaM-7B-Instruct-preview") # Default to Arabic model
def load_model(model_name):
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map=device,
torch_dtype="auto",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
return_full_text=False,
max_new_tokens=500,
do_sample=True, # Enable sampling for more creative outputs
top_k=50, # Control diversity
top_p=0.95 # Control diversity
)
del model
del tokenizer
return generator
@spaces.GPU
def generate_kids_story(character, setting, language):
model_name = get_model_name(language)
generator = load_model(model_name)
# Define prompt for the AI model
if language == "English":
prompt = (f"Write a short story for kids about a character named {character} who goes on an adventure in {setting}. "
"Make it fun, engaging, and suitable for children.")
else:
prompt = (f"اكتب قصة قصيرة للأطفال عن شخصية اسمها {character} التي تذهب في مغامرة في {setting}. "
"اجعلها ممتعة وجذابة ومناسبة للأطفال.")
messages = [{"role": "user", "content": prompt}]
output = generator(messages)
# Delete model and associated objects
del generator
# Run garbage collection
gc.collect ()
# Empty CUDA cache
torch.cuda.empty_cache()
return output[0]["generated_text"]
css_style = """
body {
background-image: url('https://cdna.artstation.com/p/assets/images/images/074/776/904/large/pietro-chiovaro-r1-castle-chp.jpg?1712916847');
background-size: cover;
background-position: center;
color: #fff; /* General text color */
font-family: 'Arial', sans-serif;
}"""
# Create Gradio interface
demo = gr.Interface(
fn=generate_kids_story,
inputs=[
gr.Textbox(placeholder="Enter a character name (e.g., Benny the Bunny)...", label="Character Name"),
gr.Textbox(placeholder="Enter a setting (e.g., a magical forest)...", label="Setting"),
gr.Dropdown(
choices=["English", "Arabic"],
label="Choose Language",
value="English" # Default to English
)
],
outputs=gr.Textbox(label="Kids' Story"),
title="📖 AI Kids' Story Generator - English & Arabic 📖",
description="Enter a character name and a setting, and AI will generate a fun short story for kids in English or Arabic.",
examples=[
["Benny the Bunny", "a magical forest", "English"],
["علي البطل", "غابة سحرية", "Arabic"],
["Lila the Ladybug", "a garden full of flowers", "English"],
["ليلى الجنية", "حديقة مليئة بالأزهار", "Arabic"]
],
css= css_style,
)
# Launch the Gradio app
demo.launch()