MNCJihun commited on
Commit
25322fb
·
0 Parent(s):
Files changed (12) hide show
  1. .gitattributes +34 -0
  2. .gitignore +5 -0
  3. Dockerfile +19 -0
  4. README.md +12 -0
  5. app.py +3 -0
  6. binary_classification.ipynb +558 -0
  7. classification.ipynb +632 -0
  8. dataset.py +96 -0
  9. main.py +119 -0
  10. requirements.txt +7 -0
  11. test.py +89 -0
  12. utils.py +12 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ data/
2
+ __pycache__/
3
+ __MACOSX/
4
+ dataset.zip
5
+ *.pth
Dockerfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel
2
+ RUN apt-key adv --keyserver keyserver.ubuntu.com --recv-keys A4B469963BF863CC
3
+ RUN apt-get update && \
4
+ apt-get upgrade -y && \
5
+ apt-get install -y git
6
+ RUN apt-get -y install libgl1-mesa-glx libglib2.0-0
7
+ RUN apt-get -y install vim byobu aria2
8
+
9
+ COPY . /usr/src/motorbike_cls
10
+ # RUN ls /usr/src/motorbike_cls
11
+ RUN cd /usr/src/motorbike_cls
12
+
13
+ WORKDIR /usr/src/motorbike_cls
14
+
15
+ RUN pip install --no-cache-dir --upgrade pip && \
16
+ pip install -r /usr/src/motorbike_cls/requirements.txt
17
+
18
+ CMD ["test.py"]
19
+ ENTRYPOINT ["python3"]
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: HI Motorcycle Trunk Cls
3
+ emoji: 👀
4
+ colorFrom: gray
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 3.24.1
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import os
2
+
3
+ print(os.getcwd())
binary_classification.ipynb ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "%matplotlib inline\n",
10
+ "%config InlineBackend.figure_format = 'retina'\n",
11
+ "\n",
12
+ "import os\n",
13
+ "import matplotlib.pyplot as plt\n",
14
+ "from pandas.core.common import flatten\n",
15
+ "import copy\n",
16
+ "import numpy as np\n",
17
+ "import random\n",
18
+ "\n",
19
+ "import torch\n",
20
+ "from torch import nn\n",
21
+ "from torch import optim\n",
22
+ "import torch.nn.functional as F\n",
23
+ "from torchvision import datasets, transforms, models\n",
24
+ "from torch.utils.data import Dataset, DataLoader\n",
25
+ "import torch.nn as nn\n",
26
+ "import albumentations as A\n",
27
+ "from albumentations.pytorch import ToTensorV2\n",
28
+ "import cv2\n",
29
+ "\n",
30
+ "import glob\n",
31
+ "from tqdm import tqdm\n",
32
+ "import random"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": 2,
38
+ "metadata": {},
39
+ "outputs": [],
40
+ "source": [
41
+ "#######################################################\n",
42
+ "# Define Transforms\n",
43
+ "#######################################################\n",
44
+ "\n",
45
+ "#To define an augmentation pipeline, you need to create an instance of the Compose class.\n",
46
+ "#As an argument to the Compose class, you need to pass a list of augmentations you want to apply. \n",
47
+ "#A call to Compose will return a transform function that will perform image augmentation.\n",
48
+ "#(https://albumentations.ai/docs/getting_started/image_augmentation/)\n",
49
+ "\n",
50
+ "train_transforms = A.Compose(\n",
51
+ " [\n",
52
+ " A.SmallestMaxSize(max_size=350),\n",
53
+ " A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=360, p=0.5),\n",
54
+ " A.RandomCrop(height=256, width=256),\n",
55
+ " A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),\n",
56
+ " A.RandomBrightnessContrast(p=0.5),\n",
57
+ " A.MultiplicativeNoise(multiplier=[0.5,2], per_channel=True, p=0.2),\n",
58
+ " A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),\n",
59
+ " A.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),\n",
60
+ " A.RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),\n",
61
+ " ToTensorV2(),\n",
62
+ " ]\n",
63
+ ")\n",
64
+ "\n",
65
+ "test_transforms = A.Compose(\n",
66
+ " [\n",
67
+ " A.SmallestMaxSize(max_size=350),\n",
68
+ " A.CenterCrop(height=256, width=256),\n",
69
+ " A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),\n",
70
+ " ToTensorV2(),\n",
71
+ " ]\n",
72
+ ")"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": 3,
78
+ "metadata": {},
79
+ "outputs": [],
80
+ "source": [
81
+ "import os\n",
82
+ "import matplotlib.pyplot as plt\n",
83
+ "from pandas.core.common import flatten\n",
84
+ "import copy\n",
85
+ "import numpy as np\n",
86
+ "import random\n",
87
+ "\n",
88
+ "import torch\n",
89
+ "from torch import nn\n",
90
+ "from torch import optim\n",
91
+ "import torch.nn.functional as F\n",
92
+ "from torchvision import datasets, transforms, models\n",
93
+ "from torch.utils.data import Dataset, DataLoader\n",
94
+ "import torch.nn as nn\n",
95
+ "import albumentations as A\n",
96
+ "from albumentations.pytorch import ToTensorV2\n",
97
+ "import cv2\n",
98
+ "\n",
99
+ "import glob\n",
100
+ "from tqdm import tqdm\n",
101
+ "import random\n",
102
+ "\n",
103
+ "class MotorbikeDataset(torch.utils.data.Dataset):\n",
104
+ " def __init__(self, image_paths, transform=None):\n",
105
+ " self.root = image_paths\n",
106
+ " self.image_paths = os.listdir(image_paths)\n",
107
+ " self.transform = transform\n",
108
+ " \n",
109
+ " def __len__(self):\n",
110
+ " return len(self.image_paths)\n",
111
+ "\n",
112
+ " def __getitem__(self, idx):\n",
113
+ " image_filepath = self.image_paths[idx]\n",
114
+ " \n",
115
+ " image = cv2.imread(os.path.join(self.root,image_filepath))\n",
116
+ " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
117
+ " \n",
118
+ " label = int('t' in image_filepath)\n",
119
+ " if self.transform is not None:\n",
120
+ " image = self.transform(image=image)[\"image\"]\n",
121
+ " \n",
122
+ " return image, label\n",
123
+ " \n",
124
+ "\n",
125
+ "class MotorbikeDataset_CV(torch.utils.data.Dataset):\n",
126
+ " def __init__(self, root, train_transforms, val_transforms, trainval_ratio=0.8) -> None:\n",
127
+ " self.root = root\n",
128
+ " self.train_transforms = train_transforms\n",
129
+ " self.val_transforms = val_transforms\n",
130
+ " self.trainval_ratio = trainval_ratio\n",
131
+ " self.train_split, self.val_split = self.gen_split()\n",
132
+ " \n",
133
+ " def __len__(self):\n",
134
+ " return len(self.root)\n",
135
+ "\n",
136
+ " def gen_split(self):\n",
137
+ " img_list = os.listdir(self.root)\n",
138
+ " n_list = [img for img in img_list if img.startswith('n_')]\n",
139
+ " t_list = [img for img in img_list if img.startswith('t_')]\n",
140
+ " \n",
141
+ " n_train = random.choices(n_list, k=int(len(n_list)*self.trainval_ratio))\n",
142
+ " t_train = random.choices(t_list, k=int(len(t_list)*self.trainval_ratio))\n",
143
+ " n_val = [img for img in n_list if img not in n_train]\n",
144
+ " t_val = [img for img in t_list if img not in t_train]\n",
145
+ " \n",
146
+ " train_split = n_train + t_train\n",
147
+ " val_split = n_val + t_val\n",
148
+ " return train_split, val_split\n",
149
+ "\n",
150
+ " def get_split(self):\n",
151
+ " train_dataset = Dataset_from_list(self.root, self.train_split, self.train_transforms)\n",
152
+ " val_dataset = Dataset_from_list(self.root, self.val_split, self.val_transforms)\n",
153
+ " return train_dataset, val_dataset\n",
154
+ " \n",
155
+ "class Dataset_from_list(torch.utils.data.Dataset):\n",
156
+ " def __init__(self, root, img_list, transform) -> None:\n",
157
+ " self.root = root\n",
158
+ " self.img_list = img_list\n",
159
+ " self.transform = transform\n",
160
+ " \n",
161
+ " def __len__(self):\n",
162
+ " return len(self.img_list)\n",
163
+ " \n",
164
+ " def __getitem__(self, idx):\n",
165
+ " image = cv2.imread(os.path.join(self.root, self.img_list[idx]))\n",
166
+ " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
167
+ " \n",
168
+ " label = int(self.img_list[idx].startswith('t_'))\n",
169
+ " \n",
170
+ " if self.transform is not None:\n",
171
+ " image = self.transform(image=image)[\"image\"]\n",
172
+ " \n",
173
+ " return image, label\n",
174
+ " \n",
175
+ " \n",
176
+ " \n"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": 4,
182
+ "metadata": {},
183
+ "outputs": [],
184
+ "source": [
185
+ "dataset_CV = MotorbikeDataset_CV(\n",
186
+ " root='/workspace/data/',\n",
187
+ " train_transforms=train_transforms,\n",
188
+ " val_transforms=test_transforms\n",
189
+ " )"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": 5,
195
+ "metadata": {},
196
+ "outputs": [],
197
+ "source": [
198
+ "train_dataset, val_dataset = dataset_CV.get_split()"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "code",
203
+ "execution_count": 6,
204
+ "metadata": {},
205
+ "outputs": [
206
+ {
207
+ "name": "stdout",
208
+ "output_type": "stream",
209
+ "text": [
210
+ "277\n",
211
+ "150\n"
212
+ ]
213
+ }
214
+ ],
215
+ "source": [
216
+ "print(len(train_dataset))\n",
217
+ "print(len(val_dataset))"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "code",
222
+ "execution_count": 7,
223
+ "metadata": {},
224
+ "outputs": [],
225
+ "source": [
226
+ "train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True)\n",
227
+ "val_loader = DataLoader(val_dataset,batch_size=64, shuffle=False)"
228
+ ]
229
+ },
230
+ {
231
+ "cell_type": "code",
232
+ "execution_count": 8,
233
+ "metadata": {},
234
+ "outputs": [],
235
+ "source": [
236
+ "classes = ('no_trunk', 'trunk')"
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "code",
241
+ "execution_count": 9,
242
+ "metadata": {},
243
+ "outputs": [],
244
+ "source": [
245
+ "device = torch.device(\"cuda:2\") if torch.cuda.is_available() else torch.device(\"cpu\")"
246
+ ]
247
+ },
248
+ {
249
+ "cell_type": "code",
250
+ "execution_count": 16,
251
+ "metadata": {},
252
+ "outputs": [
253
+ {
254
+ "data": {
255
+ "text/plain": [
256
+ "ResNet(\n",
257
+ " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
258
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
259
+ " (relu): ReLU(inplace=True)\n",
260
+ " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
261
+ " (layer1): Sequential(\n",
262
+ " (0): Bottleneck(\n",
263
+ " (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
264
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
265
+ " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
266
+ " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
267
+ " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
268
+ " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
269
+ " (relu): ReLU(inplace=True)\n",
270
+ " (downsample): Sequential(\n",
271
+ " (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
272
+ " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
273
+ " )\n",
274
+ " )\n",
275
+ " (1): Bottleneck(\n",
276
+ " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
277
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
278
+ " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
279
+ " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
280
+ " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
281
+ " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
282
+ " (relu): ReLU(inplace=True)\n",
283
+ " )\n",
284
+ " (2): Bottleneck(\n",
285
+ " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
286
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
287
+ " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
288
+ " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
289
+ " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
290
+ " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
291
+ " (relu): ReLU(inplace=True)\n",
292
+ " )\n",
293
+ " )\n",
294
+ " (layer2): Sequential(\n",
295
+ " (0): Bottleneck(\n",
296
+ " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
297
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
298
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
299
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
300
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
301
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
302
+ " (relu): ReLU(inplace=True)\n",
303
+ " (downsample): Sequential(\n",
304
+ " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
305
+ " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
306
+ " )\n",
307
+ " )\n",
308
+ " (1): Bottleneck(\n",
309
+ " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
310
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
311
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
312
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
313
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
314
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
315
+ " (relu): ReLU(inplace=True)\n",
316
+ " )\n",
317
+ " (2): Bottleneck(\n",
318
+ " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
319
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
320
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
321
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
322
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
323
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
324
+ " (relu): ReLU(inplace=True)\n",
325
+ " )\n",
326
+ " (3): Bottleneck(\n",
327
+ " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
328
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
329
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
330
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
331
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
332
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
333
+ " (relu): ReLU(inplace=True)\n",
334
+ " )\n",
335
+ " )\n",
336
+ " (layer3): Sequential(\n",
337
+ " (0): Bottleneck(\n",
338
+ " (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
339
+ " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
340
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
341
+ " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
342
+ " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
343
+ " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
344
+ " (relu): ReLU(inplace=True)\n",
345
+ " (downsample): Sequential(\n",
346
+ " (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
347
+ " (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
348
+ " )\n",
349
+ " )\n",
350
+ " (1): Bottleneck(\n",
351
+ " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
352
+ " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
353
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
354
+ " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
355
+ " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
356
+ " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
357
+ " (relu): ReLU(inplace=True)\n",
358
+ " )\n",
359
+ " (2): Bottleneck(\n",
360
+ " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
361
+ " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
362
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
363
+ " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
364
+ " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
365
+ " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
366
+ " (relu): ReLU(inplace=True)\n",
367
+ " )\n",
368
+ " (3): Bottleneck(\n",
369
+ " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
370
+ " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
371
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
372
+ " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
373
+ " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
374
+ " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
375
+ " (relu): ReLU(inplace=True)\n",
376
+ " )\n",
377
+ " (4): Bottleneck(\n",
378
+ " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
379
+ " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
380
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
381
+ " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
382
+ " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
383
+ " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
384
+ " (relu): ReLU(inplace=True)\n",
385
+ " )\n",
386
+ " (5): Bottleneck(\n",
387
+ " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
388
+ " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
389
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
390
+ " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
391
+ " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
392
+ " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
393
+ " (relu): ReLU(inplace=True)\n",
394
+ " )\n",
395
+ " )\n",
396
+ " (layer4): Sequential(\n",
397
+ " (0): Bottleneck(\n",
398
+ " (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
399
+ " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
400
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
401
+ " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
402
+ " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
403
+ " (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
404
+ " (relu): ReLU(inplace=True)\n",
405
+ " (downsample): Sequential(\n",
406
+ " (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
407
+ " (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
408
+ " )\n",
409
+ " )\n",
410
+ " (1): Bottleneck(\n",
411
+ " (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
412
+ " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
413
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
414
+ " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
415
+ " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
416
+ " (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
417
+ " (relu): ReLU(inplace=True)\n",
418
+ " )\n",
419
+ " (2): Bottleneck(\n",
420
+ " (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
421
+ " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
422
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
423
+ " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
424
+ " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
425
+ " (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
426
+ " (relu): ReLU(inplace=True)\n",
427
+ " )\n",
428
+ " )\n",
429
+ " (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
430
+ " (fc): Sequential(\n",
431
+ " (0): Linear(in_features=2048, out_features=2, bias=True)\n",
432
+ " )\n",
433
+ ")"
434
+ ]
435
+ },
436
+ "execution_count": 16,
437
+ "metadata": {},
438
+ "output_type": "execute_result"
439
+ }
440
+ ],
441
+ "source": [
442
+ "model = models.resnet50(pretrained=True)\n",
443
+ "model.fc = nn.Sequential(\n",
444
+ " # nn.Dropout(0.5),\n",
445
+ " nn.Linear(model.fc.in_features, 2),\n",
446
+ " # nn.Sigmoid()\n",
447
+ ")\n",
448
+ "\n",
449
+ "for n, p in model.named_parameters():\n",
450
+ " if 'fc' in n:\n",
451
+ " p.requires_grad = True\n",
452
+ " else:\n",
453
+ " p.requires_grad = False\n",
454
+ "\n",
455
+ "model.to(device)"
456
+ ]
457
+ },
458
+ {
459
+ "cell_type": "code",
460
+ "execution_count": 17,
461
+ "metadata": {},
462
+ "outputs": [],
463
+ "source": [
464
+ "import torch.optim as optim\n",
465
+ "criterion = nn.BCEWithLogitsLoss().to(device)\n",
466
+ "optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1, momentum=0.9)\n",
467
+ "# optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)"
468
+ ]
469
+ },
470
+ {
471
+ "cell_type": "code",
472
+ "execution_count": 18,
473
+ "metadata": {},
474
+ "outputs": [
475
+ {
476
+ "ename": "ValueError",
477
+ "evalue": "Target size (torch.Size([64])) must be the same as input size (torch.Size([64, 2]))",
478
+ "output_type": "error",
479
+ "traceback": [
480
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
481
+ "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
482
+ "\u001b[0;32m/tmp/ipykernel_107755/1844816491.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 11\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
483
+ "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1100\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1101\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1103\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1104\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
484
+ "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/modules/loss.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input, target)\u001b[0m\n\u001b[1;32m 705\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 706\u001b[0m \u001b[0mpos_weight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpos_weight\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 707\u001b[0;31m reduction=self.reduction)\n\u001b[0m\u001b[1;32m 708\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 709\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
485
+ "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mbinary_cross_entropy_with_logits\u001b[0;34m(input, target, weight, size_average, reduce, reduction, pos_weight)\u001b[0m\n\u001b[1;32m 2978\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2979\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtarget\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2980\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Target size ({}) must be the same as input size ({})\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtarget\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2981\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2982\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbinary_cross_entropy_with_logits\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpos_weight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduction_enum\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
486
+ "\u001b[0;31mValueError\u001b[0m: Target size (torch.Size([64])) must be the same as input size (torch.Size([64, 2]))"
487
+ ]
488
+ }
489
+ ],
490
+ "source": [
491
+ "for epoch in range(10):\n",
492
+ " model.train()\n",
493
+ " running_loss = 0.0\n",
494
+ " for i, data in enumerate(train_loader, 0):\n",
495
+ " inputs, labels = data[0].to(device), data[1].to(device)\n",
496
+ " \n",
497
+ " optimizer.zero_grad()\n",
498
+ " \n",
499
+ " outputs = model(inputs)\n",
500
+ " loss = criterion(outputs, labels)\n",
501
+ " loss.backward()\n",
502
+ " optimizer.step()\n",
503
+ " running_loss += loss.item()\n",
504
+ " \n",
505
+ " print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')\n",
506
+ " # print(\"TRAIN acc = {}\".format(acc))\n",
507
+ " # running_loss = 0.0\n",
508
+ " \n",
509
+ " with torch.no_grad():\n",
510
+ " model.eval()\n",
511
+ " running_loss = 0.0\n",
512
+ " correct =0\n",
513
+ " for i, data in enumerate(val_loader, 0):\n",
514
+ " inputs, labels = data[0].to(device), data[1].to(device)\n",
515
+ " outputs = model(inputs)\n",
516
+ " _, preds = outputs.max(1)\n",
517
+ " loss = criterion(outputs, labels)\n",
518
+ " running_loss += loss.item()\n",
519
+ " labels_one_hot = F.one_hot(labels, 2)\n",
520
+ " outputs_one_hot = F.one_hot(preds, 2)\n",
521
+ " correct = correct + (labels_one_hot + outputs_one_hot == 2).sum(dim=0).to(torch.float)\n",
522
+ " \n",
523
+ " acc = 100 * correct / len(val_dataset)\n",
524
+ " print(f'VAL: [{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')\n",
525
+ " print(\"VAL acc = {}\".format(acc))"
526
+ ]
527
+ },
528
+ {
529
+ "cell_type": "code",
530
+ "execution_count": null,
531
+ "metadata": {},
532
+ "outputs": [],
533
+ "source": []
534
+ }
535
+ ],
536
+ "metadata": {
537
+ "kernelspec": {
538
+ "display_name": "base",
539
+ "language": "python",
540
+ "name": "python3"
541
+ },
542
+ "language_info": {
543
+ "codemirror_mode": {
544
+ "name": "ipython",
545
+ "version": 3
546
+ },
547
+ "file_extension": ".py",
548
+ "mimetype": "text/x-python",
549
+ "name": "python",
550
+ "nbconvert_exporter": "python",
551
+ "pygments_lexer": "ipython3",
552
+ "version": "3.7.11"
553
+ },
554
+ "orig_nbformat": 4
555
+ },
556
+ "nbformat": 4,
557
+ "nbformat_minor": 2
558
+ }
classification.ipynb ADDED
@@ -0,0 +1,632 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 10,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "%matplotlib inline\n",
10
+ "%config InlineBackend.figure_format = 'retina'\n",
11
+ "\n",
12
+ "import os\n",
13
+ "import matplotlib.pyplot as plt\n",
14
+ "from pandas.core.common import flatten\n",
15
+ "import copy\n",
16
+ "import numpy as np\n",
17
+ "import random\n",
18
+ "\n",
19
+ "import torch\n",
20
+ "from torch import nn\n",
21
+ "from torch import optim\n",
22
+ "import torch.nn.functional as F\n",
23
+ "from torchvision import datasets, transforms, models\n",
24
+ "from torch.utils.data import Dataset, DataLoader\n",
25
+ "import torch.nn as nn\n",
26
+ "import albumentations as A\n",
27
+ "from albumentations.pytorch import ToTensorV2\n",
28
+ "import cv2\n",
29
+ "\n",
30
+ "import glob\n",
31
+ "from tqdm import tqdm\n",
32
+ "import random"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": 11,
38
+ "metadata": {},
39
+ "outputs": [],
40
+ "source": [
41
+ "#######################################################\n",
42
+ "# Define Transforms\n",
43
+ "#######################################################\n",
44
+ "\n",
45
+ "#To define an augmentation pipeline, you need to create an instance of the Compose class.\n",
46
+ "#As an argument to the Compose class, you need to pass a list of augmentations you want to apply. \n",
47
+ "#A call to Compose will return a transform function that will perform image augmentation.\n",
48
+ "#(https://albumentations.ai/docs/getting_started/image_augmentation/)\n",
49
+ "\n",
50
+ "train_transforms = A.Compose(\n",
51
+ " [\n",
52
+ " A.SmallestMaxSize(max_size=350),\n",
53
+ " A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=360, p=0.5),\n",
54
+ " A.RandomCrop(height=256, width=256),\n",
55
+ " A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),\n",
56
+ " A.RandomBrightnessContrast(p=0.5),\n",
57
+ " A.MultiplicativeNoise(multiplier=[0.5,2], per_channel=True, p=0.2),\n",
58
+ " A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),\n",
59
+ " A.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),\n",
60
+ " A.RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),\n",
61
+ " ToTensorV2(),\n",
62
+ " ]\n",
63
+ ")\n",
64
+ "\n",
65
+ "test_transforms = A.Compose(\n",
66
+ " [\n",
67
+ " A.SmallestMaxSize(max_size=350),\n",
68
+ " A.CenterCrop(height=256, width=256),\n",
69
+ " A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),\n",
70
+ " ToTensorV2(),\n",
71
+ " ]\n",
72
+ ")"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": 12,
78
+ "metadata": {},
79
+ "outputs": [],
80
+ "source": [
81
+ "import os\n",
82
+ "import matplotlib.pyplot as plt\n",
83
+ "from pandas.core.common import flatten\n",
84
+ "import copy\n",
85
+ "import numpy as np\n",
86
+ "import random\n",
87
+ "\n",
88
+ "import torch\n",
89
+ "from torch import nn\n",
90
+ "from torch import optim\n",
91
+ "import torch.nn.functional as F\n",
92
+ "from torchvision import datasets, transforms, models\n",
93
+ "from torch.utils.data import Dataset, DataLoader\n",
94
+ "import torch.nn as nn\n",
95
+ "import albumentations as A\n",
96
+ "from albumentations.pytorch import ToTensorV2\n",
97
+ "import cv2\n",
98
+ "\n",
99
+ "import glob\n",
100
+ "from tqdm import tqdm\n",
101
+ "import random\n",
102
+ "\n",
103
+ "class MotorbikeDataset(torch.utils.data.Dataset):\n",
104
+ " def __init__(self, image_paths, transform=None):\n",
105
+ " self.root = image_paths\n",
106
+ " self.image_paths = os.listdir(image_paths)\n",
107
+ " self.transform = transform\n",
108
+ " \n",
109
+ " def __len__(self):\n",
110
+ " return len(self.image_paths)\n",
111
+ "\n",
112
+ " def __getitem__(self, idx):\n",
113
+ " image_filepath = self.image_paths[idx]\n",
114
+ " \n",
115
+ " image = cv2.imread(os.path.join(self.root,image_filepath))\n",
116
+ " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
117
+ " \n",
118
+ " label = int('t' in image_filepath)\n",
119
+ " if self.transform is not None:\n",
120
+ " image = self.transform(image=image)[\"image\"]\n",
121
+ " \n",
122
+ " return image, label\n",
123
+ " \n",
124
+ "\n",
125
+ "class MotorbikeDataset_CV(torch.utils.data.Dataset):\n",
126
+ " def __init__(self, root, train_transforms, val_transforms, trainval_ratio=0.8) -> None:\n",
127
+ " self.root = root\n",
128
+ " self.train_transforms = train_transforms\n",
129
+ " self.val_transforms = val_transforms\n",
130
+ " self.trainval_ratio = trainval_ratio\n",
131
+ " self.train_split, self.val_split = self.gen_split()\n",
132
+ " \n",
133
+ " def __len__(self):\n",
134
+ " return len(self.root)\n",
135
+ "\n",
136
+ " def gen_split(self):\n",
137
+ " img_list = os.listdir(self.root)\n",
138
+ " n_list = [img for img in img_list if img.startswith('n_')]\n",
139
+ " t_list = [img for img in img_list if img.startswith('t_')]\n",
140
+ " \n",
141
+ " n_train = random.choices(n_list, k=int(len(n_list)*self.trainval_ratio))\n",
142
+ " t_train = random.choices(t_list, k=int(len(t_list)*self.trainval_ratio))\n",
143
+ " n_val = [img for img in n_list if img not in n_train]\n",
144
+ " t_val = [img for img in t_list if img not in t_train]\n",
145
+ " \n",
146
+ " train_split = n_train + t_train\n",
147
+ " val_split = n_val + t_val\n",
148
+ " return train_split, val_split\n",
149
+ "\n",
150
+ " def get_split(self):\n",
151
+ " train_dataset = Dataset_from_list(self.root, self.train_split, self.train_transforms)\n",
152
+ " val_dataset = Dataset_from_list(self.root, self.val_split, self.val_transforms)\n",
153
+ " return train_dataset, val_dataset\n",
154
+ " \n",
155
+ "class Dataset_from_list(torch.utils.data.Dataset):\n",
156
+ " def __init__(self, root, img_list, transform) -> None:\n",
157
+ " self.root = root\n",
158
+ " self.img_list = img_list\n",
159
+ " self.transform = transform\n",
160
+ " \n",
161
+ " def __len__(self):\n",
162
+ " return len(self.img_list)\n",
163
+ " \n",
164
+ " def __getitem__(self, idx):\n",
165
+ " image = cv2.imread(os.path.join(self.root, self.img_list[idx]))\n",
166
+ " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
167
+ " \n",
168
+ " label = int(self.img_list[idx].startswith('t_'))\n",
169
+ " \n",
170
+ " if self.transform is not None:\n",
171
+ " image = self.transform(image=image)[\"image\"]\n",
172
+ " \n",
173
+ " return image, label\n",
174
+ " \n",
175
+ " \n",
176
+ " \n"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": 13,
182
+ "metadata": {},
183
+ "outputs": [],
184
+ "source": [
185
+ "dataset_CV = MotorbikeDataset_CV(\n",
186
+ " root='/workspace/data/',\n",
187
+ " train_transforms=train_transforms,\n",
188
+ " val_transforms=test_transforms\n",
189
+ " )"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": 14,
195
+ "metadata": {},
196
+ "outputs": [],
197
+ "source": [
198
+ "train_dataset, val_dataset = dataset_CV.get_split()"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "code",
203
+ "execution_count": 15,
204
+ "metadata": {},
205
+ "outputs": [
206
+ {
207
+ "name": "stdout",
208
+ "output_type": "stream",
209
+ "text": [
210
+ "277\n",
211
+ "166\n"
212
+ ]
213
+ }
214
+ ],
215
+ "source": [
216
+ "print(len(train_dataset))\n",
217
+ "print(len(val_dataset))"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "code",
222
+ "execution_count": 16,
223
+ "metadata": {},
224
+ "outputs": [],
225
+ "source": [
226
+ "train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True)\n",
227
+ "val_loader = DataLoader(val_dataset,batch_size=64, shuffle=False)"
228
+ ]
229
+ },
230
+ {
231
+ "cell_type": "code",
232
+ "execution_count": 17,
233
+ "metadata": {},
234
+ "outputs": [],
235
+ "source": [
236
+ "classes = ('no_trunk', 'trunk')"
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "code",
241
+ "execution_count": 18,
242
+ "metadata": {},
243
+ "outputs": [],
244
+ "source": [
245
+ "device = torch.device(\"cuda:1\") if torch.cuda.is_available() else torch.device(\"cpu\")"
246
+ ]
247
+ },
248
+ {
249
+ "cell_type": "code",
250
+ "execution_count": 28,
251
+ "metadata": {},
252
+ "outputs": [
253
+ {
254
+ "data": {
255
+ "text/plain": [
256
+ "ResNet(\n",
257
+ " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
258
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
259
+ " (relu): ReLU(inplace=True)\n",
260
+ " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
261
+ " (layer1): Sequential(\n",
262
+ " (0): Bottleneck(\n",
263
+ " (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
264
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
265
+ " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
266
+ " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
267
+ " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
268
+ " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
269
+ " (relu): ReLU(inplace=True)\n",
270
+ " (downsample): Sequential(\n",
271
+ " (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
272
+ " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
273
+ " )\n",
274
+ " )\n",
275
+ " (1): Bottleneck(\n",
276
+ " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
277
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
278
+ " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
279
+ " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
280
+ " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
281
+ " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
282
+ " (relu): ReLU(inplace=True)\n",
283
+ " )\n",
284
+ " (2): Bottleneck(\n",
285
+ " (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
286
+ " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
287
+ " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
288
+ " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
289
+ " (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
290
+ " (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
291
+ " (relu): ReLU(inplace=True)\n",
292
+ " )\n",
293
+ " )\n",
294
+ " (layer2): Sequential(\n",
295
+ " (0): Bottleneck(\n",
296
+ " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
297
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
298
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
299
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
300
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
301
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
302
+ " (relu): ReLU(inplace=True)\n",
303
+ " (downsample): Sequential(\n",
304
+ " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
305
+ " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
306
+ " )\n",
307
+ " )\n",
308
+ " (1): Bottleneck(\n",
309
+ " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
310
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
311
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
312
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
313
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
314
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
315
+ " (relu): ReLU(inplace=True)\n",
316
+ " )\n",
317
+ " (2): Bottleneck(\n",
318
+ " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
319
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
320
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
321
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
322
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
323
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
324
+ " (relu): ReLU(inplace=True)\n",
325
+ " )\n",
326
+ " (3): Bottleneck(\n",
327
+ " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
328
+ " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
329
+ " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
330
+ " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
331
+ " (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
332
+ " (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
333
+ " (relu): ReLU(inplace=True)\n",
334
+ " )\n",
335
+ " )\n",
336
+ " (layer3): Sequential(\n",
337
+ " (0): Bottleneck(\n",
338
+ " (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
339
+ " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
340
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
341
+ " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
342
+ " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
343
+ " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
344
+ " (relu): ReLU(inplace=True)\n",
345
+ " (downsample): Sequential(\n",
346
+ " (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
347
+ " (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
348
+ " )\n",
349
+ " )\n",
350
+ " (1): Bottleneck(\n",
351
+ " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
352
+ " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
353
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
354
+ " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
355
+ " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
356
+ " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
357
+ " (relu): ReLU(inplace=True)\n",
358
+ " )\n",
359
+ " (2): Bottleneck(\n",
360
+ " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
361
+ " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
362
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
363
+ " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
364
+ " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
365
+ " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
366
+ " (relu): ReLU(inplace=True)\n",
367
+ " )\n",
368
+ " (3): Bottleneck(\n",
369
+ " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
370
+ " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
371
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
372
+ " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
373
+ " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
374
+ " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
375
+ " (relu): ReLU(inplace=True)\n",
376
+ " )\n",
377
+ " (4): Bottleneck(\n",
378
+ " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
379
+ " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
380
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
381
+ " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
382
+ " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
383
+ " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
384
+ " (relu): ReLU(inplace=True)\n",
385
+ " )\n",
386
+ " (5): Bottleneck(\n",
387
+ " (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
388
+ " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
389
+ " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
390
+ " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
391
+ " (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
392
+ " (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
393
+ " (relu): ReLU(inplace=True)\n",
394
+ " )\n",
395
+ " )\n",
396
+ " (layer4): Sequential(\n",
397
+ " (0): Bottleneck(\n",
398
+ " (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
399
+ " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
400
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
401
+ " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
402
+ " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
403
+ " (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
404
+ " (relu): ReLU(inplace=True)\n",
405
+ " (downsample): Sequential(\n",
406
+ " (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
407
+ " (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
408
+ " )\n",
409
+ " )\n",
410
+ " (1): Bottleneck(\n",
411
+ " (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
412
+ " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
413
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
414
+ " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
415
+ " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
416
+ " (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
417
+ " (relu): ReLU(inplace=True)\n",
418
+ " )\n",
419
+ " (2): Bottleneck(\n",
420
+ " (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
421
+ " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
422
+ " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
423
+ " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
424
+ " (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
425
+ " (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
426
+ " (relu): ReLU(inplace=True)\n",
427
+ " )\n",
428
+ " )\n",
429
+ " (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
430
+ " (fc): Sequential(\n",
431
+ " (0): Linear(in_features=2048, out_features=2, bias=True)\n",
432
+ " )\n",
433
+ ")"
434
+ ]
435
+ },
436
+ "execution_count": 28,
437
+ "metadata": {},
438
+ "output_type": "execute_result"
439
+ }
440
+ ],
441
+ "source": [
442
+ "model = models.resnet50(pretrained=True)\n",
443
+ "model.fc = nn.Sequential(\n",
444
+ " # nn.Dropout(0.5),\n",
445
+ " nn.Linear(model.fc.in_features, 2)\n",
446
+ ")\n",
447
+ "\n",
448
+ "for n, p in model.named_parameters():\n",
449
+ " if 'fc' in n:\n",
450
+ " p.requires_grad = True\n",
451
+ " else:\n",
452
+ " p.requires_grad = False\n",
453
+ "\n",
454
+ "model.to(device)"
455
+ ]
456
+ },
457
+ {
458
+ "cell_type": "code",
459
+ "execution_count": 31,
460
+ "metadata": {},
461
+ "outputs": [],
462
+ "source": [
463
+ "import torch.optim as optim\n",
464
+ "criterion = nn.CrossEntropyLoss()\n",
465
+ "optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1, momentum=0.9)\n",
466
+ "# optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)"
467
+ ]
468
+ },
469
+ {
470
+ "cell_type": "code",
471
+ "execution_count": 32,
472
+ "metadata": {},
473
+ "outputs": [
474
+ {
475
+ "name": "stdout",
476
+ "output_type": "stream",
477
+ "text": [
478
+ "[1, 4] loss: 0.009\n",
479
+ "VAL: [1, 3] loss: 0.028\n",
480
+ "VAL acc = tensor([51.8072, 0.0000], device='cuda:1')\n",
481
+ "[2, 4] loss: 0.026\n",
482
+ "VAL: [2, 3] loss: 0.018\n",
483
+ "VAL acc = tensor([51.8072, 0.0000], device='cuda:1')\n",
484
+ "[3, 4] loss: 0.014\n",
485
+ "VAL: [3, 3] loss: 0.003\n",
486
+ "VAL acc = tensor([50.6024, 33.7349], device='cuda:1')\n",
487
+ "[4, 4] loss: 0.006\n",
488
+ "VAL: [4, 3] loss: 0.006\n",
489
+ "VAL acc = tensor([21.0843, 46.9879], device='cuda:1')\n",
490
+ "[5, 4] loss: 0.007\n",
491
+ "VAL: [5, 3] loss: 0.006\n",
492
+ "VAL acc = tensor([50.6024, 30.1205], device='cuda:1')\n",
493
+ "[6, 4] loss: 0.005\n",
494
+ "VAL: [6, 3] loss: 0.003\n",
495
+ "VAL acc = tensor([50.0000, 38.5542], device='cuda:1')\n",
496
+ "[7, 4] loss: 0.005\n",
497
+ "VAL: [7, 3] loss: 0.002\n",
498
+ "VAL acc = tensor([49.3976, 39.7590], device='cuda:1')\n",
499
+ "[8, 4] loss: 0.003\n",
500
+ "VAL: [8, 3] loss: 0.004\n",
501
+ "VAL acc = tensor([50.6024, 33.1325], device='cuda:1')\n",
502
+ "[9, 4] loss: 0.005\n",
503
+ "VAL: [9, 3] loss: 0.002\n",
504
+ "VAL acc = tensor([48.1928, 41.5663], device='cuda:1')\n",
505
+ "[10, 4] loss: 0.004\n",
506
+ "VAL: [10, 3] loss: 0.002\n",
507
+ "VAL acc = tensor([49.3976, 38.5542], device='cuda:1')\n"
508
+ ]
509
+ }
510
+ ],
511
+ "source": [
512
+ "for epoch in range(10):\n",
513
+ " model.train()\n",
514
+ " running_loss = 0.0\n",
515
+ " for i, data in enumerate(train_loader, 0):\n",
516
+ " inputs, labels = data[0].to(device), data[1].to(device)\n",
517
+ " \n",
518
+ " optimizer.zero_grad()\n",
519
+ " \n",
520
+ " outputs = model(inputs)\n",
521
+ " loss = criterion(outputs, labels)\n",
522
+ " loss.backward()\n",
523
+ " optimizer.step()\n",
524
+ " running_loss += loss.item()\n",
525
+ " \n",
526
+ " print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')\n",
527
+ " # print(\"TRAIN acc = {}\".format(acc))\n",
528
+ " # running_loss = 0.0\n",
529
+ " \n",
530
+ " with torch.no_grad():\n",
531
+ " model.eval()\n",
532
+ " running_loss = 0.0\n",
533
+ " correct =0\n",
534
+ " for i, data in enumerate(val_loader, 0):\n",
535
+ " inputs, labels = data[0].to(device), data[1].to(device)\n",
536
+ " outputs = model(inputs)\n",
537
+ " _, preds = outputs.max(1)\n",
538
+ " loss = criterion(outputs, labels)\n",
539
+ " running_loss += loss.item()\n",
540
+ " labels_one_hot = F.one_hot(labels, 2)\n",
541
+ " outputs_one_hot = F.one_hot(preds, 2)\n",
542
+ " correct = correct + (labels_one_hot + outputs_one_hot == 2).sum(dim=0).to(torch.float)\n",
543
+ " \n",
544
+ " acc = 100 * correct / len(val_dataset)\n",
545
+ " print(f'VAL: [{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')\n",
546
+ " print(\"VAL acc = {}\".format(acc))"
547
+ ]
548
+ },
549
+ {
550
+ "cell_type": "code",
551
+ "execution_count": 34,
552
+ "metadata": {},
553
+ "outputs": [
554
+ {
555
+ "data": {
556
+ "text/plain": [
557
+ "349"
558
+ ]
559
+ },
560
+ "execution_count": 34,
561
+ "metadata": {},
562
+ "output_type": "execute_result"
563
+ }
564
+ ],
565
+ "source": [
566
+ "len(os.listdir('/workspace//data'))"
567
+ ]
568
+ },
569
+ {
570
+ "cell_type": "code",
571
+ "execution_count": null,
572
+ "metadata": {},
573
+ "outputs": [],
574
+ "source": []
575
+ },
576
+ {
577
+ "cell_type": "code",
578
+ "execution_count": 40,
579
+ "metadata": {},
580
+ "outputs": [
581
+ {
582
+ "name": "stdout",
583
+ "output_type": "stream",
584
+ "text": [
585
+ "n_0000000187.jpg\n",
586
+ "t_0000000182.jpg\n"
587
+ ]
588
+ }
589
+ ],
590
+ "source": [
591
+ "root = '/workspace/data'\n",
592
+ "for img in os.listdir(root):\n",
593
+ " try:\n",
594
+ " image = cv2.imread(os.path.join(root,img))\n",
595
+ " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
596
+ " except:\n",
597
+ " print(img)\n",
598
+ " \n",
599
+ " "
600
+ ]
601
+ },
602
+ {
603
+ "cell_type": "code",
604
+ "execution_count": null,
605
+ "metadata": {},
606
+ "outputs": [],
607
+ "source": []
608
+ }
609
+ ],
610
+ "metadata": {
611
+ "kernelspec": {
612
+ "display_name": "base",
613
+ "language": "python",
614
+ "name": "python3"
615
+ },
616
+ "language_info": {
617
+ "codemirror_mode": {
618
+ "name": "ipython",
619
+ "version": 3
620
+ },
621
+ "file_extension": ".py",
622
+ "mimetype": "text/x-python",
623
+ "name": "python",
624
+ "nbconvert_exporter": "python",
625
+ "pygments_lexer": "ipython3",
626
+ "version": "3.7.11"
627
+ },
628
+ "orig_nbformat": 4
629
+ },
630
+ "nbformat": 4,
631
+ "nbformat_minor": 2
632
+ }
dataset.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import matplotlib.pyplot as plt
3
+ from pandas.core.common import flatten
4
+ import copy
5
+ import numpy as np
6
+ import random
7
+
8
+ import torch
9
+ from torch import nn
10
+ from torch import optim
11
+ import torch.nn.functional as F
12
+ from torchvision import datasets, transforms, models
13
+ from torch.utils.data import Dataset, DataLoader
14
+ import torch.nn as nn
15
+ import albumentations as A
16
+ from albumentations.pytorch import ToTensorV2
17
+ import cv2
18
+
19
+ import glob
20
+ from tqdm import tqdm
21
+ import random
22
+
23
+ class MotorbikeDataset(torch.utils.data.Dataset):
24
+ def __init__(self, image_paths, transform=None):
25
+ self.root = image_paths
26
+ self.image_paths = os.listdir(image_paths)
27
+ self.transform = transform
28
+
29
+ def __len__(self):
30
+ return len(self.image_paths)
31
+
32
+ def __getitem__(self, idx):
33
+ image_filepath = self.image_paths[idx]
34
+
35
+ image = cv2.imread(os.path.join(self.root,image_filepath))
36
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
37
+
38
+ label = int('t' in image_filepath)
39
+ if self.transform is not None:
40
+ image = self.transform(image=image)["image"]
41
+
42
+ return image, label
43
+
44
+
45
+ class MotorbikeDataset_CV(torch.utils.data.Dataset):
46
+ def __init__(self, root, train_transforms, val_transforms, trainval_ratio=0.8) -> None:
47
+ self.root = root
48
+ self.train_transforms = train_transforms
49
+ self.val_transforms = val_transforms
50
+ self.trainval_ratio = trainval_ratio
51
+ self.train_split, self.val_split = self.gen_split()
52
+
53
+ def __len__(self):
54
+ return len(self.root)
55
+
56
+ def gen_split(self):
57
+ img_list = os.listdir(self.root)
58
+ n_list = [img for img in img_list if img.startswith('n_')]
59
+ t_list = [img for img in img_list if img.startswith('t_')]
60
+
61
+ n_train = random.choices(n_list, k=int(len(n_list)*self.trainval_ratio))
62
+ t_train = random.choices(t_list, k=int(len(t_list)*self.trainval_ratio))
63
+ n_val = [img for img in n_list if img not in n_train]
64
+ t_val = [img for img in t_list if img not in t_train]
65
+
66
+ train_split = n_train + t_train
67
+ val_split = n_val + t_val
68
+ return train_split, val_split
69
+
70
+ def get_split(self):
71
+ train_dataset = Dataset_from_list(self.root, self.train_split, self.train_transforms)
72
+ val_dataset = Dataset_from_list(self.root, self.val_split, self.val_transforms)
73
+ return train_dataset, val_dataset
74
+
75
+ class Dataset_from_list(torch.utils.data.Dataset):
76
+ def __init__(self, root, img_list, transform) -> None:
77
+ self.root = root
78
+ self.img_list = img_list
79
+ self.transform = transform
80
+
81
+ def __len__(self):
82
+ return len(self.img_list)
83
+
84
+ def __getitem__(self, idx):
85
+ image = cv2.imread(os.path.join(self.root, self.img_list[idx]))
86
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
87
+
88
+ label = int(self.img_list[idx].startswith('t_'))
89
+
90
+ if self.transform is not None:
91
+ image = self.transform(image=image)["image"]
92
+
93
+ return image, label
94
+
95
+
96
+
main.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import matplotlib.pyplot as plt
5
+ from pandas.core.common import flatten
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch import optim
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.optim as optim
13
+ from torch.utils.data import Dataset, DataLoader
14
+ from torchvision import datasets, transforms, models
15
+ import albumentations as A
16
+ from albumentations.pytorch import ToTensorV2
17
+
18
+ from tqdm import tqdm
19
+ import random
20
+
21
+ sys.path.append('/workspace')
22
+ import dataset
23
+
24
+
25
+
26
+ train_transforms = A.Compose(
27
+ [
28
+ A.SmallestMaxSize(max_size=350),
29
+ A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=360, p=0.5),
30
+ A.RandomCrop(height=256, width=256),
31
+ A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
32
+ A.RandomBrightnessContrast(p=0.5),
33
+ A.MultiplicativeNoise(multiplier=[0.5,2], per_channel=True, p=0.2),
34
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
35
+ A.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
36
+ A.RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
37
+ ToTensorV2(),
38
+ ]
39
+ )
40
+
41
+ test_transforms = A.Compose(
42
+ [
43
+ A.SmallestMaxSize(max_size=350),
44
+ A.CenterCrop(height=256, width=256),
45
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
46
+ ToTensorV2(),
47
+ ]
48
+ )
49
+
50
+ dataset_CV = dataset.MotorbikeDataset_CV(
51
+ root='/workspace/data/',
52
+ train_transforms=train_transforms,
53
+ val_transforms=test_transforms
54
+ )
55
+
56
+ train_dataset, val_dataset = dataset_CV.get_split()
57
+
58
+ train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True)
59
+ val_loader = DataLoader(val_dataset,batch_size=64, shuffle=False)
60
+
61
+ device = torch.device("cuda:3") if torch.cuda.is_available() else torch.device("cpu")
62
+
63
+ model = models.resnet50(pretrained=True)
64
+ model.fc = nn.Sequential(
65
+ nn.Dropout(0.5),
66
+ nn.Linear(model.fc.in_features, 2)
67
+ )
68
+
69
+ for n, p in model.named_parameters():
70
+ if 'fc' in n:
71
+ p.requires_grad = True
72
+ else:
73
+ p.requires_grad = False
74
+
75
+ model.to(device)
76
+
77
+ criterion = nn.CrossEntropyLoss()
78
+ optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
79
+ scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
80
+ best_acc = 0.0
81
+
82
+ for epoch in range(10):
83
+ model.train()
84
+ running_loss = 0.0
85
+ for i, data in enumerate(train_loader, 0):
86
+ inputs, labels = data[0].to(device), data[1].to(device)
87
+
88
+ optimizer.zero_grad()
89
+
90
+ outputs = model(inputs)
91
+ loss = criterion(outputs, labels)
92
+ loss.backward()
93
+ optimizer.step()
94
+ running_loss += loss.item()
95
+ scheduler.step()
96
+
97
+ print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
98
+ # print("TRAIN acc = {}".format(acc))
99
+ running_loss = 0.0
100
+
101
+ with torch.no_grad():
102
+ model.eval()
103
+ running_loss = 0.0
104
+ correct =0
105
+ for i, data in enumerate(val_loader, 0):
106
+ inputs, labels = data[0].to(device), data[1].to(device)
107
+ outputs = model(inputs)
108
+ _, preds = outputs.max(1)
109
+ loss = criterion(outputs, labels)
110
+ running_loss += loss.item()
111
+ labels_one_hot = F.one_hot(labels, 2)
112
+ outputs_one_hot = F.one_hot(preds, 2)
113
+ correct = correct + (labels_one_hot + outputs_one_hot == 2).sum().to(torch.float)
114
+
115
+ acc = 100 * correct / len(val_dataset)
116
+ print(f'VAL: [{epoch + 1}, {i + 1:5d}] loss: {running_loss / len(val_loader):.3f}')
117
+ print("VAL acc = {:.2f}".format(acc))
118
+ if best_acc < acc:
119
+ torch.save(model.state_dict(), './result/best_model.pth')
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ pandas
2
+ numpy
3
+ albumentations
4
+ opencv-python
5
+ tqdm
6
+ matplotlib
7
+ jupyter
test.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import matplotlib.pyplot as plt
5
+ from pandas.core.common import flatten
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch import optim
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.optim as optim
13
+ from torch.utils.data import Dataset, DataLoader
14
+ from torchvision import datasets, transforms, models
15
+ import albumentations as A
16
+ from albumentations.pytorch import ToTensorV2
17
+
18
+ from tqdm import tqdm
19
+ import random
20
+ import cv2
21
+
22
+ sys.path.append('/workspace')
23
+ import dataset
24
+ import argparse
25
+
26
+ def parse_args():
27
+ parser = argparse.ArgumentParser(description='MiSLAS training (Stage-2)')
28
+ parser.add_argument('--input',
29
+ help='test image path',
30
+ required=True,
31
+ type=str)
32
+ args = parser.parse_args()
33
+ return args
34
+
35
+ classes = ('no_trunk', 'trunk')
36
+
37
+ test_transforms = A.Compose(
38
+ [
39
+ A.SmallestMaxSize(max_size=350),
40
+ A.CenterCrop(height=256, width=256),
41
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
42
+ ToTensorV2(),
43
+ ]
44
+ )
45
+
46
+ def main():
47
+ args = parse_args()
48
+ assert os.path.exists(args.input)
49
+ device = torch.device("cuda:3") if torch.cuda.is_available() else torch.device("cpu")
50
+
51
+ model = models.resnet50(pretrained=True)
52
+ model.fc = nn.Sequential(
53
+ nn.Dropout(0.5),
54
+ nn.Linear(model.fc.in_features, 2)
55
+ )
56
+
57
+ state_dict = torch.load('./result/best_model.pth')
58
+ model.load_state_dict(state_dict)
59
+
60
+ for _, p in model.named_parameters():
61
+ p.requires_grad = False
62
+
63
+ model.to(device)
64
+ model.eval()
65
+
66
+ test_transforms = A.Compose(
67
+ [
68
+ A.SmallestMaxSize(max_size=350),
69
+ A.CenterCrop(height=256, width=256),
70
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
71
+ ToTensorV2(),
72
+ ]
73
+ )
74
+
75
+ image = cv2.imread(args.input)
76
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
77
+ image = test_transforms(image=image)["image"]
78
+ image = torch.unsqueeze(image, 0).to(device)
79
+
80
+ output = model(image)
81
+ _, preds = output.max(1)
82
+
83
+ input_cls = 'trunk' if 't_' in args.input else 'no_trunk'
84
+
85
+ print("input: %s \n" %(input_cls))
86
+ print("output: %s" %(classes[preds.item()]))
87
+
88
+ if __name__ == '__main__':
89
+ main()
utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def mic_acc_cal(preds, labels):
5
+ if isinstance(labels, tuple):
6
+ assert len(labels) == 3
7
+ targets_a, targets_b, lam = labels
8
+ acc_mic_top1 = (lam * preds.eq(targets_a.data).cpu().sum().float() \
9
+ + (1 - lam) * preds.eq(targets_b.data).cpu().sum().float()) / len(preds)
10
+ else:
11
+ acc_mic_top1 = (preds == labels).sum().item() / len(labels)
12
+ return acc_mic_top1