Bils's picture
Update app.py
00f9f38 verified
raw
history blame
4.62 kB
import gradio as gr
import torch
from transformers import AutoConfig, AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor
from PIL import Image
import numpy as np
import spaces # Ensure this is available
# Load the model and processor
model_path = "deepseek-ai/Janus-Pro-7B"
config = AutoConfig.from_pretrained(model_path)
language_config = config.language_config
language_config._attn_implementation = 'eager'
vl_gpt = AutoModelForCausalLM.from_pretrained(
model_path,
language_config=language_config,
trust_remote_code=True
)
vl_gpt = vl_gpt.to(torch.bfloat16).cuda() if torch.cuda.is_available() else vl_gpt.to(torch.float16)
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Helper functions
def generate(input_ids, width, height, cfg_weight=5, temperature=1.0, parallel_size=5, patch_size=16):
torch.cuda.empty_cache()
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
for i in range(parallel_size * 2):
tokens[i, :] = input_ids
if i % 2 != 0:
tokens[i, 1:-1] = vl_chat_processor.pad_id
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
generated_tokens = torch.zeros((parallel_size, 576), dtype=torch.int).to(cuda_device)
pkv = None
for i in range(576):
with torch.no_grad():
outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv)
pkv = outputs.past_key_values
hidden_states = outputs.last_hidden_state
logits = vl_gpt.gen_head(hidden_states[:, -1, :])
logit_cond = logits[0::2, :]
logit_uncond = logits[1::2, :]
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
probs = torch.softmax(logits / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated_tokens[:, i] = next_token.squeeze(dim=-1)
next_token = torch.cat([next_token.unsqueeze(dim=1)] * 2, dim=1).view(-1)
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
inputs_embeds = img_embeds.unsqueeze(dim=1)
patches = vl_gpt.gen_vision_model.decode_code(
generated_tokens.to(dtype=torch.int),
shape=[parallel_size, 8, width // patch_size, height // patch_size]
)
return patches
def unpack(patches, width, height, parallel_size=5):
patches = patches.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
patches = np.clip((patches + 1) / 2 * 255, 0, 255)
images = [Image.fromarray(patches[i].astype(np.uint8)) for i in range(parallel_size)]
return images
@torch.inference_mode()
@spaces.GPU(duration=120)
def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0):
torch.cuda.empty_cache()
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
width, height, parallel_size = 384, 384, 5
messages = [
{'role': '<|User|>', 'content': prompt},
{'role': '<|Assistant|>', 'content': ''}
]
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
conversations=messages, sft_format=vl_chat_processor.sft_format, system_prompt=''
)
text += vl_chat_processor.image_start_tag
input_ids = torch.LongTensor(tokenizer.encode(text))
patches = generate(input_ids, width, height, cfg_weight=guidance, temperature=t2i_temperature, parallel_size=parallel_size)
return unpack(patches, width, height, parallel_size)
# Gradio interface
def create_interface():
with gr.Blocks() as demo:
gr.Markdown("# Text-to-Image Generation")
prompt_input = gr.Textbox(label="Prompt (describe the image)")
seed_input = gr.Number(label="Seed (Optional)", value=12345, precision=0)
guidance_slider = gr.Slider(label="CFG Guidance Weight", minimum=1, maximum=10, value=5, step=0.5)
temperature_slider = gr.Slider(label="Temperature", minimum=0, maximum=1, value=1.0, step=0.05)
generate_button = gr.Button("Generate Images")
output_gallery = gr.Gallery(label="Generated Images", columns=2, height=300)
generate_button.click(
generate_image,
inputs=[prompt_input, seed_input, guidance_slider, temperature_slider],
outputs=output_gallery
)
return demo
demo = create_interface()
demo.launch(share=True)