File size: 1,425 Bytes
24ec34e
 
2d7803c
24ec34e
 
 
 
 
42ff6ec
24ec34e
8f8fa7b
2d7803c
24ec34e
cbd0404
24ec34e
2d7803c
8f8fa7b
24ec34e
 
125c44b
42ff6ec
125c44b
8f8fa7b
125c44b
a90ff45
 
 
 
24ec34e
 
 
 
 
4a2ff96
24ec34e
4a2ff96
24ec34e
 
 
 
 
42ff6ec
 
24ec34e
42ff6ec
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
from transformers import pipeline
from transformers.utils import logging
from PIL import Image
import requests
import numpy as np
from cv2 import imencode
from base64 import b64encode
import time

# logging.set_verbosity_info()

class EndpointHandler():
  def __init__(self, path=""):
    self.device = "cuda" if torch.cuda.is_available() else "cpu"
    self.logger = logging.get_logger("transformers")
    self.logger.warn(f"Using device: {self.device}")
    self.generator = pipeline("mask-generation", model="facebook/sam-vit-large", device=self.device)

  def __call__(self, data):
    start = time.time()
    inputs = data.pop("inputs", data)
    self.logger.warn(f"got request for {inputs}")
    raw_image = Image.open(requests.get(inputs, stream=True).raw).convert("RGB")
    
    with torch.no_grad():
      outputs = self.generator(raw_image, points_per_batch=32)
    
    masks = outputs["masks"]
    data = []
    for index, mask in enumerate(masks):
      cv_image = np.array(raw_image)
      mask_image = np.zeros(cv_image.shape[:3], np.uint8)
      mask_image[mask == True] = 255
      retval, buffer = imencode('.png', mask_image)
      encoded_mask = b64encode(buffer).decode("ascii")
      data.append({
          "score": outputs["scores"][index].item(),
          "mask": encoded_mask,
          "label": ""
      })

    end = time.time()
    
    return { "data": data, "time": end - start }