import gradio as gr import torch import gc from diffusers import StableDiffusionInpaintPipeline from PIL import Image import numpy as np from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation import cv2 import traceback class RafayyVirtualTryOn: def __init__(self): try: # Clear CUDA cache if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() # Use smaller model for stability self.inpaint_model = StableDiffusionInpaintPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, safety_checker=None # Disable safety checker if causing issues ) if torch.cuda.is_available(): self.inpaint_model.to("cuda") self.inpaint_model.enable_attention_slicing() # Reduce memory usage # Initialize segmentation with error handling try: self.segmenter = SegformerForSemanticSegmentation.from_pretrained( "mattmdjaga/segformer_b2_clothes", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ) self.processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes") except Exception as e: print(f"Segmentation model loading error: {str(e)}") raise except Exception as e: print(f"Initialization error: {str(e)}") raise def preprocess_image(self, image): """Safely preprocess input image""" try: if isinstance(image, np.ndarray): image = Image.fromarray(image) # Ensure image is RGB if image.mode != "RGB": image = image.convert("RGB") # Resize if too large max_size = 768 if max(image.size) > max_size: ratio = max_size / max(image.size) new_size = tuple(int(dim * ratio) for dim in image.size) image = image.resize(new_size, Image.LANCZOS) return image except Exception as e: raise gr.Error(f"Image preprocessing failed: {str(e)}") def get_clothing_mask(self, image): """Safely extract clothing mask""" try: # Convert to RGB if needed if isinstance(image, np.ndarray): image = Image.fromarray(image) if image.mode != "RGB": image = image.convert("RGB") inputs = self.processor(images=image, return_tensors="pt") # Move to GPU if available if torch.cuda.is_available(): inputs = {k: v.to("cuda") for k, v in inputs.items()} self.segmenter = self.segmenter.to("cuda") outputs = self.segmenter(**inputs) logits = outputs.logits.squeeze() # Move back to CPU for numpy operations if torch.cuda.is_available(): logits = logits.cpu() clothing_mask = (logits.argmax(0) == 5).float().numpy() clothing_mask = (clothing_mask * 255).astype(np.uint8) # Enhance mask kernel = np.ones((5,5), np.uint8) clothing_mask = cv2.dilate(clothing_mask, kernel, iterations=1) clothing_mask = cv2.GaussianBlur(clothing_mask, (5,5), 0) return Image.fromarray(clothing_mask) except Exception as e: raise gr.Error(f"Mask generation failed: {str(e)}") def try_on(self, image, prompt, style_strength=0.7, progress=gr.Progress()): """Main try-on function with comprehensive error handling""" try: if image is None: raise gr.Error("Please upload an image first") if not prompt or prompt.strip() == "": raise gr.Error("Please provide a clothing description") progress(0.1, desc="Preprocessing image...") original_image = self.preprocess_image(image) progress(0.3, desc="Detecting clothing...") mask = self.get_clothing_mask(original_image) # Clear GPU memory if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() progress(0.5, desc="Preparing generation...") # Enhanced prompt engineering full_prompt = f"A person wearing {prompt}, professional photo, detailed, realistic, high quality" negative_prompt = "low quality, blurry, distorted, deformed, bad anatomy, unrealistic" progress(0.7, desc="Generating new clothing...") try: result = self.inpaint_model( prompt=full_prompt, negative_prompt=negative_prompt, image=original_image, mask_image=mask, num_inference_steps=30, # Reduced for stability guidance_scale=7.5 * style_strength ).images[0] except torch.cuda.OutOfMemoryError: torch.cuda.empty_cache() gc.collect() raise gr.Error("Out of memory. Please try with a smaller image.") progress(1.0, desc="Done!") return result except Exception as e: error_msg = f"Error: {str(e)}\n{traceback.format_exc()}" print(error_msg) # For logging raise gr.Error(str(e)) # Initialize model with error handling try: model = RafayyVirtualTryOn() except Exception as e: print(f"Model initialization failed: {str(e)}") raise # Create Gradio interface with error handling demo = gr.Interface( fn=model.try_on, inputs=[ gr.Image(label="📸 Upload Your Photo", type="numpy"), gr.Textbox( label="🎨 Describe New Clothing", placeholder="e.g., 'elegant black suit', 'red dress'", lines=2 ), gr.Slider( label="Style Strength", minimum=0.1, maximum=1.0, value=0.7, step=0.1 ) ], outputs=gr.Image(label="✨ Result", type="pil"), title="🌟 Rafayy's Virtual Try-On Studio 🌟", description="""
Upload a photo and describe the new clothing you want to try on!