sitammeur commited on
Commit
4117ebf
·
verified ·
1 Parent(s): 66f3482

Upload 10 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ images/baklava.png filter=lfs diff=lfs merge=lfs -text
37
+ images/beignets.png filter=lfs diff=lfs merge=lfs -text
38
+ images/cat.png filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Necessary imports
2
+ import gradio as gr
3
+ from src.siglip.classifier import ZeroShotImageClassification
4
+
5
+
6
+ # Examples to display in the interface
7
+ examples = [
8
+ [
9
+ "images/baklava.png",
10
+ "dessert on a plate, a serving of baklava, a plate and spoon",
11
+ ],
12
+ [
13
+ "images/beignets.png",
14
+ "a dog, a cat, a donut, a beignet",
15
+ ],
16
+ [
17
+ "images/cat.png",
18
+ "two sleeping cats, two cats playing, three cats laying down",
19
+ ],
20
+ ]
21
+
22
+ # Title and description and article for the interface
23
+ title = "Zero Shot Image Classification"
24
+ description = "Classify image using zero-shot classification with SigLIP 2 zeroshot model! Provide an image input and a list of candidate labels separated by commas. Read more at the links below."
25
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2502.14786' target='_blank'>SigLIP 2: Multilingual Vision-Language Encoders with Improved Semantic Understanding, Localization, and Dense Features</a> | <a href='https://huggingface.co/google/siglip2-so400m-patch16-naflex' target='_blank'>Model Page</a></p>"
26
+
27
+
28
+ # Launch the interface
29
+ demo = gr.Interface(
30
+ fn=ZeroShotImageClassification,
31
+ inputs=[
32
+ gr.Image(type="pil", label="Input", placeholder="Enter image to classify"),
33
+ gr.Textbox(
34
+ label="Candidate Labels",
35
+ placeholder="Enter candidate labels separated by commas",
36
+ ),
37
+ ],
38
+ outputs=gr.Label(label="Classification"),
39
+ title=title,
40
+ description=description,
41
+ article=article,
42
+ examples=examples,
43
+ cache_examples=True,
44
+ cache_mode="lazy",
45
+ theme="Soft",
46
+ flagging_mode="never",
47
+ )
48
+ demo.launch(debug=False)
images/baklava.png ADDED

Git LFS Details

  • SHA256: 14f498fd5108dca3138206fc69e7381f0c12168b536a2e6262ebb2105f55c42e
  • Pointer size: 132 Bytes
  • Size of remote file: 3.76 MB
images/beignets.png ADDED

Git LFS Details

  • SHA256: 0ef0c3f7caa3ad4ec92dc8968f9ac0e81ede5b4805aae89648a4c7c4236768df
  • Pointer size: 131 Bytes
  • Size of remote file: 357 kB
images/cat.png ADDED

Git LFS Details

  • SHA256: 66d102316c941636ccbabbe892ee8905f62345ab3951288089386f901fea1a83
  • Pointer size: 131 Bytes
  • Size of remote file: 879 kB
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ git+https://github.com/huggingface/transformers@main
4
+ sentencepiece
5
+ pillow
6
+ protobuf
7
+ accelerate
src/__init__.py ADDED
File without changes
src/exception.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module defines a custom exception handling class and a function to get error message with details of the error.
3
+ """
4
+
5
+ # Standard Library
6
+ import sys
7
+
8
+ # Local imports
9
+ from src.logger import logging
10
+
11
+
12
+ # Function Definition to get error message with details of the error (file name and line number) when an error occurs in the program
13
+ def get_error_message(error, error_detail: sys):
14
+ """
15
+ Get error message with details of the error.
16
+
17
+ Args:
18
+ - error (Exception): The error that occurred.
19
+ - error_detail (sys): The details of the error.
20
+
21
+ Returns:
22
+ str: A string containing the error message along with the file name and line number where the error occurred.
23
+ """
24
+ _, _, exc_tb = error_detail.exc_info()
25
+
26
+ # Get error details
27
+ file_name = exc_tb.tb_frame.f_code.co_filename
28
+ return "Error occured in python script name [{0}] line number [{1}] error message[{2}]".format(
29
+ file_name, exc_tb.tb_lineno, str(error)
30
+ )
31
+
32
+
33
+ # Custom Exception Handling Class Definition
34
+ class CustomExceptionHandling(Exception):
35
+ """
36
+ Custom Exception Handling:
37
+ This class defines a custom exception that can be raised when an error occurs in the program.
38
+ It takes an error message and an error detail as input and returns a formatted error message when the exception is raised.
39
+ """
40
+
41
+ # Constructor
42
+ def __init__(self, error_message, error_detail: sys):
43
+ """Initialize the exception"""
44
+ super().__init__(error_message)
45
+
46
+ self.error_message = get_error_message(error_message, error_detail=error_detail)
47
+
48
+ def __str__(self):
49
+ """String representation of the exception"""
50
+ return self.error_message
src/logger.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Importing the required modules
2
+ import os
3
+ import logging
4
+ from datetime import datetime
5
+
6
+ # Creating a log file with the current date and time as the name of the file
7
+ LOG_FILE = f"{datetime.now().strftime('%m_%d_%Y_%H_%M_%S')}.log"
8
+
9
+ # Creating a logs folder if it does not exist
10
+ logs_path = os.path.join(os.getcwd(), "logs", LOG_FILE)
11
+ os.makedirs(logs_path, exist_ok=True)
12
+
13
+ # Setting the log file path and the log level
14
+ LOG_FILE_PATH = os.path.join(logs_path, LOG_FILE)
15
+
16
+ # Configuring the logger
17
+ logging.basicConfig(
18
+ filename=LOG_FILE_PATH,
19
+ format="[ %(asctime)s ] %(lineno)d %(name)s - %(levelname)s - %(message)s",
20
+ level=logging.INFO,
21
+ )
src/siglip/__init__.py ADDED
File without changes
src/siglip/classifier.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Necessary imports
2
+ import sys
3
+ from typing import Dict
4
+ import torch
5
+ from transformers import AutoModel, AutoProcessor
6
+ import gradio as gr
7
+
8
+ # Local imports
9
+ from src.logger import logging
10
+ from src.exception import CustomExceptionHandling
11
+
12
+
13
+ # Load the zero-shot image classification model
14
+ model_id = "google/siglip2-so400m-patch16-naflex"
15
+ model = AutoModel.from_pretrained(model_id, device_map="cpu").eval()
16
+ processor = AutoProcessor.from_pretrained(model_id)
17
+
18
+
19
+ def ZeroShotImageClassification(
20
+ image_input: str, candidate_labels: str
21
+ ) -> Dict[str, float]:
22
+ """
23
+ Performs zero-shot classification on the given image input and candidate labels.
24
+
25
+ Args:
26
+ - image_input: The input image to classify.
27
+ - candidate_labels: A comma-separated string of candidate labels.
28
+
29
+ Returns:
30
+ Dictionary containing label-score pairs.
31
+ """
32
+ try:
33
+ # Check if the input and candidate labels are valid
34
+ if not image_input or not candidate_labels:
35
+ gr.Warning("Please provide valid input and candidate labels")
36
+
37
+ # Split and clean the candidate labels
38
+ labels = [label.strip() for label in candidate_labels.split(",")]
39
+
40
+ # Log the classification attempt
41
+ logging.info(f"Attempting classification with {len(labels)} labels")
42
+
43
+ # Perform zero-shot image classification
44
+ inputs = processor(
45
+ text=labels,
46
+ images=image_input,
47
+ return_tensors="pt",
48
+ padding="max_length",
49
+ max_length=64,
50
+ ).to("cpu")
51
+ with torch.no_grad():
52
+ outputs = model(**inputs)
53
+ logits_per_image = outputs.logits_per_image
54
+ probs = torch.sigmoid(logits_per_image)
55
+
56
+ # Return the classification results
57
+ logging.info("Classification completed successfully")
58
+ return {labels[i]: probs[0][i] for i in range(len(labels))}
59
+
60
+ # Handle exceptions that may occur during the process
61
+ except Exception as e:
62
+ # Custom exception handling
63
+ raise CustomExceptionHandling(e, sys) from e