TheDemond commited on
Commit
9c77129
·
verified ·
1 Parent(s): 982ffbb

Add app files

Browse files
Files changed (4) hide show
  1. app.py +91 -0
  2. myCNN.bin +3 -0
  3. myCNN_states.pt +3 -0
  4. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import gradio as gr
4
+ from torchvision.transforms import Resize, ToTensor, Compose
5
+ from torch.nn.functional import softmax
6
+
7
+ class myCNN(nn.Module):
8
+ def __init__(self, input_channels, classes) -> None:
9
+ super().__init__()
10
+ self.layer1 = nn.Sequential(nn.Conv2d(in_channels=input_channels, out_channels=64, kernel_size=(3,3), padding='valid', bias=False),
11
+ nn.BatchNorm2d(num_features=64),
12
+ nn.ReLU())
13
+ self.layer2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3,3), padding='valid', bias=False),
14
+ nn.BatchNorm2d(num_features=64),
15
+ nn.ReLU())
16
+ self.layer3 = nn.Sequential(nn.MaxPool2d((2,2)),
17
+ nn.Dropout2d(0.4))
18
+
19
+ self.layer4 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3,3), padding='valid', bias=False),
20
+ nn.BatchNorm2d(num_features=128),
21
+ nn.ReLU())
22
+ self.layer5 = nn.Sequential(nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3,3), padding='valid', bias=False),
23
+ nn.BatchNorm2d(num_features=128),
24
+ nn.ReLU())
25
+ self.layer6 = nn.Sequential(nn.MaxPool2d((2,2)),
26
+ nn.Dropout2d(0.4))
27
+ self.flat = nn.Flatten()
28
+
29
+ self.fc1 = nn.Sequential(nn.Linear(3200, 512),
30
+ nn.ReLU(),
31
+ nn.Dropout1d(0.5))
32
+
33
+ self.fc2 = nn.Sequential(nn.Linear(512, 256),
34
+ nn.ReLU())
35
+
36
+ self.fc3 = nn.Linear(256, classes)
37
+
38
+ def forward(self, x):
39
+ layer1 = self.layer1(x)
40
+ layer2 = self.layer2(layer1)
41
+ layer3 = self.layer3(layer2)
42
+ layer4 = self.layer4(layer3)
43
+ layer5 = self.layer5(layer4)
44
+ layer6 = self.layer6(layer5)
45
+ flat = self.flat(layer6)
46
+ fc1 = self.fc1(flat)
47
+ fc2 = self.fc2(fc1)
48
+ fc3 = self.fc3(fc2)
49
+ return fc3
50
+
51
+
52
+ device = 'gpu' if torch.cuda.is_available() else 'cpu'
53
+
54
+ model_state = torch.load("myCNN_states.pt", map_location=device, weights_only=False)
55
+ input_shape = model_state['input_shape']
56
+ cls_to_idx = model_state['labels_encoder']
57
+ idx_to_cls = {value:key for key,value in cls_to_idx.items()}
58
+
59
+ pre_processor = Compose([Resize(input_shape[1:]),
60
+ ToTensor()])
61
+
62
+ model = torch.load("myCNN.bin",
63
+ map_location=device,
64
+ weights_only=False)
65
+
66
+ def post_processor(raw_output):
67
+ softmax_output = softmax(raw_output, -1)
68
+ values, indices = torch.max(softmax_output, -1)
69
+ return idx_to_cls[indices.item()].capitalize(), round(values.item(), 2)
70
+
71
+
72
+ @torch.no_grad
73
+ def lunch(raw_input):
74
+ input = pre_processor(raw_input)
75
+ output = model(input.unsqueeze(0))
76
+ return post_processor(output)
77
+
78
+ custom_css ='.gr-button {background-color: #bf4b04; color: white;}'
79
+
80
+ with gr.Blocks(css=custom_css) as demo:
81
+ with gr.Row():
82
+ with gr.Column():
83
+ input_image = gr.Image(type="pil", label='Input Image')
84
+ gr.Text("Airplane, Automobile, Bird, Cat, Deer, Dog, Frog, Horse, Ship, Truck", label="Supported Classes:")
85
+ with gr.Column():
86
+ class_name = gr.Textbox(label="This is (a\\an)")
87
+ confidence = gr.Textbox(label='Confidence')
88
+ start_btn = gr.Button(value='Submit', elem_classes=["gr-button"])
89
+ start_btn.click(fn=lunch, inputs=input_image, outputs=[class_name, confidence])
90
+
91
+ demo.launch()
myCNN.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2510e5ef54bc911f9daf0d8efabe79b2ba1aca62a8e988ba5aca292826e59b01
3
+ size 8152632
myCNN_states.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c54e6595788d3bb34e76db6c064e200a28e7fe1e7d0594ab724363944adfd7e
3
+ size 24429714
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio