Realtime-FLUX / app.py
ginipick's picture
Update app.py
b02e794 verified
import gradio as gr
import numpy as np
import random
import spaces
import torch
import time
import os
from diffusers import DiffusionPipeline
from custom_pipeline import FLUXPipelineWithIntermediateOutputs
from transformers import pipeline
# Translation model loading
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
# Constants
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
DEFAULT_WIDTH = 1024
DEFAULT_HEIGHT = 1024
DEFAULT_INFERENCE_STEPS = 1
# Device and model setup
dtype = torch.float16
pipe = FLUXPipelineWithIntermediateOutputs.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=dtype
).to("cuda")
torch.cuda.empty_cache()
# Menu labels dictionary
english_labels = {
"Generated Image": "Generated Image",
"Prompt": "Prompt",
"Enhance Image": "Enhance Image",
"Advanced Options": "Advanced Options",
"Seed": "Seed",
"Randomize Seed": "Randomize Seed",
"Width": "Width",
"Height": "Height",
"Inference Steps": "Inference Steps",
"Inspiration Gallery": "Inspiration Gallery"
}
def translate_if_korean(text):
if any('\u3131' <= char <= '\u3163' or '\uac00' <= char <= '\ud7a3' for char in text):
return translator(text)[0]['translation_text']
return text
# Modified inference function to always use random seed for examples
@spaces.GPU(duration=25)
def generate_image(prompt, seed=None, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT, randomize_seed=True, num_inference_steps=DEFAULT_INFERENCE_STEPS):
prompt = translate_if_korean(prompt)
# Always generate a random seed if none provided or randomize_seed is True
if seed is None or randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
start_time = time.time()
for img in pipe.generate_images(
prompt=prompt,
guidance_scale=0,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator
):
latency = f"Processing Time: {(time.time()-start_time):.2f} seconds"
yield img, seed, latency
# Function specifically for examples that always uses random seeds
def generate_example_image(prompt):
return generate_image(prompt, randomize_seed=True)
# Example prompts
examples = [
"비너 슈니첼의 애니메이션 일러스트레이션",
"A steampunk owl wearing Victorian-era clothing and reading a mechanical book",
"A floating island made of books with waterfalls of knowledge cascading down",
"A bioluminescent forest where mushrooms glow like neon signs in a cyberpunk city",
"An ancient temple being reclaimed by nature, with robots performing archaeology",
"A cosmic coffee shop where baristas are constellations serving drinks made of stardust"
]
css = """
footer {
visibility: hidden;
}
"""
# --- Gradio UI ---
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
with gr.Column(elem_id="app-container"):
with gr.Row():
with gr.Column(scale=3):
result = gr.Image(label=english_labels["Generated Image"], show_label=False, interactive=False)
with gr.Column(scale=1):
prompt = gr.Text(
label=english_labels["Prompt"],
placeholder="Describe the image you want to generate...",
lines=3,
show_label=False,
container=False,
)
enhanceBtn = gr.Button(f"🚀 {english_labels['Enhance Image']}")
with gr.Column(english_labels["Advanced Options"]):
with gr.Row():
latency = gr.Text(show_label=False)
with gr.Row():
seed = gr.Number(label=english_labels["Seed"], value=42, precision=0)
randomize_seed = gr.Checkbox(label=english_labels["Randomize Seed"], value=True)
with gr.Row():
width = gr.Slider(label=english_labels["Width"], minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=DEFAULT_WIDTH)
height = gr.Slider(label=english_labels["Height"], minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=DEFAULT_HEIGHT)
num_inference_steps = gr.Slider(label=english_labels["Inference Steps"], minimum=1, maximum=4, step=1, value=DEFAULT_INFERENCE_STEPS)
with gr.Row():
gr.Markdown(f"### 🌟 {english_labels['Inspiration Gallery']}")
with gr.Row():
gr.Examples(
examples=examples,
fn=generate_example_image, # Use the example-specific function
inputs=[prompt],
outputs=[result, seed],
cache_examples=False # Disable caching to ensure new generation each time
)
# Event handling
enhanceBtn.click(
fn=generate_image,
inputs=[prompt, seed, width, height],
outputs=[result, seed, latency],
show_progress="hidden",
show_api=False,
queue=False
)
gr.on(
triggers=[prompt.input, width.input, height.input, num_inference_steps.input],
fn=generate_image,
inputs=[prompt, seed, width, height, randomize_seed, num_inference_steps],
outputs=[result, seed, latency],
show_progress="hidden",
show_api=False,
trigger_mode="always_last",
queue=False
)
demo.launch()