Spaces:
Running
Running
import os | |
import onnxruntime | |
from insightface.model_zoo import SCRFD | |
from insightface.model_zoo import ArcFaceONNX | |
from insightface.app.common import Face | |
import numpy as np | |
import cv2 | |
import json | |
from typing import Annotated, List | |
from dataclasses import dataclass, field | |
from fastapi import FastAPI, status, HTTPException, Depends, UploadFile | |
from fastapi.middleware.gzip import GZipMiddleware | |
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer | |
keys = json.loads(os.getenv('KEYS')) | |
def verify_credentials(credentials: str): | |
if not credentials in keys: | |
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) | |
detection_model_file = os.path.abspath(os.path.join(os.getcwd(), "det.onnx")) | |
recognition_model_file = os.path.abspath(os.path.join(os.getcwd(), "rec.onnx")) | |
detection = SCRFD(model_file=detection_model_file) | |
detection.prepare(ctx_id=-1, input_size=(640, 640)) | |
recognition = ArcFaceONNX(model_file=recognition_model_file) | |
recognition.prepare(ctx_id=-1) | |
def read_image_from_spooled_temporary_file(temp_file): | |
temp_file.seek(0) | |
return cv2.imdecode(np.frombuffer(temp_file.read(), np.uint8), cv2.IMREAD_COLOR) | |
app = FastAPI( | |
title="Memoir Face", | |
version="1.0.0" | |
) | |
app.add_middleware(GZipMiddleware) | |
security = HTTPBearer(auto_error=True) | |
class Point: | |
x: float | |
y: float | |
class BoundingBox: | |
topLeft: Point | |
bottomRight: Point | |
class FaceDetection: | |
detectionScore: float | |
boundingBox: BoundingBox | |
keyPoints: List[Point] | |
class FaceRecognition(FaceDetection): | |
embedding: List[float] | |
class FaceDetectionResult: | |
faces: List[FaceDetection] | |
class FaceRecognitionResult: | |
faces: List[FaceRecognition] | |
async def v1_detection( | |
credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)], | |
photo: UploadFile | |
): | |
verify_credentials(credentials.credentials) | |
if not detection: | |
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE) | |
image = read_image_from_spooled_temporary_file(photo.file) | |
bboxes, kpss = detection.detect(image) | |
if bboxes.shape[0] == 0: | |
return {"faces": []} | |
faces = [] | |
for i in range(bboxes.shape[0]): | |
bbox = bboxes[i, :4] | |
det_score = bboxes[i, 4] | |
kps = None | |
if kpss is not None: | |
kps = kpss[i] | |
faces.append({ | |
"detectionScore": det_score.item(), | |
"boundingBox": { | |
"topLeft": { | |
"x": bbox[0].item(), | |
"y": bbox[1].item() | |
}, | |
"bottomRight": { | |
"x": bbox[2].item(), | |
"y": bbox[3].item() | |
} | |
}, | |
"keyPoints": [{"x": pt[0].item(), "y": pt[1].item()} for pt in kps] | |
}) | |
return {"faces": faces} | |
async def v1_recognition( | |
credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)], | |
photo: UploadFile | |
): | |
verify_credentials(credentials.credentials) | |
if not detection or not recognition: | |
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE) | |
image = read_image_from_spooled_temporary_file(photo.file) | |
bboxes, kpss = detection.detect(image) | |
if bboxes.shape[0] == 0: | |
return {"faces": []} | |
faces = [] | |
for i in range(bboxes.shape[0]): | |
bbox = bboxes[i, :4] | |
det_score = bboxes[i, 4] | |
kps = None | |
if kpss is not None: | |
kps = kpss[i] | |
face = Face(bbox=bbox, kps=kps, det_score=det_score) | |
recognition.get(image, face) | |
faces.append({ | |
"boundingBox": { | |
"topLeft": { | |
"x": face.bbox[0].item(), | |
"y": face.bbox[1].item() | |
}, | |
"bottomRight": { | |
"x": face.bbox[2].item(), | |
"y": face.bbox[3].item() | |
} | |
}, | |
"detectionScore": face.det_score.item(), | |
"keyPoints": [{"x": pt[0].item(), "y": pt[1].item()} for pt in face.kps], | |
"embedding": [v.item() for v in face.embedding.ravel()] | |
}) | |
return {"faces": faces} | |