MilindChawre's picture
Adding code for hosting huggingface app for resnet50 image classifier
12dc8e3
import torch
import torchvision.transforms as transforms
from PIL import Image
import json
import gradio as gr
from model import resnet50 # Import the model definition
# Load class index mapping
def load_class_labels(label_path):
with open(label_path, 'r') as f:
class_labels = json.load(f)
return class_labels
# Load the model
def load_model(model_path):
model = resnet50() # Create an instance of the ResNet-50 model
state_dict = torch.load(model_path, map_location='cpu') # Load the state dictionary
# Remove "module." prefix if present
if 'module.' in next(iter(state_dict.keys())):
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict) # Load the state dictionary
model.eval() # Set the model to evaluation mode
return model
# Preprocess the input image
def preprocess_image(image):
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image = image.convert("RGB") # Ensure image is in RGB format
image = preprocess(image) # Apply transformations
image = image.unsqueeze(0) # Add batch dimension
return image
# Prediction function
def predict(image):
image_tensor = preprocess_image(image)
with torch.no_grad():
output = model(image_tensor)
probabilities = torch.nn.functional.softmax(output, dim=1)
top5_probabilities, top5_indices = probabilities.topk(5)
results = {}
for i in range(5):
class_index = top5_indices[0][i].item()
class_label = class_labels.get(str(class_index), "Unknown class")
results[class_label] = top5_probabilities[0][i].item() # Store label and probability in a dictionary
return results # Return the results as a dictionary
# Load model and class labels
model_path = 'model.pt' # Path to the trained model
label_path = 'imagenet_class_index.json' # Path to the class index mapping
model = load_model(model_path)
class_labels = load_class_labels(label_path)
# Create Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=5),
title="Image Classification using ResNet-50 Model",
description="Upload an image to get the top-5 predictions from the ResNet-50 model."
)
# Launch the app
iface.launch()