import gradio as gr import torch import gc from diffusers import StableDiffusionInpaintPipeline from PIL import Image import numpy as np from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation import cv2 import traceback import os # Set environment variables to prevent warnings and errors os.environ['TOKENIZERS_PARALLELISM'] = 'false' os.environ['TORCH_HOME'] = '/tmp/torch' os.environ['HF_HOME'] = '/tmp/huggingface' class RafayyVirtualTryOn: def __init__(self): try: # Clear CUDA cache if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() # Use smaller model for stability model_id = "runwayml/stable-diffusion-inpainting" self.inpaint_model = StableDiffusionInpaintPipeline.from_pretrained( model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, safety_checker=None, cache_dir='/tmp/models' ) if torch.cuda.is_available(): self.inpaint_model.to("cuda") self.inpaint_model.enable_attention_slicing() # Initialize segmentation with error handling try: self.segmenter = SegformerForSemanticSegmentation.from_pretrained( "mattmdjaga/segformer_b2_clothes", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, cache_dir='/tmp/models' ) self.processor = SegformerImageProcessor.from_pretrained( "mattmdjaga/segformer_b2_clothes", cache_dir='/tmp/models' ) except Exception as e: print(f"Segmentation model loading error: {str(e)}") raise except Exception as e: print(f"Initialization error: {str(e)}") raise # ... (rest of the code remains the same) # Initialize model with error handling and retry mechanism def initialize_model(max_retries=3): for attempt in range(max_retries): try: return RafayyVirtualTryOn() except Exception as e: if attempt == max_retries - 1: print(f"Failed to initialize model after {max_retries} attempts: {str(e)}") raise print(f"Attempt {attempt + 1} failed, retrying...") torch.cuda.empty_cache() gc.collect() try: model = initialize_model() except Exception as e: print(f"Model initialization failed: {str(e)}") raise # Create Gradio interface with error handling demo = gr.Interface( fn=model.try_on, inputs=[ gr.Image(label="📸 Upload Your Photo", type="numpy"), gr.Textbox( label="🎨 Describe New Clothing", placeholder="e.g., 'elegant black suit', 'red dress'", lines=2 ), gr.Slider( label="Style Strength", minimum=0.1, maximum=1.0, value=0.7, step=0.1 ) ], outputs=gr.Image(label="✨ Result", type="pil"), title="🌟 Rafayy's Virtual Try-On Studio 🌟", description="""
Upload a photo and describe the new clothing you want to try on!