import torch from flask import Flask, request, jsonify import numpy as np from PIL import Image from segment_anything import SamAutomaticMaskGenerator, sam_model_registry from cv2 import imencode from base64 import b64encode import requests import time device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") sam = sam_model_registry["vit_l"](checkpoint="sam_vit_l_0b3195.pth") sam.to(device=device) mask_generator = SamAutomaticMaskGenerator(sam) print("Loaded model") app = Flask(__name__) @app.route('/', methods=['POST']) def index(): app.logger.info('Got request !') start = time.time() input = request.json url = input.get('url') app.logger.info('Got request for url %s', url) image = np.array(Image.open(requests.get(url, stream=True).raw).convert("RGB")) masks = mask_generator.generate(image) data = [] for mask in masks: mask_image = np.zeros(image.shape[:3], np.uint8) mask_image[mask["segmentation"] == True] = 255 retval, buffer = imencode('.png', mask_image) encoded_mask = b64encode(buffer).decode("ascii") data.append({ "label": "", "mask": encoded_mask, "score": mask["predicted_iou"] }) end = time.time() return jsonify({ "data": data, "time": end - start }) @app.route('/health', methods=['GET']) def health(): return jsonify({ "success": True }) if __name__ == '__main__': app.run(host='0.0.0.0', port=8000)