import random import time import numpy as np import torch import torch.backends.cudnn as cudnn import matplotlib.pyplot as plt from glob import glob from PIL import Image from model.load_model import get_model from torchvision import transforms from pytorch_grad_cam import GradCAM, GuidedBackpropReLUModel from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from pytorch_grad_cam.utils.image import show_cam_on_image, deprocess_image from ultralytics import YOLO # from rembg import remove import uuid # Static variables model_path = "efficientnet-b0-best.pth" model_name = "efficientnet_b0" YOLO_MODEL_WEIGHTS = "yolo-v11-best.pt" classes = ["Healthy", "Resistant", "Susceptible"] resizing_transforms = transforms.Compose([transforms.CenterCrop(224)]) # Function definitions def reproduce(seed=42): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) cudnn.deterministic = True cudnn.benchmark = False def get_grad_cam_results(image, transformed_image, class_index=0): with GradCAM(model=model, target_layers=target_layers) as cam: targets = [ClassifierOutputTarget(class_index)] grayscale_cam = cam( input_tensor=transformed_image.unsqueeze(0), targets=targets ) grayscale_cam = grayscale_cam[0, :] visualization = show_cam_on_image( np.array(image) / 255.0, grayscale_cam, use_rgb=True ) return visualization, grayscale_cam def get_backpropagation_results(transformed_image, class_index=0): transformed_image = transformed_image.unsqueeze(0) backpropagation = gbp_model(transformed_image, target_category=class_index) bp_deprocessed = deprocess_image(backpropagation) return backpropagation, bp_deprocessed def get_guided_gradcam(image, cam_grayscale, bp): cam_mask = np.expand_dims(cam_grayscale, axis=-1) cam_mask = np.repeat(cam_mask, 3, axis=-1) img = show_cam_on_image( np.array(image) / 255.0, deprocess_image(cam_mask * bp), use_rgb=False ) return img def explain_results(image, class_index=0): transformed_image = image_transform(image) image = resizing_transforms(image) visualization, cam_mask = get_grad_cam_results( image, transformed_image, class_index ) backpropagation, bp_deprocessed = get_backpropagation_results( transformed_image, class_index ) guided_gradcam = get_guided_gradcam(image, cam_mask, backpropagation) return visualization, bp_deprocessed, guided_gradcam def make_prediction_and_explain(image): transformed_image = image_transform(image) transformed_image = transformed_image.unsqueeze(0) model.eval() with torch.no_grad(): output = model(transformed_image) output = torch.nn.functional.softmax(output, dim=1) predictions = [round(x, 4) * 100 for x in output[0].tolist()] results = {} for i, k in enumerate(classes): gradcam, bp_deprocessed, guided_gradcam = explain_results(image, class_index=i) results[k] = { "original_image": image, "prediction": f"{k} ({predictions[i]}%)", "gradcam": gradcam, "backpropagation": bp_deprocessed, "guided_gradcam": guided_gradcam, } return results def save_explanation_results(res, path): fig, ax = plt.subplots(3, 4, figsize=(15, 15)) for i, (k, v) in enumerate(res.items()): ax[i, 0].imshow(v["original_image"]) ax[i, 0].set_title(f"Original Image (class: {v['prediction']}") ax[i, 0].axis("off") ax[i, 1].imshow(v["gradcam"]) ax[i, 1].set_title("GradCAM") ax[i, 1].axis("off") ax[i, 2].imshow(v["backpropagation"]) ax[i, 2].set_title("Backpropagation") ax[i, 2].axis("off") ax[i, 3].imshow(v["guided_gradcam"]) ax[i, 3].set_title("Guided GradCAM") ax[i, 3].axis("off") plt.tight_layout() plt.savefig(path, bbox_inches="tight") plt.close(fig) model, image_transform = get_model(model_name) model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) model.train() target_layers = [model.conv_head] gbp_model = GuidedBackpropReLUModel(model=model, device="cpu") yolo_model = YOLO(YOLO_MODEL_WEIGHTS) def get_results(img_path=None, img_for_testing=None, od=False): if img_path is None and img_for_testing is None: raise ValueError("Either img_path or img_for_testing should be provided.") if img_path is not None: image = Image.open(img_path) if img_for_testing is not None: image = Image.fromarray(img_for_testing) result_paths = [] if od: results = yolo_model(img_path if img_path else img_for_testing) for i, result in enumerate(results): unique_id = uuid.uuid4().hex save_path = f"/tmp/with-bg-result-{unique_id}.png" bbox = result.boxes.xyxy[0].cpu().numpy().astype(int) bbox_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3])) # bbox_image = remove(bbox_image).convert("RGB") # bbox_image = Image.fromarray( # np.where( # np.array(bbox_image) == [0, 0, 0], # [255, 255, 255], # np.array(bbox_image), # ).astype(np.uint8) # ) res = make_prediction_and_explain(bbox_image) save_explanation_results(res, save_path) result_paths.append(save_path) else: unique_id = uuid.uuid4().hex save_path = f"/tmp/with-bg-result-{unique_id}.png" res = make_prediction_and_explain(image) save_explanation_results(res, save_path) result_paths.append(save_path) return result_paths if __name__ == "__main__": # Actual logic reproduce() model, image_transform = get_model(model_name) model.load_state_dict(torch.load(model_path)) model.train() target_layers = [model.conv_head] gbp_model = GuidedBackpropReLUModel(model=model, device="cpu") yolo_model = YOLO(YOLO_MODEL_WEIGHTS) for IMAGE_PATH in glob("samples/*"): start = time.perf_counter() results = yolo_model(IMAGE_PATH) image = Image.open(IMAGE_PATH) for i, result in enumerate(results): save_path = IMAGE_PATH.replace( "samples/", f"sample-results/with-white-bg-result-{i:02d}-" ) bbox = result.boxes.xyxy[0].cpu().numpy().astype(int) bbox_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3])) # bbox_image = remove(bbox_image).convert("RGB") # bbox_image = Image.fromarray( # np.where( # np.array(bbox_image) == [0, 0, 0], # [255, 255, 255], # np.array(bbox_image), # ).astype(np.uint8) # ) res = make_prediction_and_explain(bbox_image) save_explanation_results(res, save_path) end = time.perf_counter() - start print(f"Completed in {end}s")