Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import albumentations as A | |
from albumentations.pytorch import ToTensorV2 | |
from PIL import Image | |
import torch | |
import torch.nn as nn | |
import torchvision.models as models | |
import torchvision.transforms as transforms | |
import urllib.request | |
test_transforms = A.Compose( | |
[ | |
A.SmallestMaxSize(max_size=350), | |
A.CenterCrop(height=256, width=256), | |
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
ToTensorV2(), | |
] | |
) | |
print(os.getcwd()) | |
print(os.listdir(os.getcwd())) | |
img_samples = os.listdir('./sample/') | |
# print(os.path.isfile(os.path.join(os.getcwd(),'sample', img_samples[0]))) | |
# assert False | |
img_samples = [os.path.join(os.getcwd(), './sample/', img) for img in img_samples] | |
MODEL_URL = "https://huggingface.co/caisarl76/HI_motorcycle_trunk_cls_model/resolve/main/best_model.pth" | |
MODEL_PATH = "/tmp/best_model.pth" | |
urllib.request.urlretrieve(MODEL_URL, MODEL_PATH) | |
def predict(img): | |
img = Image.fromarray(img.astype('uint8'), 'RGB') | |
img = transforms.ToTensor()(img).unsqueeze(0) | |
with torch.no_grad(): | |
pred = torch.nn.functional.softmax(model(img)[0], dim=0) | |
return {labels[i]: float(pred[i]) for i in range(2)} | |
labels = ['no_trunk', 'trunk'] | |
model = models.resnet50(pretrained=False) | |
model.fc = nn.Sequential( | |
nn.Dropout(0.5), | |
nn.Linear(model.fc.in_features, 2) | |
) | |
device = torch.device('cpu') | |
try: | |
model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) | |
print('model load complete') | |
except: | |
print('CANNOT load model weight') | |
model.eval() | |
for _, p in model.named_parameters(): | |
p.requires_grad = False | |
inputs = gr.inputs.Image() | |
outputs = gr.outputs.Label(num_top_classes=1) | |
gr.Interface(fn=predict, | |
inputs=inputs, | |
outputs=outputs, | |
examples=img_samples).launch() | |