noxo8888's picture
Create app.py
211a063 verified
raw
history blame
4.09 kB
import gradio as gr
import torch
import gc
from diffusers import StableDiffusionInpaintPipeline
from PIL import Image
import numpy as np
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
import cv2
import traceback
import os
# Set environment variables to prevent warnings and errors
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
os.environ['TORCH_HOME'] = '/tmp/torch'
os.environ['HF_HOME'] = '/tmp/huggingface'
class RafayyVirtualTryOn:
def __init__(self):
try:
# Clear CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Use smaller model for stability
model_id = "runwayml/stable-diffusion-inpainting"
self.inpaint_model = StableDiffusionInpaintPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
safety_checker=None,
cache_dir='/tmp/models'
)
if torch.cuda.is_available():
self.inpaint_model.to("cuda")
self.inpaint_model.enable_attention_slicing()
# Initialize segmentation with error handling
try:
self.segmenter = SegformerForSemanticSegmentation.from_pretrained(
"mattmdjaga/segformer_b2_clothes",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
cache_dir='/tmp/models'
)
self.processor = SegformerImageProcessor.from_pretrained(
"mattmdjaga/segformer_b2_clothes",
cache_dir='/tmp/models'
)
except Exception as e:
print(f"Segmentation model loading error: {str(e)}")
raise
except Exception as e:
print(f"Initialization error: {str(e)}")
raise
# ... (rest of the code remains the same)
# Initialize model with error handling and retry mechanism
def initialize_model(max_retries=3):
for attempt in range(max_retries):
try:
return RafayyVirtualTryOn()
except Exception as e:
if attempt == max_retries - 1:
print(f"Failed to initialize model after {max_retries} attempts: {str(e)}")
raise
print(f"Attempt {attempt + 1} failed, retrying...")
torch.cuda.empty_cache()
gc.collect()
try:
model = initialize_model()
except Exception as e:
print(f"Model initialization failed: {str(e)}")
raise
# Create Gradio interface with error handling
demo = gr.Interface(
fn=model.try_on,
inputs=[
gr.Image(label="πŸ“Έ Upload Your Photo", type="numpy"),
gr.Textbox(
label="🎨 Describe New Clothing",
placeholder="e.g., 'elegant black suit', 'red dress'",
lines=2
),
gr.Slider(
label="Style Strength",
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.1
)
],
outputs=gr.Image(label="✨ Result", type="pil"),
title="🌟 Rafayy's Virtual Try-On Studio 🌟",
description="""
<div style="text-align: center;">
<h3>Transform Your Style with AI</h3>
<p>Upload a photo and describe the new clothing you want to try on!</p>
</div>
""",
examples=[
["example1.jpg", "black suit", 0.7],
["example2.jpg", "white dress", 0.7]
],
allow_flagging="never",
cache_examples=True
)
# Launch with error handling and retry mechanism
if __name__ == "__main__":
try:
demo.launch(
share=False,
server_name="0.0.0.0",
server_port=7860,
show_error=True,
enable_queue=True,
cache_examples=True,
max_threads=4
)
except Exception as e:
print(f"Launch failed: {str(e)}")
raise