import gradio as gr import torch from transformers import AutoConfig, AutoModelForCausalLM from janus.models import MultiModalityCausalLM, VLChatProcessor from PIL import Image import numpy as np import spaces import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Constants DEFAULT_WIDTH = 384 DEFAULT_HEIGHT = 384 PARALLEL_SIZE = 5 PATCH_SIZE = 16 # Load model and processor with error handling def load_model(): try: model_path = "deepseek-ai/Janus-Pro-7B" config = AutoConfig.from_pretrained(model_path) language_config = config.language_config language_config._attn_implementation = 'eager' device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Loading model on device: {device}") vl_gpt = AutoModelForCausalLM.from_pretrained( model_path, language_config=language_config, trust_remote_code=True, torch_dtype=torch.bfloat16 if device.type == "cuda" else torch.float32 ).to(device) vl_chat_processor = VLChatProcessor.from_pretrained(model_path) return vl_gpt, vl_chat_processor, device except Exception as e: logger.error(f"Model loading failed: {str(e)}") raise RuntimeError("Failed to load model. Please check the model path and dependencies.") try: vl_gpt, vl_chat_processor, device = load_model() tokenizer = vl_chat_processor.tokenizer except RuntimeError as e: raise e # Helper functions with improved memory management def generate(input_ids, width, height, cfg_weight=5, temperature=1.0, parallel_size=5, progress=None): try: torch.cuda.empty_cache() tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int, device=device) for i in range(parallel_size * 2): tokens[i, :] = input_ids if i % 2 != 0: tokens[i, 1:-1] = vl_chat_processor.pad_id with torch.no_grad(): inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens) generated_tokens = torch.zeros((parallel_size, 576), dtype=torch.int, device=device) pkv = None total_steps = 576 for i in range(total_steps): if progress is not None: progress((i + 1) / total_steps, desc="Generating image tokens") outputs = vl_gpt.language_model.model( inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv ) pkv = outputs.past_key_values hidden_states = outputs.last_hidden_state logits = vl_gpt.gen_head(hidden_states[:, -1, :]) logit_cond = logits[0::2, :] logit_uncond = logits[1::2, :] logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) probs = torch.softmax(logits / temperature, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated_tokens[:, i] = next_token.squeeze(dim=-1) next_token = torch.cat([next_token.unsqueeze(dim=1)] * 2, dim=1).view(-1) img_embeds = vl_gpt.prepare_gen_img_embeds(next_token) inputs_embeds = img_embeds.unsqueeze(dim=1) return generated_tokens except RuntimeError as e: logger.error(f"Generation error: {str(e)}") raise RuntimeError("Generation failed due to memory constraints. Try reducing the parallel size.") finally: torch.cuda.empty_cache() def unpack(patches, width, height, parallel_size=5): try: patches = patches.detach().to(device='cpu', dtype=torch.float32).numpy() patches = patches.transpose(0, 2, 3, 1) patches = np.clip((patches + 1) / 2 * 255, 0, 255) return [Image.fromarray(patch.astype(np.uint8)) for patch in patches] except Exception as e: logger.error(f"Unpacking error: {str(e)}") raise RuntimeError("Failed to process generated image data.") @torch.inference_mode() @spaces.GPU(duration=120) def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0, progress=gr.Progress()): try: if not prompt.strip(): raise gr.Error("Please enter a valid prompt.") if progress is not None: progress(0, desc="Initializing...") torch.cuda.empty_cache() # Seed management if seed is None: seed = torch.seed() else: seed = int(seed) torch.manual_seed(seed) if device.type == "cuda": torch.cuda.manual_seed(seed) messages = [{'role': '<|User|>', 'content': prompt}, {'role': '<|Assistant|>', 'content': ''}] text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( conversations=messages, sft_format=vl_chat_processor.sft_format, system_prompt='' ) + vl_chat_processor.image_start_tag input_ids = torch.tensor(tokenizer.encode(text), dtype=torch.long, device=device) if progress is not None: progress(0.1, desc="Generating image tokens...") generated_tokens = generate( input_ids, DEFAULT_WIDTH, DEFAULT_HEIGHT, cfg_weight=guidance, temperature=t2i_temperature, parallel_size=PARALLEL_SIZE, progress=progress ) if progress is not None: progress(0.9, desc="Processing images...") patches = vl_gpt.gen_vision_model.decode_code( generated_tokens.to(dtype=torch.int), shape=[PARALLEL_SIZE, 8, DEFAULT_WIDTH // PATCH_SIZE, DEFAULT_HEIGHT // PATCH_SIZE] ) images = unpack(patches, DEFAULT_WIDTH, DEFAULT_HEIGHT, PARALLEL_SIZE) return images except Exception as e: logger.error(f"Generation failed: {str(e)}", exc_info=True) if "index out of range" in str(e).lower(): raise gr.Error("Image generation failed due to internal error. Please try again with different parameters.") else: raise gr.Error(f"Image generation failed: {str(e)}") def create_interface(): with gr.Blocks(title="Janus-Pro-7B Image Generator", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # Text-to-Image Generation with Janus-Pro-7B **Generate high-quality images from text prompts using DeepSeek's advanced multimodal AI model.** """) with gr.Row(): with gr.Column(scale=3): prompt_input = gr.Textbox( label="Prompt", placeholder="Describe the image you want to generate...", lines=3 ) generate_btn = gr.Button("Generate Images", variant="primary") with gr.Accordion("Advanced Settings", open=False): with gr.Group(): seed_input = gr.Number( label="Seed", value=None, precision=0, info="Leave empty for random seed" ) guidance_slider = gr.Slider( label="CFG Guidance Weight", minimum=3, maximum=10, value=5, step=0.5, info="Higher values = more prompt adherence, lower values = more creativity" ) temp_slider = gr.Slider( label="Temperature", minimum=0.1, maximum=1.0, value=1.0, step=0.1, info="Higher values = more randomness, lower values = more deterministic" ) with gr.Column(scale=2): output_gallery = gr.Gallery( label="Generated Images", columns=2, height=600, preview=True ) status = gr.Textbox( label="Status", interactive=False ) gr.Examples( examples=[ ["A futuristic cityscape at sunset with flying cars and holographic advertisements"], ["An astronaut riding a horse in photorealistic style"], ["A cute robotic cat sitting on a stack of ancient books, digital art"] ], inputs=prompt_input ) gr.Markdown(""" ## Model Information - **Model:** [Janus-Pro-7B](https://huggingface.co/deepseek-ai/Janus-Pro-7B) - **Output Resolution:** 384x384 pixels - **Parallel Generation:** 5 images per request """) # Footer Section gr.Markdown("""