File size: 1,425 Bytes
24ec34e 2d7803c 24ec34e 42ff6ec 24ec34e 8f8fa7b 2d7803c 24ec34e cbd0404 24ec34e 2d7803c 8f8fa7b 24ec34e 125c44b 42ff6ec 125c44b 8f8fa7b 125c44b a90ff45 24ec34e 4a2ff96 24ec34e 4a2ff96 24ec34e 42ff6ec 24ec34e 42ff6ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import torch
from transformers import pipeline
from transformers.utils import logging
from PIL import Image
import requests
import numpy as np
from cv2 import imencode
from base64 import b64encode
import time
# logging.set_verbosity_info()
class EndpointHandler():
def __init__(self, path=""):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.logger = logging.get_logger("transformers")
self.logger.warn(f"Using device: {self.device}")
self.generator = pipeline("mask-generation", model="facebook/sam-vit-large", device=self.device)
def __call__(self, data):
start = time.time()
inputs = data.pop("inputs", data)
self.logger.warn(f"got request for {inputs}")
raw_image = Image.open(requests.get(inputs, stream=True).raw).convert("RGB")
with torch.no_grad():
outputs = self.generator(raw_image, points_per_batch=32)
masks = outputs["masks"]
data = []
for index, mask in enumerate(masks):
cv_image = np.array(raw_image)
mask_image = np.zeros(cv_image.shape[:3], np.uint8)
mask_image[mask == True] = 255
retval, buffer = imencode('.png', mask_image)
encoded_mask = b64encode(buffer).decode("ascii")
data.append({
"score": outputs["scores"][index].item(),
"mask": encoded_mask,
"label": ""
})
end = time.time()
return { "data": data, "time": end - start } |