NemesisAlm commited on
Commit
0b0d380
·
1 Parent(s): dd4fb66

1st commit

Browse files
Files changed (8) hide show
  1. 0.jpg +0 -0
  2. 1.jpg +0 -0
  3. 2.jpg +0 -0
  4. 3.jpg +0 -0
  5. app.py +87 -0
  6. favicon.ico +0 -0
  7. logo_gradio.png +0 -0
  8. requirements.txt +4 -0
0.jpg ADDED
1.jpg ADDED
2.jpg ADDED
3.jpg ADDED
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import CLIPModel, CLIPProcessor
6
+
7
+ LIST_LABELS = ['agricultural land', 'airplane', 'baseball diamond', 'beach', 'buildings', 'chaparral', 'dense residential area', 'forest', 'freeway', 'golf course', 'harbor', 'intersection', 'medium residential area', 'mobilehome park', 'overpass', 'parking lot', 'river', 'runway', 'sparse residential area', 'storage tanks', 'tennis court']
8
+
9
+ CLIP_LABELS = [f"A satellite image of {label}" for label in LIST_LABELS]
10
+
11
+ MODEL_NAME = "NemesisAlm/clip-fine-tuned-satellite"
12
+
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
15
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
16
+
17
+ fine_tuned_model = CLIPModel.from_pretrained(MODEL_NAME).to(device)
18
+ fine_tuned_processor = CLIPProcessor.from_pretrained(MODEL_NAME)
19
+
20
+
21
+ def classify(image_path, model_number):
22
+ if model_number == "CLIP":
23
+ processor = clip_processor
24
+ model = clip_model
25
+ else:
26
+ processor = fine_tuned_processor
27
+ model = fine_tuned_model
28
+ image = Image.open(image_path).convert('RGB')
29
+ inputs = processor(text=CLIP_LABELS, images=image, return_tensors="pt", padding=True).to(device)
30
+ with torch.no_grad():
31
+ outputs = model(**inputs)
32
+ logits_per_image = outputs.logits_per_image
33
+ prediction = logits_per_image.softmax(dim=1)
34
+ confidences = {LIST_LABELS[i]: float(prediction[0][i].item()) for i in range(len(LIST_LABELS))}
35
+ return confidences
36
+
37
+ DESCRIPTION="""
38
+ <div style="font-family: Arial, sans-serif; line-height: 1.6; margin: auto; text-align: center;">
39
+ <h2 style="color: #333;">CLIP Fine-Tuned Satellite Model Demo</h2>
40
+ <p>
41
+ This space demonstrates the capabilities of a <strong>fine-tuned CLIP-based model</strong>
42
+ in classifying satellite images. The model has been specifically trained on the
43
+ <em>UC Merced</em> satellite image dataset.
44
+ </p>
45
+ <p>
46
+ After just <strong>2 epochs of training</strong>, adjusting only 30% of the model parameters,
47
+ the model's accuracy in classifying satellite images has significantly improved, from an
48
+ initial accuracy of <strong>58.8%</strong> to <strong>96.9%</strong> on the test set.
49
+ </p>
50
+ <p>
51
+ Explore this space to see its performance and compare it with the initial CLIP model.
52
+ </p>
53
+ </div>
54
+ """
55
+
56
+ FOOTER = """
57
+ <div style="margin-top:50px">
58
+ Link to model: <a href='https://huggingface.co/NemesisAlm/clip-fine-tuned-satellite'>https://huggingface.co/NemesisAlm/clip-fine-tuned-satellite</a><br>
59
+ Link to dataset: <a href='https://huggingface.co/datasets/blanchon/UC_Merced'>https://huggingface.co/datasets/blanchon/UC_Merced</a>
60
+ </div>
61
+ """
62
+
63
+ with gr.Blocks(title="Satellite image classification", css="") as demo:
64
+ logo = gr.HTML("<img src='file/logo_gradio.png' style='margin:auto'/>")
65
+ description = gr.HTML(DESCRIPTION)
66
+ with gr.Row():
67
+ with gr.Column():
68
+ input_image = gr.Image(type='filepath', label='Input image')
69
+ submit_btn = gr.Button("Submit", variant="primary")
70
+ with gr.Column():
71
+ title_1 = gr.HTML("<h1 style='text-align:center'>Original CLIP Model</h1>")
72
+ model_1 = gr.Textbox("CLIP", visible=False)
73
+ output_labels_clip = gr.Label(num_top_classes=10, label="Top 10 classes")
74
+ with gr.Column():
75
+ title_2 = gr.HTML("<h1 style='text-align:center'>Fine-tuned Model</h1>")
76
+ model_2 = gr.Textbox("Fine-tuned", visible=False)
77
+ output_labels_finetuned = gr.Label(num_top_classes=10, label="Top 10 classes")
78
+ examples = gr.Examples([["0.jpg"], ["1.jpg"], ["2.jpg"], ["3.jpg"] ], input_image)
79
+ footer = gr.HTML(FOOTER)
80
+ submit_btn.click(fn=classify, inputs=[input_image, model_1], outputs=output_labels_clip).then( classify, inputs=[input_image, model_2], outputs=[output_labels_finetuned] )
81
+
82
+
83
+ demo.queue()
84
+ demo.launch(server_name="0.0.0.0",favicon_path='favicon.ico', allowed_paths=["logo_gradio.png", "0.jpg", "1.jpg", "2.jpg", "3.jpg"])
85
+
86
+
87
+
favicon.ico ADDED
logo_gradio.png ADDED
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers
2
+ Pillow
3
+ torch
4
+ gradio