sam-hf-endpoint / handler.py
baconseason's picture
logging
8f8fa7b
import torch
from transformers import pipeline
from transformers.utils import logging
from PIL import Image
import requests
import numpy as np
from cv2 import imencode
from base64 import b64encode
import time
# logging.set_verbosity_info()
class EndpointHandler():
def __init__(self, path=""):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.logger = logging.get_logger("transformers")
self.logger.warn(f"Using device: {self.device}")
self.generator = pipeline("mask-generation", model="facebook/sam-vit-large", device=self.device)
def __call__(self, data):
start = time.time()
inputs = data.pop("inputs", data)
self.logger.warn(f"got request for {inputs}")
raw_image = Image.open(requests.get(inputs, stream=True).raw).convert("RGB")
with torch.no_grad():
outputs = self.generator(raw_image, points_per_batch=32)
masks = outputs["masks"]
data = []
for index, mask in enumerate(masks):
cv_image = np.array(raw_image)
mask_image = np.zeros(cv_image.shape[:3], np.uint8)
mask_image[mask == True] = 255
retval, buffer = imencode('.png', mask_image)
encoded_mask = b64encode(buffer).decode("ascii")
data.append({
"score": outputs["scores"][index].item(),
"mask": encoded_mask,
"label": ""
})
end = time.time()
return { "data": data, "time": end - start }