Skip to content

Commit 1d398a2

Browse files
committed
fix heatmap bug
1 parent c630cf9 commit 1d398a2

File tree

9 files changed

+69
-40
lines changed

9 files changed

+69
-40
lines changed

config/yolov4_config.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
# coding=utf-8
22
# project
3-
DATA_PATH = "E:\YOLOV4/data"
4-
PROJECT_PATH = "E:\YOLOV4/data"
5-
DETECTION_PATH = "E:\YOLOV4/"
3+
import os.path as osp
4+
PROJECT_PATH = osp.abspath(osp.join(osp.dirname(__file__), '..'))
5+
6+
DATA_PATH = osp.join(PROJECT_PATH, 'data')
7+
# PROJECT_PATH = "E:\YOLOV4/data"
8+
# PROJECT_PATH = "E:\YOLOV4/"
9+
610

711
MODEL_TYPE = {
812
"TYPE": "Mobilenetv3-YOLOv4"
@@ -14,10 +18,10 @@
1418

1519
# train
1620
TRAIN = {
17-
"DATA_TYPE": "VOC", # DATA_TYPE: VOC ,COCO or Customer
21+
"DATA_TYPE": "Customer", # DATA_TYPE: VOC ,COCO or Customer
1822
"TRAIN_IMG_SIZE": 416,
1923
"AUGMENT": True,
20-
"BATCH_SIZE": 2,
24+
"BATCH_SIZE": 1,
2125
"MULTI_SCALE_TRAIN": False,
2226
"IOU_THRESHOLD_LOSS": 0.5,
2327
"YOLO_EPOCHS": 50,
@@ -34,7 +38,7 @@
3438
# val
3539
VAL = {
3640
"TEST_IMG_SIZE": 416,
37-
"BATCH_SIZE": 2,
41+
"BATCH_SIZE": 1,
3842
"NUMBER_WORKERS": 0,
3943
"CONF_THRESH": 0.005,
4044
"NMS_THRESH": 0.45,
@@ -44,8 +48,8 @@
4448
}
4549

4650
Customer_DATA = {
47-
"NUM": 1, # your dataset number
48-
"CLASSES": ["aeroplane"], # your dataset class
51+
"NUM": 3, # your dataset number
52+
"CLASSES": ["unknown", "person", "car"], # your dataset class
4953
}
5054

5155
VOC_DATA = {

eval/evaluator.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
class Evaluator(object):
15-
def __init__(self, model, showatt):
15+
def __init__(self, model=None, showatt=False):
1616
if cfg.TRAIN["DATA_TYPE"] == "VOC":
1717
self.classes = cfg.VOC_DATA["CLASSES"]
1818
elif cfg.TRAIN["DATA_TYPE"] == "COCO":
@@ -88,32 +88,32 @@ def APs_voc(self, multi_test=False, flip_test=False):
8888
self.inference_time = 1.0 * self.inference_time / len(img_inds)
8989
return self.__calc_APs(), self.inference_time
9090

91-
def get_bbox(self, img, multi_test=False, flip_test=False):
91+
def get_bbox(self, img, multi_test=False, flip_test=False, mode=None):
9292
if multi_test:
9393
test_input_sizes = range(320, 640, 96)
9494
bboxes_list = []
9595
for test_input_size in test_input_sizes:
9696
valid_scale = (0, np.inf)
9797
bboxes_list.append(
98-
self.__predict(img, test_input_size, valid_scale)
98+
self.__predict(img, test_input_size, valid_scale, mode)
9999
)
100100
if flip_test:
101101
bboxes_flip = self.__predict(
102-
img[:, ::-1], test_input_size, valid_scale
102+
img[:, ::-1], test_input_size, valid_scale, mode
103103
)
104104
bboxes_flip[:, [0, 2]] = (
105105
img.shape[1] - bboxes_flip[:, [2, 0]]
106106
)
107107
bboxes_list.append(bboxes_flip)
108108
bboxes = np.row_stack(bboxes_list)
109109
else:
110-
bboxes = self.__predict(img, self.val_shape, (0, np.inf))
110+
bboxes = self.__predict(img, self.val_shape, (0, np.inf), mode)
111111

112112
bboxes = nms(bboxes, self.conf_thresh, self.nms_thresh)
113113

114114
return bboxes
115115

116-
def __predict(self, img, test_shape, valid_scale):
116+
def __predict(self, img, test_shape, valid_scale, mode):
117117
org_img = np.copy(img)
118118
org_h, org_w, _ = org_img.shape
119119

@@ -130,8 +130,8 @@ def __predict(self, img, test_shape, valid_scale):
130130
bboxes = self.__convert_pred(
131131
pred_bbox, test_shape, (org_h, org_w), valid_scale
132132
)
133-
if self.showatt and len(img):
134-
self.__show_heatmap(beta[2], org_img)
133+
if self.showatt and len(img) and mode == 'det':
134+
self.__show_heatmap(beta, org_img)
135135
return bboxes
136136

137137
def __show_heatmap(self, beta, img):

eval_voc.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,26 @@ def __init__(
1818
weight_path=None,
1919
visiual=None,
2020
eval=False,
21+
showatt=False,
22+
mode=None
2123
):
2224
self.__num_class = cfg.VOC_DATA["NUM"]
2325
self.__conf_threshold = cfg.VAL["CONF_THRESH"]
2426
self.__nms_threshold = cfg.VAL["NMS_THRESH"]
2527
self.__device = gpu.select_device(gpu_id)
2628
self.__multi_scale_val = cfg.VAL["MULTI_SCALE_VAL"]
2729
self.__flip_val = cfg.VAL["FLIP_VAL"]
28-
30+
self.__showatt = showatt
2931
self.__visiual = visiual
3032
self.__eval = eval
33+
self.__mode = mode
3134
self.__classes = cfg.VOC_DATA["CLASSES"]
3235

33-
self.__model = Build_Model().to(self.__device)
36+
self.__model = Build_Model(showatt=self.__showatt).to(self.__device)
3437

3538
self.__load_model_weights(weight_path)
3639

37-
self.__evalter = Evaluator(self.__model, showatt=False)
40+
self.__evalter = Evaluator(self.__model, showatt=self.showatt)
3841

3942
def __load_model_weights(self, weight_path):
4043
print("loading weight file from : {}".format(weight_path))
@@ -76,7 +79,7 @@ def detection(self):
7679
img = cv2.imread(path)
7780
assert img is not None
7881

79-
bboxes_prd = self.__evalter.get_bbox(img, v)
82+
bboxes_prd = self.__evalter.get_bbox(img, v, mode=self.__mode)
8083
if bboxes_prd.shape[0] != 0:
8184
boxes = bboxes_prd[..., :4]
8285
class_inds = bboxes_prd[..., 5].astype(np.int32)
@@ -107,7 +110,7 @@ def detection(self):
107110
help="weight file path",
108111
)
109112
parser.add_argument(
110-
"--log_val_path", type=str, default="log_val", help="weight file path"
113+
"--log_val_path", type=str, default="log_val", help="val log file path"
111114
)
112115
parser.add_argument(
113116
"--gpu_id",
@@ -125,7 +128,10 @@ def detection(self):
125128
"--eval", action="store_true", default=True, help="eval the mAP or not"
126129
)
127130
parser.add_argument("--mode", type=str, default="val", help="val or det")
131+
parser.add_argument("--showatt", type=bool, default=True, help="whether to show attention map")
128132
opt = parser.parse_args()
133+
if not os.path.exists(opt.log_val_path):
134+
os.mkdir(opt.log_val_path)
129135
logger = Logger(
130136
log_file_name=opt.log_val_path + "/log_voc_val.txt",
131137
log_level=logging.DEBUG,
@@ -138,11 +144,15 @@ def detection(self):
138144
weight_path=opt.weight_path,
139145
eval=opt.eval,
140146
visiual=opt.visiual,
147+
showatt=opt.showatt,
148+
mode=opt.mode
141149
).val()
142150
else:
143151
Evaluation(
144152
gpu_id=opt.gpu_id,
145153
weight_path=opt.weight_path,
146154
eval=opt.eval,
147155
visiual=opt.visiual,
156+
showatt=opt.showatt,
157+
mode=opt.mode
148158
).detection()

model/YOLOv4.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .backbones.CSPDarknet53 import _BuildCSPDarknet53
66
from .backbones.mobilenetv2 import _BuildMobilenetV2
77
from .backbones.mobilenetv3 import _BuildMobilenetV3
8-
8+
from .layers.global_context_block import ContextBlock2d
99

1010
class Conv(nn.Module):
1111
def __init__(self, in_channels, out_channels, kernel_size, stride=1):
@@ -245,10 +245,9 @@ def __initialize_weights(self):
245245

246246

247247
class YOLOv4(nn.Module):
248-
def __init__(self, weight_path=None, out_channels=255, resume=False):
248+
def __init__(self, weight_path=None, out_channels=255, resume=False, showatt=False, feature_channels=0):
249249
super(YOLOv4, self).__init__()
250-
251-
a = cfg.MODEL_TYPE["TYPE"]
250+
self.showatt = showatt
252251
if cfg.MODEL_TYPE["TYPE"] == "YOLOv4":
253252
# CSPDarknet53 backbone
254253
self.backbone, feature_channels = _BuildCSPDarknet53(
@@ -267,6 +266,8 @@ def __init__(self, weight_path=None, out_channels=255, resume=False):
267266
else:
268267
assert print("model type must be YOLOv4 or Mobilenet-YOLOv4")
269268

269+
if self.showatt:
270+
self.attention = ContextBlock2d(feature_channels[-1], feature_channels[-1])
270271
# Spatial Pyramid Pooling
271272
self.spp = SpatialPyramidPooling(feature_channels)
272273

@@ -277,12 +278,14 @@ def __init__(self, weight_path=None, out_channels=255, resume=False):
277278
self.predict_net = PredictNet(feature_channels, out_channels)
278279

279280
def forward(self, x):
281+
beta = None
280282
features = self.backbone(x)
283+
if self.showatt:
284+
features[-1], beta = self.attention(features[-1])
281285
features[-1] = self.spp(features[-1])
282286
features = self.panet(features)
283287
predicts = self.predict_net(features)
284-
285-
return predicts
288+
return predicts, beta
286289

287290

288291
if __name__ == "__main__":

model/backbones/mobilenetv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def __init__(
249249
weight_path=None,
250250
resume=False,
251251
width_mult=1.0,
252-
feature_channels=[24, 48, 1024],
252+
feature_channels=[24, 48, 1024]
253253
):
254254
super(MobilenetV3, self).__init__()
255255
self.feature_channels = feature_channels

model/build_model.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ class Build_Model(nn.Module):
1414
Note : int the __init__(), to define the modules should be in order, because of the weight file is order
1515
"""
1616

17-
def __init__(self, weight_path=None, resume=False):
17+
def __init__(self, weight_path=None, resume=False, showatt=False):
1818
super(Build_Model, self).__init__()
19-
19+
self.__showatt = showatt
2020
self.__anchors = torch.FloatTensor(cfg.MODEL["ANCHORS"])
2121
self.__strides = torch.FloatTensor(cfg.MODEL["STRIDES"])
2222
if cfg.TRAIN["DATA_TYPE"] == "VOC":
@@ -31,6 +31,7 @@ def __init__(self, weight_path=None, resume=False):
3131
weight_path=weight_path,
3232
out_channels=self.__out_channel,
3333
resume=resume,
34+
showatt=showatt
3435
)
3536
# small
3637
self.__head_s = Yolo_head(
@@ -47,8 +48,7 @@ def __init__(self, weight_path=None, resume=False):
4748

4849
def forward(self, x):
4950
out = []
50-
51-
x_s, x_m, x_l = self.__yolov4(x)
51+
[x_s, x_m, x_l], beta = self.__yolov4(x)
5252

5353
out.append(self.__head_s(x_s))
5454
out.append(self.__head_m(x_m))
@@ -59,6 +59,8 @@ def forward(self, x):
5959
return p, p_d # smalll, medium, large
6060
else:
6161
p, p_d = list(zip(*out))
62+
if self.__showatt:
63+
return p, torch.cat(p_d, 0), beta
6264
return p, torch.cat(p_d, 0)
6365

6466

train.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,15 @@ def detection_collate(batch):
3030

3131

3232
class Trainer(object):
33-
def __init__(self, weight_path, resume, gpu_id, accumulate, fp_16):
33+
def __init__(self, weight_path=None,
34+
resume=False,
35+
gpu_id=0,
36+
accumulate=1,
37+
fp_16=False,
38+
showatt=False):
3439
init_seeds(0)
3540
self.fp_16 = fp_16
41+
self.showatt = showatt
3642
self.device = gpu.select_device(gpu_id)
3743
self.start_epoch = 0
3844
self.best_mAP = 0.0
@@ -59,7 +65,7 @@ def __init__(self, weight_path, resume, gpu_id, accumulate, fp_16):
5965
pin_memory=True,
6066
)
6167

62-
self.yolov4 = Build_Model(weight_path=weight_path, resume=resume).to(
68+
self.yolov4 = Build_Model(weight_path=weight_path, resume=resume, showatt=self.showatt).to(
6369
self.device
6470
)
6571

@@ -269,7 +275,7 @@ def train(self):
269275
logger.info("val img size is {}".format(cfg.VAL["TEST_IMG_SIZE"]))
270276
with torch.no_grad():
271277
APs, inference_time = Evaluator(
272-
self.yolov4, showatt=False
278+
self.yolov4, showatt=self.showatt
273279
).APs_voc()
274280
for i in APs:
275281
logger.info("{} --> mAP : {}".format(i, APs[i]))
@@ -340,6 +346,12 @@ def train(self):
340346
default=False,
341347
help="whither to use fp16 precision",
342348
)
349+
parser.add_argument(
350+
"--showatt",
351+
type=bool,
352+
default=True,
353+
help="whether to show attention map"
354+
)
343355
opt = parser.parse_args()
344356
writer = SummaryWriter(logdir=opt.log_path + "/event")
345357
logger = Logger(
@@ -354,4 +366,5 @@ def train(self):
354366
gpu_id=opt.gpu_id,
355367
accumulate=opt.accumulate,
356368
fp_16=opt.fp_16,
369+
showatt = opt.showatt
357370
).train()

utils/datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __load_annotations(self, anno_type):
8383
"test",
8484
], "You must choice one of the 'train' or 'test' for anno_type parameter"
8585
anno_path = os.path.join(
86-
cfg.PROJECT_PATH, anno_type + "_annotation.txt"
86+
cfg.DATA_PATH, anno_type + "_annotation.txt"
8787
)
8888
with open(anno_path, "r") as f:
8989
annotations = list(filter(lambda x: len(x) > 0, f.readlines()))

utils/heatmap.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,7 @@
88
def imshowAtt(beta, img=None):
99
cv2.namedWindow("img")
1010
cv2.namedWindow("img1")
11-
if img is None:
12-
img = cv2.imread(
13-
os.path.join("VOCdevkit\VOC2007\JPEGImages/000001.jpg"), 1
14-
) # the same input image
11+
assert img is not None
1512

1613
h, w, c = img.shape
1714
img1 = img.copy()

0 commit comments

Comments
 (0)