|
import torch |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
import json |
|
import gradio as gr |
|
from model import resnet50 |
|
|
|
|
|
def load_class_labels(label_path): |
|
with open(label_path, 'r') as f: |
|
class_labels = json.load(f) |
|
return class_labels |
|
|
|
|
|
def load_model(model_path): |
|
model = resnet50() |
|
state_dict = torch.load(model_path, map_location='cpu') |
|
|
|
|
|
if 'module.' in next(iter(state_dict.keys())): |
|
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} |
|
|
|
model.load_state_dict(state_dict) |
|
model.eval() |
|
return model |
|
|
|
|
|
def preprocess_image(image): |
|
preprocess = transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
image = image.convert("RGB") |
|
image = preprocess(image) |
|
image = image.unsqueeze(0) |
|
return image |
|
|
|
|
|
def predict(image): |
|
image_tensor = preprocess_image(image) |
|
with torch.no_grad(): |
|
output = model(image_tensor) |
|
probabilities = torch.nn.functional.softmax(output, dim=1) |
|
top5_probabilities, top5_indices = probabilities.topk(5) |
|
|
|
results = {} |
|
for i in range(5): |
|
class_index = top5_indices[0][i].item() |
|
class_label = class_labels.get(str(class_index), "Unknown class") |
|
results[class_label] = top5_probabilities[0][i].item() |
|
|
|
return results |
|
|
|
|
|
model_path = 'model.pt' |
|
label_path = 'imagenet_class_index.json' |
|
model = load_model(model_path) |
|
class_labels = load_class_labels(label_path) |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Label(num_top_classes=5), |
|
title="Image Classification using ResNet-50 Model", |
|
description="Upload an image to get the top-5 predictions from the ResNet-50 model." |
|
) |
|
|
|
|
|
iface.launch() |