MedImageInsight: Open-Source Medical Image Embedding Model
This repository provides a simplified implementation for using the MedImageInsight model, an open-source medical imaging embedding model presented in the paper MedImageInsight: An Open-Source Embedding Model for General Domain Medical Imaging by Noel C. F. Codella et al. The official guide to access the model from Microsoft is quite complicated, and it is arguable whether the model is truly open-source. This repository aims to make it easier to use the MedImageInsight model for various tasks, such as zero-shot classification, image embedding, and text embedding.
What we have done:
- Downloaded the models from Azure
- Got rid of all the unnecessary files
- Got rid of unnecessary MLflow code to make a standalone implementation
- Moved to uv for dependency management
- Added multi-label classification
- Created an example with the FastAPI service
Usage
- Clone the repository and navigate to the project directory.
Make sure you have git-lfs installed (https://git-lfs.com)
git lfs install
git clone https://huggingface.co/lion-ai/MedImageInsights
- Install the required dependencies We are using uv package manager to simplify the installation.
To create a virtual env, simply run:
uv sync
Or to run a single script, just run:
uv run example.py
That's it!
Examples
See to the example.py
file.
Zero-shot image classification
Here's an example of how to use the MedImageInsight
class for zero-shot classification:
# Initialize classifier
classifier = MedImageInsight(
model_dir="2024.09.27",
vision_model_name="medimageinsigt-v1.0.0.pt",
language_model_name="language_model.pth"
)
# Load model
classifier.load_model()
# Read image
image = base64.encodebytes(read_image("image.png")).decode("utf-8")
# Zero-shot classification
images = [image]
labels = ["normal", "Pneumonia", "unclear"]
results = classifier.predict(images, labels)
print(results)
Multi-label zero-shot image classification
Run multi-label image classification (without softmax at the end)
# Multilabel classification example
images = [image]
labels = ["normal", "Pneumonia", "Fracture", "Tumor"]
results = classifier.predict(images, labels, multilabel=True)
print(results)
Image embeddings
results = classifier.encode(images=images)
print(results["image_embeddings"])
Text embeddings
results = classifier.encode(texts=labels)
print(results["text_embeddings"])
FastAPI server
uv run fastapi_app.py
Go to localhost:8000/docs to see the swagger.
The application provides endpoints for classification and image embeddings. Images have to be base64 encoded.
Roadmap
- Basic implementation
- Multilabel classification
- FastAPI service
- HF compatible API (from_pretrained())
- Explainability
Acknowledgments
This repository is based on the work presented in the paper "MedImageInsight: An Open-Source Embedding Model for General Domain Medical Imaging" by Noel C. F. Codella et al. (arXiv:2410.06542).