|
import torch |
|
from transformers import pipeline |
|
from transformers.utils import logging |
|
from PIL import Image |
|
import requests |
|
import numpy as np |
|
from cv2 import imencode |
|
from base64 import b64encode |
|
import time |
|
|
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.logger = logging.get_logger("transformers") |
|
self.logger.warn(f"Using device: {self.device}") |
|
self.generator = pipeline("mask-generation", model="facebook/sam-vit-large", device=self.device) |
|
|
|
def __call__(self, data): |
|
start = time.time() |
|
inputs = data.pop("inputs", data) |
|
self.logger.warn(f"got request for {inputs}") |
|
raw_image = Image.open(requests.get(inputs, stream=True).raw).convert("RGB") |
|
|
|
with torch.no_grad(): |
|
outputs = self.generator(raw_image, points_per_batch=32) |
|
|
|
masks = outputs["masks"] |
|
data = [] |
|
for index, mask in enumerate(masks): |
|
cv_image = np.array(raw_image) |
|
mask_image = np.zeros(cv_image.shape[:3], np.uint8) |
|
mask_image[mask == True] = 255 |
|
retval, buffer = imencode('.png', mask_image) |
|
encoded_mask = b64encode(buffer).decode("ascii") |
|
data.append({ |
|
"score": outputs["scores"][index].item(), |
|
"mask": encoded_mask, |
|
"label": "" |
|
}) |
|
|
|
end = time.time() |
|
|
|
return { "data": data, "time": end - start } |