baconseason commited on
Commit
4a2ff96
·
1 Parent(s): a90ff45
Files changed (4) hide show
  1. .gitignore +1 -0
  2. handler.py +2 -2
  3. requirements.txt +3 -0
  4. server.py +53 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ sam_vit_l_0b3195.pth
handler.py CHANGED
@@ -23,9 +23,9 @@ class EndpointHandler():
23
  for index, mask in enumerate(masks):
24
  cv_image = np.array(raw_image)
25
  mask_image = np.zeros(cv_image.shape[:3], np.uint8)
26
- mask_image[mask == False] = 255
27
  retval, buffer = imencode('.png', mask_image)
28
- encoded_mask = b64encode(buffer)
29
  data.append({
30
  "score": outputs["scores"][index].item(),
31
  "mask": encoded_mask,
 
23
  for index, mask in enumerate(masks):
24
  cv_image = np.array(raw_image)
25
  mask_image = np.zeros(cv_image.shape[:3], np.uint8)
26
+ mask_image[mask == True] = 255
27
  retval, buffer = imencode('.png', mask_image)
28
+ encoded_mask = b64encode(buffer).decode("ascii")
29
  data.append({
30
  "score": outputs["scores"][index].item(),
31
  "mask": encoded_mask,
requirements.txt CHANGED
@@ -1,6 +1,9 @@
1
  torch
 
2
  transformers
3
  pillow
4
  numpy
5
  requests
6
  opencv-python
 
 
 
1
  torch
2
+ torchvision
3
  transformers
4
  pillow
5
  numpy
6
  requests
7
  opencv-python
8
+ segment_anything
9
+ flask
server.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from flask import Flask, request, jsonify
3
+ import numpy as np
4
+ from PIL import Image
5
+ from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
6
+ from cv2 import imencode
7
+ from base64 import b64encode
8
+ import requests
9
+ import time
10
+
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ print(f"Using device: {device}")
13
+
14
+ sam = sam_model_registry["vit_l"](checkpoint="sam_vit_l_0b3195.pth")
15
+ sam.to(device=device)
16
+ mask_generator = SamAutomaticMaskGenerator(sam)
17
+ print("Loaded model")
18
+
19
+ app = Flask(__name__)
20
+
21
+ @app.route('/', methods=['POST'])
22
+ def index():
23
+ app.logger.info('Got request !')
24
+ start = time.time()
25
+ input = request.json
26
+ url = input.get('url')
27
+ app.logger.info('Got request for url %s', url)
28
+
29
+ image = np.array(Image.open(requests.get(url, stream=True).raw).convert("RGB"))
30
+ masks = mask_generator.generate(image)
31
+
32
+ data = []
33
+ for mask in masks:
34
+ mask_image = np.zeros(image.shape[:3], np.uint8)
35
+ mask_image[mask["segmentation"] == True] = 255
36
+ retval, buffer = imencode('.png', mask_image)
37
+ encoded_mask = b64encode(buffer).decode("ascii")
38
+ data.append({
39
+ "label": "",
40
+ "mask": encoded_mask,
41
+ "score": mask["predicted_iou"]
42
+ })
43
+
44
+ end = time.time()
45
+
46
+ return jsonify({ "data": data, "time": end - start })
47
+
48
+ @app.route('/health', methods=['GET'])
49
+ def health():
50
+ return jsonify({ "success": True })
51
+
52
+ if __name__ == '__main__':
53
+ app.run(host='0.0.0.0', port=8000)