noxo8888 commited on
Commit
4fafe9d
·
verified ·
1 Parent(s): 9c39d26

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -160
app.py CHANGED
@@ -1,146 +1,165 @@
1
  import gradio as gr
2
  import torch
3
- from diffusers import StableDiffusionInpaintPipeline, StableDiffusionXLInpaintPipeline
 
4
  from PIL import Image
5
  import numpy as np
6
  from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
7
  import cv2
 
8
 
9
  class RafayyVirtualTryOn:
10
  def __init__(self):
11
- # Initialize SDXL for better quality
12
- self.inpaint_model = StableDiffusionXLInpaintPipeline.from_pretrained(
13
- "stabilityai/stable-diffusion-xl-base-1.0",
14
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
15
- variant="fp16"
16
- )
17
- if torch.cuda.is_available():
18
- self.inpaint_model.to("cuda")
 
 
 
 
 
 
 
 
19
 
20
- # Initialize enhanced segmentation model
21
- self.segmenter = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
22
- self.processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
23
-
24
- def enhance_mask(self, mask):
25
- """Enhance the clothing mask for better results"""
26
- kernel = np.ones((5,5), np.uint8)
27
- mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
28
- mask = cv2.GaussianBlur(mask, (5,5), 0)
29
- return mask
 
 
 
 
30
 
31
- def get_clothing_mask(self, image):
32
- """Extract and enhance clothing mask"""
33
- inputs = self.processor(images=image, return_tensors="pt")
34
- outputs = self.segmenter(**inputs)
35
- logits = outputs.logits.squeeze()
36
-
37
- clothing_mask = (logits.argmax(0) == 5).float().numpy()
38
- clothing_mask = (clothing_mask * 255).astype(np.uint8)
39
- clothing_mask = self.enhance_mask(clothing_mask)
40
-
41
- return Image.fromarray(clothing_mask)
 
 
 
 
 
 
 
 
 
42
 
43
- def try_on(self, original_image, prompt, style_strength=0.7, progress=gr.Progress()):
44
- """Enhanced virtual try-on with style control"""
45
  try:
46
- progress(0, desc="Initializing...")
 
 
 
 
 
 
47
 
48
- # Image preprocessing
49
- original_image = Image.fromarray(original_image)
50
- if original_image.mode != "RGB":
51
- original_image = original_image.convert("RGB")
 
 
 
52
 
53
- progress(0.2, desc="Analyzing clothing...")
54
- mask = self.get_clothing_mask(original_image)
 
 
 
 
55
 
56
- # Enhanced prompt engineering
57
- style_prompts = {
58
- "quality": "ultra detailed, 8k uhd, high quality, professional photo",
59
- "lighting": "perfect lighting, studio lighting, professional photography",
60
- "realism": "hyperrealistic, photorealistic, highly detailed"
61
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- full_prompt = f"""
64
- A person wearing {prompt}, {style_prompts['quality']},
65
- {style_prompts['lighting']}, {style_prompts['realism']}
66
- """
67
 
68
- negative_prompt = """
69
- low quality, blurry, distorted, deformed, unrealistic,
70
- bad proportions, bad lighting, oversaturated, undersaturated
71
- """
72
 
73
- progress(0.4, desc="Generating new clothing...")
74
- result = self.inpaint_model(
75
- prompt=full_prompt,
76
- negative_prompt=negative_prompt,
77
- image=original_image,
78
- mask_image=mask,
79
- num_inference_steps=50,
80
- guidance_scale=7.5 * style_strength
81
- ).images[0]
82
 
83
- progress(1.0, desc="Finalizing...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  return result
85
 
86
  except Exception as e:
87
- raise gr.Error(f"Processing Error: {str(e)}")
88
-
89
- # Initialize the model
90
- model = RafayyVirtualTryOn()
91
 
92
- # Custom CSS with professional styling
93
- custom_css = """
94
- .gradio-container {
95
- font-family: 'Poppins', sans-serif;
96
- max-width: 1200px !important;
97
- margin: auto !important;
98
- }
99
- #component-0 {
100
- max-width: 100% !important;
101
- margin-bottom: 20px !important;
102
- }
103
- .gr-button {
104
- background: linear-gradient(90deg, #2193b0, #6dd5ed) !important;
105
- border: none !important;
106
- color: white !important;
107
- font-weight: 600 !important;
108
- }
109
- .gr-button:hover {
110
- background: linear-gradient(90deg, #6dd5ed, #2193b0) !important;
111
- transform: translateY(-2px);
112
- box-shadow: 0 5px 15px rgba(33, 147, 176, 0.3) !important;
113
- transition: all 0.3s ease;
114
- }
115
- .gr-input {
116
- border: 2px solid #e0e0e0 !important;
117
- border-radius: 8px !important;
118
- padding: 12px !important;
119
- }
120
- .gr-input:focus {
121
- border-color: #2193b0 !important;
122
- box-shadow: 0 0 0 2px rgba(33, 147, 176, 0.2) !important;
123
- }
124
- .gr-panel {
125
- border-radius: 12px !important;
126
- box-shadow: 0 4px 20px rgba(0, 0, 0, 0.1) !important;
127
- }
128
- .footer {
129
- background: linear-gradient(to right, #f8f9fa, #e9ecef);
130
- padding: 20px;
131
- border-radius: 10px;
132
- margin-top: 30px;
133
- }
134
- """
135
 
136
- # Create enhanced Gradio interface
137
  demo = gr.Interface(
138
  fn=model.try_on,
139
  inputs=[
140
  gr.Image(label="📸 Upload Your Photo", type="numpy"),
141
  gr.Textbox(
142
  label="🎨 Describe New Clothing",
143
- placeholder="e.g., 'elegant black suit with silk lapels', 'designer red dress with gold accents'",
144
  lines=2
145
  ),
146
  gr.Slider(
@@ -151,59 +170,32 @@ demo = gr.Interface(
151
  step=0.1
152
  )
153
  ],
154
- outputs=gr.Image(label="✨ Your New Look", type="pil"),
155
- title="🌟 Rafayy's Professional Virtual Try-On Studio 🌟",
156
  description="""
157
- <div style="text-align: center; max-width: 800px; margin: 0 auto; padding: 20px;">
158
- <h3 style="color: #2193b0; font-size: 24px; margin-bottom: 10px;">Transform Your Style with AI</h3>
159
- <p style="color: #666; font-size: 16px;">Experience the future of fashion with our advanced virtual try-on technology.</p>
160
- <div style="margin-top: 20px; padding: 15px; background: rgba(33, 147, 176, 0.1); border-radius: 10px;">
161
- <p style="font-weight: bold; color: #2193b0;">Premium Features:</p>
162
- <p>✓ High-Resolution Output</p>
163
- <p>✓ Advanced Clothing Detection</p>
164
- <p>✓ Professional Style Enhancement</p>
165
- </div>
166
  </div>
167
  """,
168
  examples=[
169
- ["example1.jpg", "luxury black suit with silk details", 0.8],
170
- ["example2.jpg", "designer white dress with lace accents", 0.7],
171
- ["example3.jpg", "professional navy blazer with gold buttons", 0.9],
172
- ["example4.jpg", "casual denim jacket with vintage wash", 0.6]
173
  ],
174
- css=custom_css
 
175
  )
176
 
177
- # Enhanced footer with professional information
178
- demo.footer = """
179
- <div class="footer">
180
- <div style="text-align: center;">
181
- <h4 style="color: #2193b0; margin-bottom: 15px;">🎯 Professional Tips for Best Results</h4>
182
- <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 20px;">
183
- <div>
184
- <p style="font-weight: bold;">📸 Photo Guidelines</p>
185
- <p>• Use well-lit, front-facing photos</p>
186
- <p>• Ensure clear visibility of clothing</p>
187
- <p>• Avoid complex backgrounds</p>
188
- </div>
189
- <div>
190
- <p style="font-weight: bold;">✍️ Description Tips</p>
191
- <p>• Be specific about style details</p>
192
- <p>• Include color and material</p>
193
- <p>• Mention design elements</p>
194
- </div>
195
- <div>
196
- <p style="font-weight: bold;">⚙️ Settings</p>
197
- <p>• Adjust style strength as needed</p>
198
- <p>• Higher values for bold changes</p>
199
- <p>• Lower values for subtle effects</p>
200
- </div>
201
- </div>
202
- <p style="margin-top: 20px; color: #666;">© 2024 Rafayy AI Studio | Professional Virtual Try-On Service</p>
203
- </div>
204
- </div>
205
- """
206
-
207
- # Launch the app
208
  if __name__ == "__main__":
209
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
+ import gc
4
+ from diffusers import StableDiffusionInpaintPipeline
5
  from PIL import Image
6
  import numpy as np
7
  from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
8
  import cv2
9
+ import traceback
10
 
11
  class RafayyVirtualTryOn:
12
  def __init__(self):
13
+ try:
14
+ # Clear CUDA cache
15
+ if torch.cuda.is_available():
16
+ torch.cuda.empty_cache()
17
+ gc.collect()
18
+
19
+ # Use smaller model for stability
20
+ self.inpaint_model = StableDiffusionInpaintPipeline.from_pretrained(
21
+ "runwayml/stable-diffusion-v1-5",
22
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
23
+ safety_checker=None # Disable safety checker if causing issues
24
+ )
25
+
26
+ if torch.cuda.is_available():
27
+ self.inpaint_model.to("cuda")
28
+ self.inpaint_model.enable_attention_slicing() # Reduce memory usage
29
 
30
+ # Initialize segmentation with error handling
31
+ try:
32
+ self.segmenter = SegformerForSemanticSegmentation.from_pretrained(
33
+ "mattmdjaga/segformer_b2_clothes",
34
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
35
+ )
36
+ self.processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
37
+ except Exception as e:
38
+ print(f"Segmentation model loading error: {str(e)}")
39
+ raise
40
+
41
+ except Exception as e:
42
+ print(f"Initialization error: {str(e)}")
43
+ raise
44
 
45
+ def preprocess_image(self, image):
46
+ """Safely preprocess input image"""
47
+ try:
48
+ if isinstance(image, np.ndarray):
49
+ image = Image.fromarray(image)
50
+
51
+ # Ensure image is RGB
52
+ if image.mode != "RGB":
53
+ image = image.convert("RGB")
54
+
55
+ # Resize if too large
56
+ max_size = 768
57
+ if max(image.size) > max_size:
58
+ ratio = max_size / max(image.size)
59
+ new_size = tuple(int(dim * ratio) for dim in image.size)
60
+ image = image.resize(new_size, Image.LANCZOS)
61
+
62
+ return image
63
+ except Exception as e:
64
+ raise gr.Error(f"Image preprocessing failed: {str(e)}")
65
 
66
+ def get_clothing_mask(self, image):
67
+ """Safely extract clothing mask"""
68
  try:
69
+ # Convert to RGB if needed
70
+ if isinstance(image, np.ndarray):
71
+ image = Image.fromarray(image)
72
+ if image.mode != "RGB":
73
+ image = image.convert("RGB")
74
+
75
+ inputs = self.processor(images=image, return_tensors="pt")
76
 
77
+ # Move to GPU if available
78
+ if torch.cuda.is_available():
79
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
80
+ self.segmenter = self.segmenter.to("cuda")
81
+
82
+ outputs = self.segmenter(**inputs)
83
+ logits = outputs.logits.squeeze()
84
 
85
+ # Move back to CPU for numpy operations
86
+ if torch.cuda.is_available():
87
+ logits = logits.cpu()
88
+
89
+ clothing_mask = (logits.argmax(0) == 5).float().numpy()
90
+ clothing_mask = (clothing_mask * 255).astype(np.uint8)
91
 
92
+ # Enhance mask
93
+ kernel = np.ones((5,5), np.uint8)
94
+ clothing_mask = cv2.dilate(clothing_mask, kernel, iterations=1)
95
+ clothing_mask = cv2.GaussianBlur(clothing_mask, (5,5), 0)
96
+
97
+ return Image.fromarray(clothing_mask)
98
+ except Exception as e:
99
+ raise gr.Error(f"Mask generation failed: {str(e)}")
100
+
101
+ def try_on(self, image, prompt, style_strength=0.7, progress=gr.Progress()):
102
+ """Main try-on function with comprehensive error handling"""
103
+ try:
104
+ if image is None:
105
+ raise gr.Error("Please upload an image first")
106
+ if not prompt or prompt.strip() == "":
107
+ raise gr.Error("Please provide a clothing description")
108
+
109
+ progress(0.1, desc="Preprocessing image...")
110
+ original_image = self.preprocess_image(image)
111
 
112
+ progress(0.3, desc="Detecting clothing...")
113
+ mask = self.get_clothing_mask(original_image)
 
 
114
 
115
+ # Clear GPU memory
116
+ if torch.cuda.is_available():
117
+ torch.cuda.empty_cache()
118
+ gc.collect()
119
 
120
+ progress(0.5, desc="Preparing generation...")
121
+ # Enhanced prompt engineering
122
+ full_prompt = f"A person wearing {prompt}, professional photo, detailed, realistic, high quality"
123
+ negative_prompt = "low quality, blurry, distorted, deformed, bad anatomy, unrealistic"
 
 
 
 
 
124
 
125
+ progress(0.7, desc="Generating new clothing...")
126
+ try:
127
+ result = self.inpaint_model(
128
+ prompt=full_prompt,
129
+ negative_prompt=negative_prompt,
130
+ image=original_image,
131
+ mask_image=mask,
132
+ num_inference_steps=30, # Reduced for stability
133
+ guidance_scale=7.5 * style_strength
134
+ ).images[0]
135
+ except torch.cuda.OutOfMemoryError:
136
+ torch.cuda.empty_cache()
137
+ gc.collect()
138
+ raise gr.Error("Out of memory. Please try with a smaller image.")
139
+
140
+ progress(1.0, desc="Done!")
141
  return result
142
 
143
  except Exception as e:
144
+ error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
145
+ print(error_msg) # For logging
146
+ raise gr.Error(str(e))
 
147
 
148
+ # Initialize model with error handling
149
+ try:
150
+ model = RafayyVirtualTryOn()
151
+ except Exception as e:
152
+ print(f"Model initialization failed: {str(e)}")
153
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
+ # Create Gradio interface with error handling
156
  demo = gr.Interface(
157
  fn=model.try_on,
158
  inputs=[
159
  gr.Image(label="📸 Upload Your Photo", type="numpy"),
160
  gr.Textbox(
161
  label="🎨 Describe New Clothing",
162
+ placeholder="e.g., 'elegant black suit', 'red dress'",
163
  lines=2
164
  ),
165
  gr.Slider(
 
170
  step=0.1
171
  )
172
  ],
173
+ outputs=gr.Image(label="✨ Result", type="pil"),
174
+ title="🌟 Rafayy's Virtual Try-On Studio 🌟",
175
  description="""
176
+ <div style="text-align: center;">
177
+ <h3>Transform Your Style with AI</h3>
178
+ <p>Upload a photo and describe the new clothing you want to try on!</p>
 
 
 
 
 
 
179
  </div>
180
  """,
181
  examples=[
182
+ ["example1.jpg", "black suit", 0.7],
183
+ ["example2.jpg", "white dress", 0.7]
 
 
184
  ],
185
+ allow_flagging="never",
186
+ cache_examples=True
187
  )
188
 
189
+ # Launch with error handling
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  if __name__ == "__main__":
191
+ try:
192
+ demo.launch(
193
+ share=False,
194
+ server_name="0.0.0.0",
195
+ server_port=7860,
196
+ show_error=True,
197
+ enable_queue=True
198
+ )
199
+ except Exception as e:
200
+ print(f"Launch failed: {str(e)}")
201
+ raise