ktllc commited on
Commit
608f6fc
·
1 Parent(s): d666f15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -21
app.py CHANGED
@@ -17,34 +17,33 @@ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
  model, preprocess = clip.load("ViT-B/32", device=device)
19
 
20
-
21
  def find_similarity(base64_image, text_input):
22
- # Decode the base64 image to bytes
23
- image_bytes = base64.b64decode(base64_image)
24
-
25
- # Convert the bytes to a PIL image
26
- image = Image.open(BytesIO(image_bytes))
27
-
28
- # Preprocess the image
29
- image = preprocess(image).unsqueeze(0).to(device)
30
-
31
- # Prepare input text
32
- text_tokens = clip.tokenize([text_input]).to(device)
33
 
34
- # Encode image and text features
 
35
 
 
 
36
 
37
- with torch.no_grad():
38
- image_features = model.encode_image(image)
39
- text_features = model.encode_text(text_tokens)
40
 
41
- # Normalize features and calculate similarity
42
- image_features /= image_features.norm(dim=-1, keepdim=True)
43
- text_features /= text_features.norm(dim=-1, keepdim=True)
44
- similarity = (text_features @ image_features.T).squeeze(0).cpu().numpy()
45
 
46
- return similarity
 
 
 
47
 
 
 
 
48
 
49
  # Define a function for image segmentation
50
  def segment_image(input_image, text_input):
 
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
  model, preprocess = clip.load("ViT-B/32", device=device)
19
 
 
20
  def find_similarity(base64_image, text_input):
21
+ try:
22
+ # Decode the base64 image to bytes
23
+ image_bytes = base64.b64decode(base64_image)
 
 
 
 
 
 
 
 
24
 
25
+ # Convert the bytes to a PIL image
26
+ image = Image.open(BytesIO(image_bytes))
27
 
28
+ # Preprocess the image
29
+ image = preprocess(image).unsqueeze(0).to(device)
30
 
31
+ # Prepare input text
32
+ text_tokens = clip.tokenize([text_input]).to(device)
 
33
 
34
+ # Encode image and text features
35
+ with torch.no_grad():
36
+ image_features = model.encode_image(image)
37
+ text_features = model.encode_text(text_tokens)
38
 
39
+ # Normalize features and calculate similarity
40
+ image_features /= image_features.norm(dim=-1, keepdim=True)
41
+ text_features /= text_features.norm(dim=-1, keepdim=True)
42
+ similarity = (text_features @ image_features.T).squeeze(0).cpu().numpy()
43
 
44
+ return similarity
45
+ except Exception as e:
46
+ return str(e)
47
 
48
  # Define a function for image segmentation
49
  def segment_image(input_image, text_input):