Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import gc | |
from diffusers import StableDiffusionInpaintPipeline | |
from PIL import Image | |
import numpy as np | |
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation | |
import cv2 | |
import traceback | |
import os | |
# Set environment variables to prevent warnings and errors | |
os.environ['TOKENIZERS_PARALLELISM'] = 'false' | |
os.environ['TORCH_HOME'] = '/tmp/torch' | |
os.environ['HF_HOME'] = '/tmp/huggingface' | |
class RafayyVirtualTryOn: | |
def __init__(self): | |
try: | |
# Clear CUDA cache | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
# Use smaller model for stability | |
model_id = "runwayml/stable-diffusion-inpainting" | |
self.inpaint_model = StableDiffusionInpaintPipeline.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
safety_checker=None, | |
cache_dir='/tmp/models' | |
) | |
if torch.cuda.is_available(): | |
self.inpaint_model.to("cuda") | |
self.inpaint_model.enable_attention_slicing() | |
# Initialize segmentation with error handling | |
try: | |
self.segmenter = SegformerForSemanticSegmentation.from_pretrained( | |
"mattmdjaga/segformer_b2_clothes", | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
cache_dir='/tmp/models' | |
) | |
self.processor = SegformerImageProcessor.from_pretrained( | |
"mattmdjaga/segformer_b2_clothes", | |
cache_dir='/tmp/models' | |
) | |
except Exception as e: | |
print(f"Segmentation model loading error: {str(e)}") | |
raise | |
except Exception as e: | |
print(f"Initialization error: {str(e)}") | |
raise | |
# ... (rest of the code remains the same) | |
# Initialize model with error handling and retry mechanism | |
def initialize_model(max_retries=3): | |
for attempt in range(max_retries): | |
try: | |
return RafayyVirtualTryOn() | |
except Exception as e: | |
if attempt == max_retries - 1: | |
print(f"Failed to initialize model after {max_retries} attempts: {str(e)}") | |
raise | |
print(f"Attempt {attempt + 1} failed, retrying...") | |
torch.cuda.empty_cache() | |
gc.collect() | |
try: | |
model = initialize_model() | |
except Exception as e: | |
print(f"Model initialization failed: {str(e)}") | |
raise | |
# Create Gradio interface with error handling | |
demo = gr.Interface( | |
fn=model.try_on, | |
inputs=[ | |
gr.Image(label="πΈ Upload Your Photo", type="numpy"), | |
gr.Textbox( | |
label="π¨ Describe New Clothing", | |
placeholder="e.g., 'elegant black suit', 'red dress'", | |
lines=2 | |
), | |
gr.Slider( | |
label="Style Strength", | |
minimum=0.1, | |
maximum=1.0, | |
value=0.7, | |
step=0.1 | |
) | |
], | |
outputs=gr.Image(label="β¨ Result", type="pil"), | |
title="π Rafayy's Virtual Try-On Studio π", | |
description=""" | |
<div style="text-align: center;"> | |
<h3>Transform Your Style with AI</h3> | |
<p>Upload a photo and describe the new clothing you want to try on!</p> | |
</div> | |
""", | |
examples=[ | |
["example1.jpg", "black suit", 0.7], | |
["example2.jpg", "white dress", 0.7] | |
], | |
allow_flagging="never", | |
cache_examples=True | |
) | |
# Launch with error handling and retry mechanism | |
if __name__ == "__main__": | |
try: | |
demo.launch( | |
share=False, | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True, | |
enable_queue=True, | |
cache_examples=True, | |
max_threads=4 | |
) | |
except Exception as e: | |
print(f"Launch failed: {str(e)}") | |
raise |