1212
1313
1414class Evaluator (object ):
15- def __init__ (self , model , showatt ):
15+ def __init__ (self , model = None , showatt = False ):
1616 if cfg .TRAIN ["DATA_TYPE" ] == "VOC" :
1717 self .classes = cfg .VOC_DATA ["CLASSES" ]
1818 elif cfg .TRAIN ["DATA_TYPE" ] == "COCO" :
@@ -88,32 +88,32 @@ def APs_voc(self, multi_test=False, flip_test=False):
8888 self .inference_time = 1.0 * self .inference_time / len (img_inds )
8989 return self .__calc_APs (), self .inference_time
9090
91- def get_bbox (self , img , multi_test = False , flip_test = False ):
91+ def get_bbox (self , img , multi_test = False , flip_test = False , mode = None ):
9292 if multi_test :
9393 test_input_sizes = range (320 , 640 , 96 )
9494 bboxes_list = []
9595 for test_input_size in test_input_sizes :
9696 valid_scale = (0 , np .inf )
9797 bboxes_list .append (
98- self .__predict (img , test_input_size , valid_scale )
98+ self .__predict (img , test_input_size , valid_scale , mode )
9999 )
100100 if flip_test :
101101 bboxes_flip = self .__predict (
102- img [:, ::- 1 ], test_input_size , valid_scale
102+ img [:, ::- 1 ], test_input_size , valid_scale , mode
103103 )
104104 bboxes_flip [:, [0 , 2 ]] = (
105105 img .shape [1 ] - bboxes_flip [:, [2 , 0 ]]
106106 )
107107 bboxes_list .append (bboxes_flip )
108108 bboxes = np .row_stack (bboxes_list )
109109 else :
110- bboxes = self .__predict (img , self .val_shape , (0 , np .inf ))
110+ bboxes = self .__predict (img , self .val_shape , (0 , np .inf ), mode )
111111
112112 bboxes = nms (bboxes , self .conf_thresh , self .nms_thresh )
113113
114114 return bboxes
115115
116- def __predict (self , img , test_shape , valid_scale ):
116+ def __predict (self , img , test_shape , valid_scale , mode ):
117117 org_img = np .copy (img )
118118 org_h , org_w , _ = org_img .shape
119119
@@ -130,8 +130,8 @@ def __predict(self, img, test_shape, valid_scale):
130130 bboxes = self .__convert_pred (
131131 pred_bbox , test_shape , (org_h , org_w ), valid_scale
132132 )
133- if self .showatt and len (img ):
134- self .__show_heatmap (beta [ 2 ] , org_img )
133+ if self .showatt and len (img ) and mode == 'det' :
134+ self .__show_heatmap (beta , org_img )
135135 return bboxes
136136
137137 def __show_heatmap (self , beta , img ):
0 commit comments