MNCJihun's picture
add sample images
2a50db8
import os
import gradio as gr
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import urllib.request
test_transforms = A.Compose(
[
A.SmallestMaxSize(max_size=350),
A.CenterCrop(height=256, width=256),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
]
)
print(os.getcwd())
print(os.listdir(os.getcwd()))
img_samples = os.listdir('./sample/')
# print(os.path.isfile(os.path.join(os.getcwd(),'sample', img_samples[0])))
# assert False
img_samples = [os.path.join(os.getcwd(), './sample/', img) for img in img_samples]
MODEL_URL = "https://huggingface.co/caisarl76/HI_motorcycle_trunk_cls_model/resolve/main/best_model.pth"
MODEL_PATH = "/tmp/best_model.pth"
urllib.request.urlretrieve(MODEL_URL, MODEL_PATH)
def predict(img):
img = Image.fromarray(img.astype('uint8'), 'RGB')
img = transforms.ToTensor()(img).unsqueeze(0)
with torch.no_grad():
pred = torch.nn.functional.softmax(model(img)[0], dim=0)
return {labels[i]: float(pred[i]) for i in range(2)}
labels = ['no_trunk', 'trunk']
model = models.resnet50(pretrained=False)
model.fc = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(model.fc.in_features, 2)
)
device = torch.device('cpu')
try:
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
print('model load complete')
except:
print('CANNOT load model weight')
model.eval()
for _, p in model.named_parameters():
p.requires_grad = False
inputs = gr.inputs.Image()
outputs = gr.outputs.Label(num_top_classes=1)
gr.Interface(fn=predict,
inputs=inputs,
outputs=outputs,
examples=img_samples).launch()