|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import urllib |
|
import warnings |
|
from typing import Tuple |
|
|
|
import onnx |
|
import torch |
|
import torch.nn as nn |
|
from onnxruntime.quantization import QuantType |
|
from onnxruntime.quantization.quantize import quantize_dynamic |
|
from segment_anything import sam_model_registry |
|
from segment_anything.modeling import Sam |
|
from segment_anything.utils.amg import calculate_stability_score |
|
from torch.nn import functional as F |
|
|
|
CHECKPOINT_PATH = os.path.join(os.path.expanduser("~"), ".cache", "SAM") |
|
CHECKPOINT_NAME = "sam_vit_h_4b8939.pth" |
|
CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" |
|
MODEL_TYPE = "default" |
|
|
|
|
|
class SamOnnxModel(nn.Module): |
|
""" |
|
This model should not be called directly, but is used in ONNX export. |
|
It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, |
|
with some functions modified to enable model tracing. Also supports extra |
|
options controlling what information. See the ONNX export script for details. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model: Sam, |
|
return_single_mask: bool, |
|
use_stability_score: bool = False, |
|
return_extra_metrics: bool = False, |
|
) -> None: |
|
super().__init__() |
|
self.mask_decoder = model.mask_decoder |
|
self.model = model |
|
self.img_size = model.image_encoder.img_size |
|
self.return_single_mask = return_single_mask |
|
self.use_stability_score = use_stability_score |
|
self.stability_score_offset = 1.0 |
|
self.return_extra_metrics = return_extra_metrics |
|
|
|
@staticmethod |
|
def resize_longest_image_size( |
|
input_image_size: torch.Tensor, longest_side: int |
|
) -> torch.Tensor: |
|
input_image_size = input_image_size.to(torch.float32) |
|
scale = longest_side / torch.max(input_image_size) |
|
transformed_size = scale * input_image_size |
|
transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) |
|
return transformed_size |
|
|
|
def _embed_points( |
|
self, point_coords: torch.Tensor, point_labels: torch.Tensor |
|
) -> torch.Tensor: |
|
point_coords = point_coords + 0.5 |
|
point_coords = point_coords / self.img_size |
|
point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) |
|
point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) |
|
|
|
point_embedding = point_embedding * (point_labels != -1) |
|
point_embedding = ( |
|
point_embedding |
|
+ self.model.prompt_encoder.not_a_point_embed.weight * (point_labels == -1) |
|
) |
|
|
|
for i in range(self.model.prompt_encoder.num_point_embeddings): |
|
point_embedding = ( |
|
point_embedding |
|
+ self.model.prompt_encoder.point_embeddings[i].weight |
|
* (point_labels == i) |
|
) |
|
|
|
return point_embedding |
|
|
|
def _embed_masks( |
|
self, input_mask: torch.Tensor, has_mask_input: torch.Tensor |
|
) -> torch.Tensor: |
|
mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling( |
|
input_mask |
|
) |
|
mask_embedding = mask_embedding + ( |
|
1 - has_mask_input |
|
) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) |
|
return mask_embedding |
|
|
|
def mask_postprocessing( |
|
self, masks: torch.Tensor, orig_im_size: torch.Tensor |
|
) -> torch.Tensor: |
|
masks = F.interpolate( |
|
masks, |
|
size=(self.img_size, self.img_size), |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
|
|
prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to( |
|
torch.int64 |
|
) |
|
masks = masks[..., : prepadded_size[0], : prepadded_size[1]] |
|
|
|
orig_im_size = orig_im_size.to(torch.int64) |
|
h, w = orig_im_size[0], orig_im_size[1] |
|
masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) |
|
return masks |
|
|
|
def select_masks( |
|
self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
|
|
|
|
score_reweight = torch.tensor( |
|
[[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] |
|
).to(iou_preds.device) |
|
score = iou_preds + (num_points - 2.5) * score_reweight |
|
best_idx = torch.argmax(score, dim=1) |
|
masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) |
|
iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) |
|
|
|
return masks, iou_preds |
|
|
|
@torch.no_grad() |
|
def forward( |
|
self, |
|
image_embeddings: torch.Tensor, |
|
point_coords: torch.Tensor, |
|
point_labels: torch.Tensor, |
|
mask_input: torch.Tensor, |
|
has_mask_input: torch.Tensor, |
|
orig_im_size: torch.Tensor, |
|
): |
|
sparse_embedding = self._embed_points(point_coords, point_labels) |
|
dense_embedding = self._embed_masks(mask_input, has_mask_input) |
|
|
|
masks, scores = self.model.mask_decoder.predict_masks( |
|
image_embeddings=image_embeddings, |
|
image_pe=self.model.prompt_encoder.get_dense_pe(), |
|
sparse_prompt_embeddings=sparse_embedding, |
|
dense_prompt_embeddings=dense_embedding, |
|
) |
|
|
|
if self.use_stability_score: |
|
scores = calculate_stability_score( |
|
masks, self.model.mask_threshold, self.stability_score_offset |
|
) |
|
|
|
if self.return_single_mask: |
|
masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) |
|
|
|
upscaled_masks = self.mask_postprocessing(masks, orig_im_size) |
|
|
|
if self.return_extra_metrics: |
|
stability_scores = calculate_stability_score( |
|
upscaled_masks, self.model.mask_threshold, self.stability_score_offset |
|
) |
|
areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) |
|
return upscaled_masks, scores, stability_scores, areas, masks |
|
|
|
return upscaled_masks, scores, masks |
|
|
|
|
|
def load_model( |
|
checkpoint_path: str = CHECKPOINT_PATH, |
|
checkpoint_name: str = CHECKPOINT_NAME, |
|
checkpoint_url: str = CHECKPOINT_URL, |
|
model_type: str = MODEL_TYPE, |
|
) -> Sam: |
|
if not os.path.exists(checkpoint_path): |
|
os.makedirs(checkpoint_path) |
|
checkpoint = os.path.join(checkpoint_path, checkpoint_name) |
|
if not os.path.exists(checkpoint): |
|
print("Downloading the model weights...") |
|
urllib.request.urlretrieve(checkpoint_url, checkpoint) |
|
print(f"The model weights saved as {checkpoint}") |
|
print(f"Load the model weights from {checkpoint}") |
|
return sam_model_registry[model_type](checkpoint=checkpoint) |
|
|
|
|
|
if __name__ == "__main__": |
|
sam = load_model() |
|
onnx_model = SamOnnxModel(sam, return_single_mask=True) |
|
|
|
dynamic_axes = { |
|
"point_coords": {1: "num_points"}, |
|
"point_labels": {1: "num_points"}, |
|
} |
|
|
|
embed_dim = sam.prompt_encoder.embed_dim |
|
embed_size = sam.prompt_encoder.image_embedding_size |
|
mask_input_size = [4 * x for x in embed_size] |
|
dummy_inputs = { |
|
"image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float), |
|
"point_coords": torch.randint( |
|
low=0, high=1024, size=(1, 5, 2), dtype=torch.float |
|
), |
|
"point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float), |
|
"mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float), |
|
"has_mask_input": torch.tensor([1], dtype=torch.float), |
|
"orig_im_size": torch.tensor([1500, 2250], dtype=torch.float), |
|
} |
|
output_names = ["masks", "iou_predictions", "low_res_masks"] |
|
|
|
with warnings.catch_warnings(): |
|
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) |
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
torch.onnx.export( |
|
onnx_model, |
|
tuple(dummy_inputs.values()), |
|
"sam_decoder.onnx", |
|
export_params=True, |
|
opset_version=17, |
|
do_constant_folding=True, |
|
input_names=list(dummy_inputs.keys()), |
|
output_names=output_names, |
|
dynamic_axes=dynamic_axes, |
|
) |
|
|
|
quantize_dynamic( |
|
model_input="sam_decoder.onnx", |
|
model_output="sam_decoder_uint8.onnx", |
|
optimize_model=True, |
|
per_channel=False, |
|
reduce_range=False, |
|
weight_type=QuantType.QUInt8, |
|
) |
|
|
|
|
|
onnx.checker.check_model("sam_decoder_uint8.onnx") |
|
|