Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from transformers import AutoConfig, AutoModelForCausalLM | |
from janus.models import MultiModalityCausalLM, VLChatProcessor | |
from PIL import Image | |
import numpy as np | |
import spaces | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Constants | |
DEFAULT_WIDTH = 384 | |
DEFAULT_HEIGHT = 384 | |
PARALLEL_SIZE = 5 | |
PATCH_SIZE = 16 | |
# Load model and processor with error handling | |
def load_model(): | |
try: | |
model_path = "deepseek-ai/Janus-Pro-7B" | |
config = AutoConfig.from_pretrained(model_path) | |
language_config = config.language_config | |
language_config._attn_implementation = 'eager' | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
logger.info(f"Loading model on device: {device}") | |
vl_gpt = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
language_config=language_config, | |
trust_remote_code=True, | |
torch_dtype=torch.bfloat16 if device.type == "cuda" else torch.float32 | |
).to(device) | |
vl_chat_processor = VLChatProcessor.from_pretrained(model_path) | |
return vl_gpt, vl_chat_processor, device | |
except Exception as e: | |
logger.error(f"Model loading failed: {str(e)}") | |
raise RuntimeError("Failed to load model. Please check the model path and dependencies.") | |
try: | |
vl_gpt, vl_chat_processor, device = load_model() | |
tokenizer = vl_chat_processor.tokenizer | |
except RuntimeError as e: | |
raise e | |
# Helper functions with improved memory management | |
def generate(input_ids, width, height, cfg_weight=5, temperature=1.0, parallel_size=5, progress=None): | |
try: | |
torch.cuda.empty_cache() | |
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int, device=device) | |
for i in range(parallel_size * 2): | |
tokens[i, :] = input_ids | |
if i % 2 != 0: | |
tokens[i, 1:-1] = vl_chat_processor.pad_id | |
with torch.no_grad(): | |
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens) | |
generated_tokens = torch.zeros((parallel_size, 576), dtype=torch.int, device=device) | |
pkv = None | |
total_steps = 576 | |
for i in range(total_steps): | |
if progress is not None: | |
progress((i + 1) / total_steps, desc="Generating image tokens") | |
outputs = vl_gpt.language_model.model( | |
inputs_embeds=inputs_embeds, | |
use_cache=True, | |
past_key_values=pkv | |
) | |
pkv = outputs.past_key_values | |
hidden_states = outputs.last_hidden_state | |
logits = vl_gpt.gen_head(hidden_states[:, -1, :]) | |
logit_cond = logits[0::2, :] | |
logit_uncond = logits[1::2, :] | |
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) | |
probs = torch.softmax(logits / temperature, dim=-1) | |
next_token = torch.multinomial(probs, num_samples=1) | |
generated_tokens[:, i] = next_token.squeeze(dim=-1) | |
next_token = torch.cat([next_token.unsqueeze(dim=1)] * 2, dim=1).view(-1) | |
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token) | |
inputs_embeds = img_embeds.unsqueeze(dim=1) | |
return generated_tokens | |
except RuntimeError as e: | |
logger.error(f"Generation error: {str(e)}") | |
raise RuntimeError("Generation failed due to memory constraints. Try reducing the parallel size.") | |
finally: | |
torch.cuda.empty_cache() | |
def unpack(patches, width, height, parallel_size=5): | |
try: | |
patches = patches.detach().to(device='cpu', dtype=torch.float32).numpy() | |
patches = patches.transpose(0, 2, 3, 1) | |
patches = np.clip((patches + 1) / 2 * 255, 0, 255) | |
return [Image.fromarray(patch.astype(np.uint8)) for patch in patches] | |
except Exception as e: | |
logger.error(f"Unpacking error: {str(e)}") | |
raise RuntimeError("Failed to process generated image data.") | |
def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0, progress=gr.Progress()): | |
try: | |
if not prompt.strip(): | |
raise gr.Error("Please enter a valid prompt.") | |
if progress is not None: | |
progress(0, desc="Initializing...") | |
torch.cuda.empty_cache() | |
# Seed management | |
if seed is None: | |
seed = torch.seed() | |
else: | |
seed = int(seed) | |
torch.manual_seed(seed) | |
if device.type == "cuda": | |
torch.cuda.manual_seed(seed) | |
messages = [{'role': '<|User|>', 'content': prompt}, {'role': '<|Assistant|>', 'content': ''}] | |
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( | |
conversations=messages, | |
sft_format=vl_chat_processor.sft_format, | |
system_prompt='' | |
) + vl_chat_processor.image_start_tag | |
input_ids = torch.tensor(tokenizer.encode(text), dtype=torch.long, device=device) | |
if progress is not None: | |
progress(0.1, desc="Generating image tokens...") | |
generated_tokens = generate( | |
input_ids, | |
DEFAULT_WIDTH, | |
DEFAULT_HEIGHT, | |
cfg_weight=guidance, | |
temperature=t2i_temperature, | |
parallel_size=PARALLEL_SIZE, | |
progress=progress | |
) | |
if progress is not None: | |
progress(0.9, desc="Processing images...") | |
patches = vl_gpt.gen_vision_model.decode_code( | |
generated_tokens.to(dtype=torch.int), | |
shape=[PARALLEL_SIZE, 8, DEFAULT_WIDTH // PATCH_SIZE, DEFAULT_HEIGHT // PATCH_SIZE] | |
) | |
images = unpack(patches, DEFAULT_WIDTH, DEFAULT_HEIGHT, PARALLEL_SIZE) | |
return images | |
except Exception as e: | |
logger.error(f"Generation failed: {str(e)}", exc_info=True) | |
if "index out of range" in str(e).lower(): | |
raise gr.Error("Image generation failed due to internal error. Please try again with different parameters.") | |
else: | |
raise gr.Error(f"Image generation failed: {str(e)}") | |
def create_interface(): | |
with gr.Blocks(title="Janus-Pro-7B Image Generator", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# Text-to-Image Generation with Janus-Pro-7B | |
**Generate high-quality images from text prompts using DeepSeek's advanced multimodal AI model.** | |
""") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
prompt_input = gr.Textbox( | |
label="Prompt", | |
placeholder="Describe the image you want to generate...", | |
lines=3 | |
) | |
generate_btn = gr.Button("Generate Images", variant="primary") | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Group(): | |
seed_input = gr.Number( | |
label="Seed", | |
value=None, | |
precision=0, | |
info="Leave empty for random seed" | |
) | |
guidance_slider = gr.Slider( | |
label="CFG Guidance Weight", | |
minimum=3, | |
maximum=10, | |
value=5, | |
step=0.5, | |
info="Higher values = more prompt adherence, lower values = more creativity" | |
) | |
temp_slider = gr.Slider( | |
label="Temperature", | |
minimum=0.1, | |
maximum=1.0, | |
value=1.0, | |
step=0.1, | |
info="Higher values = more randomness, lower values = more deterministic" | |
) | |
with gr.Column(scale=2): | |
output_gallery = gr.Gallery( | |
label="Generated Images", | |
columns=2, | |
height=600, | |
preview=True | |
) | |
status = gr.Textbox( | |
label="Status", | |
interactive=False | |
) | |
gr.Examples( | |
examples=[ | |
["A futuristic cityscape at sunset with flying cars and holographic advertisements"], | |
["An astronaut riding a horse in photorealistic style"], | |
["A cute robotic cat sitting on a stack of ancient books, digital art"] | |
], | |
inputs=prompt_input | |
) | |
gr.Markdown(""" | |
## Model Information | |
- **Model:** [Janus-Pro-7B](https://huggingface.co/deepseek-ai/Janus-Pro-7B) | |
- **Output Resolution:** 384x384 pixels | |
- **Parallel Generation:** 5 images per request | |
""") | |
# Footer Section | |
gr.Markdown(""" | |
<hr style="margin-top: 2em; margin-bottom: 1em;"> | |
<div style="text-align: center; color: #666; font-size: 0.9em;"> | |
Created with ❤️ by <a href="https://bilsimaging.com" target="_blank" style="color: #2563eb; text-decoration: none;">bilsimaging.com</a> | |
</div> | |
""") | |
# Visitor Badge | |
gr.HTML(""" | |
<div style="text-align: center; margin-top: 1em;"> | |
<a href="https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2FDeepseekJanusPro%2F"> | |
<img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2FDeepseekJanusPro%2F&countColor=%23263759" | |
alt="Visitor Badge" | |
style="display: inline-block; margin: 0 auto;"> | |
</a> | |
</div> | |
""") | |
generate_btn.click( | |
generate_image, | |
inputs=[prompt_input, seed_input, guidance_slider, temp_slider], | |
outputs=output_gallery, | |
api_name="generate" | |
) | |
demo.load( | |
fn=lambda: f"Device Status: {'GPU ✅' if device.type == 'cuda' else 'CPU ⚠️'}", | |
outputs=status, | |
queue=False | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch(share=True) |