Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import gc | |
from diffusers import StableDiffusionInpaintPipeline | |
from PIL import Image | |
import numpy as np | |
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation | |
import cv2 | |
import traceback | |
class RafayyVirtualTryOn: | |
def __init__(self): | |
try: | |
# Clear CUDA cache | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
# Use smaller model for stability | |
self.inpaint_model = StableDiffusionInpaintPipeline.from_pretrained( | |
"runwayml/stable-diffusion-v1-5", | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
safety_checker=None # Disable safety checker if causing issues | |
) | |
if torch.cuda.is_available(): | |
self.inpaint_model.to("cuda") | |
self.inpaint_model.enable_attention_slicing() # Reduce memory usage | |
# Initialize segmentation with error handling | |
try: | |
self.segmenter = SegformerForSemanticSegmentation.from_pretrained( | |
"mattmdjaga/segformer_b2_clothes", | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
) | |
self.processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes") | |
except Exception as e: | |
print(f"Segmentation model loading error: {str(e)}") | |
raise | |
except Exception as e: | |
print(f"Initialization error: {str(e)}") | |
raise | |
def preprocess_image(self, image): | |
"""Safely preprocess input image""" | |
try: | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
# Ensure image is RGB | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
# Resize if too large | |
max_size = 768 | |
if max(image.size) > max_size: | |
ratio = max_size / max(image.size) | |
new_size = tuple(int(dim * ratio) for dim in image.size) | |
image = image.resize(new_size, Image.LANCZOS) | |
return image | |
except Exception as e: | |
raise gr.Error(f"Image preprocessing failed: {str(e)}") | |
def get_clothing_mask(self, image): | |
"""Safely extract clothing mask""" | |
try: | |
# Convert to RGB if needed | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
inputs = self.processor(images=image, return_tensors="pt") | |
# Move to GPU if available | |
if torch.cuda.is_available(): | |
inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
self.segmenter = self.segmenter.to("cuda") | |
outputs = self.segmenter(**inputs) | |
logits = outputs.logits.squeeze() | |
# Move back to CPU for numpy operations | |
if torch.cuda.is_available(): | |
logits = logits.cpu() | |
clothing_mask = (logits.argmax(0) == 5).float().numpy() | |
clothing_mask = (clothing_mask * 255).astype(np.uint8) | |
# Enhance mask | |
kernel = np.ones((5,5), np.uint8) | |
clothing_mask = cv2.dilate(clothing_mask, kernel, iterations=1) | |
clothing_mask = cv2.GaussianBlur(clothing_mask, (5,5), 0) | |
return Image.fromarray(clothing_mask) | |
except Exception as e: | |
raise gr.Error(f"Mask generation failed: {str(e)}") | |
def try_on(self, image, prompt, style_strength=0.7, progress=gr.Progress()): | |
"""Main try-on function with comprehensive error handling""" | |
try: | |
if image is None: | |
raise gr.Error("Please upload an image first") | |
if not prompt or prompt.strip() == "": | |
raise gr.Error("Please provide a clothing description") | |
progress(0.1, desc="Preprocessing image...") | |
original_image = self.preprocess_image(image) | |
progress(0.3, desc="Detecting clothing...") | |
mask = self.get_clothing_mask(original_image) | |
# Clear GPU memory | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
progress(0.5, desc="Preparing generation...") | |
# Enhanced prompt engineering | |
full_prompt = f"A person wearing {prompt}, professional photo, detailed, realistic, high quality" | |
negative_prompt = "low quality, blurry, distorted, deformed, bad anatomy, unrealistic" | |
progress(0.7, desc="Generating new clothing...") | |
try: | |
result = self.inpaint_model( | |
prompt=full_prompt, | |
negative_prompt=negative_prompt, | |
image=original_image, | |
mask_image=mask, | |
num_inference_steps=30, # Reduced for stability | |
guidance_scale=7.5 * style_strength | |
).images[0] | |
except torch.cuda.OutOfMemoryError: | |
torch.cuda.empty_cache() | |
gc.collect() | |
raise gr.Error("Out of memory. Please try with a smaller image.") | |
progress(1.0, desc="Done!") | |
return result | |
except Exception as e: | |
error_msg = f"Error: {str(e)}\n{traceback.format_exc()}" | |
print(error_msg) # For logging | |
raise gr.Error(str(e)) | |
# Initialize model with error handling | |
try: | |
model = RafayyVirtualTryOn() | |
except Exception as e: | |
print(f"Model initialization failed: {str(e)}") | |
raise | |
# Create Gradio interface with error handling | |
demo = gr.Interface( | |
fn=model.try_on, | |
inputs=[ | |
gr.Image(label="πΈ Upload Your Photo", type="numpy"), | |
gr.Textbox( | |
label="π¨ Describe New Clothing", | |
placeholder="e.g., 'elegant black suit', 'red dress'", | |
lines=2 | |
), | |
gr.Slider( | |
label="Style Strength", | |
minimum=0.1, | |
maximum=1.0, | |
value=0.7, | |
step=0.1 | |
) | |
], | |
outputs=gr.Image(label="β¨ Result", type="pil"), | |
title="π Rafayy's Virtual Try-On Studio π", | |
description=""" | |
<div style="text-align: center;"> | |
<h3>Transform Your Style with AI</h3> | |
<p>Upload a photo and describe the new clothing you want to try on!</p> | |
</div> | |
""", | |
examples=[ | |
["example1.jpg", "black suit", 0.7], | |
["example2.jpg", "white dress", 0.7] | |
], | |
allow_flagging="never", | |
cache_examples=True | |
) | |
# Launch with error handling | |
if __name__ == "__main__": | |
try: | |
demo.launch( | |
share=False, | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True, | |
enable_queue=True | |
) | |
except Exception as e: | |
print(f"Launch failed: {str(e)}") | |
raise |