noxo8888 commited on
Commit
211a063
·
verified ·
1 Parent(s): 346ce0c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -108
app.py CHANGED
@@ -7,6 +7,12 @@ import numpy as np
7
  from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
8
  import cv2
9
  import traceback
 
 
 
 
 
 
10
 
11
  class RafayyVirtualTryOn:
12
  def __init__(self):
@@ -17,23 +23,29 @@ class RafayyVirtualTryOn:
17
  gc.collect()
18
 
19
  # Use smaller model for stability
 
20
  self.inpaint_model = StableDiffusionInpaintPipeline.from_pretrained(
21
- "runwayml/stable-diffusion-v1-5",
22
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
23
- safety_checker=None # Disable safety checker if causing issues
 
24
  )
25
 
26
  if torch.cuda.is_available():
27
  self.inpaint_model.to("cuda")
28
- self.inpaint_model.enable_attention_slicing() # Reduce memory usage
29
 
30
  # Initialize segmentation with error handling
31
  try:
32
  self.segmenter = SegformerForSemanticSegmentation.from_pretrained(
33
  "mattmdjaga/segformer_b2_clothes",
34
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
 
 
 
 
 
35
  )
36
- self.processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
37
  except Exception as e:
38
  print(f"Segmentation model loading error: {str(e)}")
39
  raise
@@ -42,112 +54,23 @@ class RafayyVirtualTryOn:
42
  print(f"Initialization error: {str(e)}")
43
  raise
44
 
45
- def preprocess_image(self, image):
46
- """Safely preprocess input image"""
47
- try:
48
- if isinstance(image, np.ndarray):
49
- image = Image.fromarray(image)
50
-
51
- # Ensure image is RGB
52
- if image.mode != "RGB":
53
- image = image.convert("RGB")
54
-
55
- # Resize if too large
56
- max_size = 768
57
- if max(image.size) > max_size:
58
- ratio = max_size / max(image.size)
59
- new_size = tuple(int(dim * ratio) for dim in image.size)
60
- image = image.resize(new_size, Image.LANCZOS)
61
-
62
- return image
63
- except Exception as e:
64
- raise gr.Error(f"Image preprocessing failed: {str(e)}")
65
-
66
- def get_clothing_mask(self, image):
67
- """Safely extract clothing mask"""
68
- try:
69
- # Convert to RGB if needed
70
- if isinstance(image, np.ndarray):
71
- image = Image.fromarray(image)
72
- if image.mode != "RGB":
73
- image = image.convert("RGB")
74
-
75
- inputs = self.processor(images=image, return_tensors="pt")
76
-
77
- # Move to GPU if available
78
- if torch.cuda.is_available():
79
- inputs = {k: v.to("cuda") for k, v in inputs.items()}
80
- self.segmenter = self.segmenter.to("cuda")
81
-
82
- outputs = self.segmenter(**inputs)
83
- logits = outputs.logits.squeeze()
84
-
85
- # Move back to CPU for numpy operations
86
- if torch.cuda.is_available():
87
- logits = logits.cpu()
88
-
89
- clothing_mask = (logits.argmax(0) == 5).float().numpy()
90
- clothing_mask = (clothing_mask * 255).astype(np.uint8)
91
-
92
- # Enhance mask
93
- kernel = np.ones((5,5), np.uint8)
94
- clothing_mask = cv2.dilate(clothing_mask, kernel, iterations=1)
95
- clothing_mask = cv2.GaussianBlur(clothing_mask, (5,5), 0)
96
-
97
- return Image.fromarray(clothing_mask)
98
- except Exception as e:
99
- raise gr.Error(f"Mask generation failed: {str(e)}")
100
 
101
- def try_on(self, image, prompt, style_strength=0.7, progress=gr.Progress()):
102
- """Main try-on function with comprehensive error handling"""
 
103
  try:
104
- if image is None:
105
- raise gr.Error("Please upload an image first")
106
- if not prompt or prompt.strip() == "":
107
- raise gr.Error("Please provide a clothing description")
108
-
109
- progress(0.1, desc="Preprocessing image...")
110
- original_image = self.preprocess_image(image)
111
-
112
- progress(0.3, desc="Detecting clothing...")
113
- mask = self.get_clothing_mask(original_image)
114
-
115
- # Clear GPU memory
116
- if torch.cuda.is_available():
117
- torch.cuda.empty_cache()
118
- gc.collect()
119
-
120
- progress(0.5, desc="Preparing generation...")
121
- # Enhanced prompt engineering
122
- full_prompt = f"A person wearing {prompt}, professional photo, detailed, realistic, high quality"
123
- negative_prompt = "low quality, blurry, distorted, deformed, bad anatomy, unrealistic"
124
-
125
- progress(0.7, desc="Generating new clothing...")
126
- try:
127
- result = self.inpaint_model(
128
- prompt=full_prompt,
129
- negative_prompt=negative_prompt,
130
- image=original_image,
131
- mask_image=mask,
132
- num_inference_steps=30, # Reduced for stability
133
- guidance_scale=7.5 * style_strength
134
- ).images[0]
135
- except torch.cuda.OutOfMemoryError:
136
- torch.cuda.empty_cache()
137
- gc.collect()
138
- raise gr.Error("Out of memory. Please try with a smaller image.")
139
-
140
- progress(1.0, desc="Done!")
141
- return result
142
-
143
  except Exception as e:
144
- error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
145
- print(error_msg) # For logging
146
- raise gr.Error(str(e))
 
 
 
147
 
148
- # Initialize model with error handling
149
  try:
150
- model = RafayyVirtualTryOn()
151
  except Exception as e:
152
  print(f"Model initialization failed: {str(e)}")
153
  raise
@@ -186,7 +109,7 @@ demo = gr.Interface(
186
  cache_examples=True
187
  )
188
 
189
- # Launch with error handling
190
  if __name__ == "__main__":
191
  try:
192
  demo.launch(
@@ -194,7 +117,9 @@ if __name__ == "__main__":
194
  server_name="0.0.0.0",
195
  server_port=7860,
196
  show_error=True,
197
- enable_queue=True
 
 
198
  )
199
  except Exception as e:
200
  print(f"Launch failed: {str(e)}")
 
7
  from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
8
  import cv2
9
  import traceback
10
+ import os
11
+
12
+ # Set environment variables to prevent warnings and errors
13
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
14
+ os.environ['TORCH_HOME'] = '/tmp/torch'
15
+ os.environ['HF_HOME'] = '/tmp/huggingface'
16
 
17
  class RafayyVirtualTryOn:
18
  def __init__(self):
 
23
  gc.collect()
24
 
25
  # Use smaller model for stability
26
+ model_id = "runwayml/stable-diffusion-inpainting"
27
  self.inpaint_model = StableDiffusionInpaintPipeline.from_pretrained(
28
+ model_id,
29
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
30
+ safety_checker=None,
31
+ cache_dir='/tmp/models'
32
  )
33
 
34
  if torch.cuda.is_available():
35
  self.inpaint_model.to("cuda")
36
+ self.inpaint_model.enable_attention_slicing()
37
 
38
  # Initialize segmentation with error handling
39
  try:
40
  self.segmenter = SegformerForSemanticSegmentation.from_pretrained(
41
  "mattmdjaga/segformer_b2_clothes",
42
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
43
+ cache_dir='/tmp/models'
44
+ )
45
+ self.processor = SegformerImageProcessor.from_pretrained(
46
+ "mattmdjaga/segformer_b2_clothes",
47
+ cache_dir='/tmp/models'
48
  )
 
49
  except Exception as e:
50
  print(f"Segmentation model loading error: {str(e)}")
51
  raise
 
54
  print(f"Initialization error: {str(e)}")
55
  raise
56
 
57
+ # ... (rest of the code remains the same)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ # Initialize model with error handling and retry mechanism
60
+ def initialize_model(max_retries=3):
61
+ for attempt in range(max_retries):
62
  try:
63
+ return RafayyVirtualTryOn()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  except Exception as e:
65
+ if attempt == max_retries - 1:
66
+ print(f"Failed to initialize model after {max_retries} attempts: {str(e)}")
67
+ raise
68
+ print(f"Attempt {attempt + 1} failed, retrying...")
69
+ torch.cuda.empty_cache()
70
+ gc.collect()
71
 
 
72
  try:
73
+ model = initialize_model()
74
  except Exception as e:
75
  print(f"Model initialization failed: {str(e)}")
76
  raise
 
109
  cache_examples=True
110
  )
111
 
112
+ # Launch with error handling and retry mechanism
113
  if __name__ == "__main__":
114
  try:
115
  demo.launch(
 
117
  server_name="0.0.0.0",
118
  server_port=7860,
119
  show_error=True,
120
+ enable_queue=True,
121
+ cache_examples=True,
122
+ max_threads=4
123
  )
124
  except Exception as e:
125
  print(f"Launch failed: {str(e)}")