File size: 7,323 Bytes
9c39d26
 
4fafe9d
 
9c39d26
 
 
 
4fafe9d
9c39d26
 
 
4fafe9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c39d26
4fafe9d
 
 
 
 
 
 
 
 
 
 
 
 
 
9c39d26
4fafe9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c39d26
4fafe9d
 
9c39d26
4fafe9d
 
 
 
 
 
 
9c39d26
4fafe9d
 
 
 
 
 
 
9c39d26
4fafe9d
 
 
 
 
 
9c39d26
4fafe9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c39d26
4fafe9d
 
9c39d26
4fafe9d
 
 
 
9c39d26
4fafe9d
 
 
 
9c39d26
4fafe9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c39d26
 
 
4fafe9d
 
 
9c39d26
4fafe9d
 
 
 
 
 
9c39d26
4fafe9d
9c39d26
 
 
 
 
 
4fafe9d
9c39d26
 
 
 
 
 
 
 
 
 
4fafe9d
 
9c39d26
4fafe9d
 
 
9c39d26
 
 
4fafe9d
 
9c39d26
4fafe9d
 
9c39d26
 
4fafe9d
9c39d26
4fafe9d
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import gradio as gr
import torch
import gc
from diffusers import StableDiffusionInpaintPipeline
from PIL import Image
import numpy as np
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
import cv2
import traceback

class RafayyVirtualTryOn:
    def __init__(self):
        try:
            # Clear CUDA cache
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                gc.collect()

            # Use smaller model for stability
            self.inpaint_model = StableDiffusionInpaintPipeline.from_pretrained(
                "runwayml/stable-diffusion-v1-5",
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                safety_checker=None  # Disable safety checker if causing issues
            )
            
            if torch.cuda.is_available():
                self.inpaint_model.to("cuda")
                self.inpaint_model.enable_attention_slicing()  # Reduce memory usage
            
            # Initialize segmentation with error handling
            try:
                self.segmenter = SegformerForSemanticSegmentation.from_pretrained(
                    "mattmdjaga/segformer_b2_clothes",
                    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
                )
                self.processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
            except Exception as e:
                print(f"Segmentation model loading error: {str(e)}")
                raise
                
        except Exception as e:
            print(f"Initialization error: {str(e)}")
            raise

    def preprocess_image(self, image):
        """Safely preprocess input image"""
        try:
            if isinstance(image, np.ndarray):
                image = Image.fromarray(image)
            
            # Ensure image is RGB
            if image.mode != "RGB":
                image = image.convert("RGB")
            
            # Resize if too large
            max_size = 768
            if max(image.size) > max_size:
                ratio = max_size / max(image.size)
                new_size = tuple(int(dim * ratio) for dim in image.size)
                image = image.resize(new_size, Image.LANCZOS)
            
            return image
        except Exception as e:
            raise gr.Error(f"Image preprocessing failed: {str(e)}")

    def get_clothing_mask(self, image):
        """Safely extract clothing mask"""
        try:
            # Convert to RGB if needed
            if isinstance(image, np.ndarray):
                image = Image.fromarray(image)
            if image.mode != "RGB":
                image = image.convert("RGB")

            inputs = self.processor(images=image, return_tensors="pt")
            
            # Move to GPU if available
            if torch.cuda.is_available():
                inputs = {k: v.to("cuda") for k, v in inputs.items()}
                self.segmenter = self.segmenter.to("cuda")

            outputs = self.segmenter(**inputs)
            logits = outputs.logits.squeeze()
            
            # Move back to CPU for numpy operations
            if torch.cuda.is_available():
                logits = logits.cpu()

            clothing_mask = (logits.argmax(0) == 5).float().numpy()
            clothing_mask = (clothing_mask * 255).astype(np.uint8)
            
            # Enhance mask
            kernel = np.ones((5,5), np.uint8)
            clothing_mask = cv2.dilate(clothing_mask, kernel, iterations=1)
            clothing_mask = cv2.GaussianBlur(clothing_mask, (5,5), 0)
            
            return Image.fromarray(clothing_mask)
        except Exception as e:
            raise gr.Error(f"Mask generation failed: {str(e)}")

    def try_on(self, image, prompt, style_strength=0.7, progress=gr.Progress()):
        """Main try-on function with comprehensive error handling"""
        try:
            if image is None:
                raise gr.Error("Please upload an image first")
            if not prompt or prompt.strip() == "":
                raise gr.Error("Please provide a clothing description")

            progress(0.1, desc="Preprocessing image...")
            original_image = self.preprocess_image(image)
            
            progress(0.3, desc="Detecting clothing...")
            mask = self.get_clothing_mask(original_image)
            
            # Clear GPU memory
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                gc.collect()
            
            progress(0.5, desc="Preparing generation...")
            # Enhanced prompt engineering
            full_prompt = f"A person wearing {prompt}, professional photo, detailed, realistic, high quality"
            negative_prompt = "low quality, blurry, distorted, deformed, bad anatomy, unrealistic"
            
            progress(0.7, desc="Generating new clothing...")
            try:
                result = self.inpaint_model(
                    prompt=full_prompt,
                    negative_prompt=negative_prompt,
                    image=original_image,
                    mask_image=mask,
                    num_inference_steps=30,  # Reduced for stability
                    guidance_scale=7.5 * style_strength
                ).images[0]
            except torch.cuda.OutOfMemoryError:
                torch.cuda.empty_cache()
                gc.collect()
                raise gr.Error("Out of memory. Please try with a smaller image.")
            
            progress(1.0, desc="Done!")
            return result
            
        except Exception as e:
            error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
            print(error_msg)  # For logging
            raise gr.Error(str(e))

# Initialize model with error handling
try:
    model = RafayyVirtualTryOn()
except Exception as e:
    print(f"Model initialization failed: {str(e)}")
    raise

# Create Gradio interface with error handling
demo = gr.Interface(
    fn=model.try_on,
    inputs=[
        gr.Image(label="πŸ“Έ Upload Your Photo", type="numpy"),
        gr.Textbox(
            label="🎨 Describe New Clothing",
            placeholder="e.g., 'elegant black suit', 'red dress'",
            lines=2
        ),
        gr.Slider(
            label="Style Strength",
            minimum=0.1,
            maximum=1.0,
            value=0.7,
            step=0.1
        )
    ],
    outputs=gr.Image(label="✨ Result", type="pil"),
    title="🌟 Rafayy's Virtual Try-On Studio 🌟",
    description="""
    <div style="text-align: center;">
        <h3>Transform Your Style with AI</h3>
        <p>Upload a photo and describe the new clothing you want to try on!</p>
    </div>
    """,
    examples=[
        ["example1.jpg", "black suit", 0.7],
        ["example2.jpg", "white dress", 0.7]
    ],
    allow_flagging="never",
    cache_examples=True
)

# Launch with error handling
if __name__ == "__main__":
    try:
        demo.launch(
            share=False,
            server_name="0.0.0.0",
            server_port=7860,
            show_error=True,
            enable_queue=True
        )
    except Exception as e:
        print(f"Launch failed: {str(e)}")
        raise