1+ import sys
2+ import onnx
3+ import os
4+ import argparse
5+ import numpy as np
6+ import cv2
7+ import onnxruntime
8+ import torch
9+ from utils .tools import *
10+ # from tool.utils import *
11+ from model .build_model import Build_Model
12+ from eval .evaluator import *
13+ import config .yolov4_config as cfg
14+
15+
16+ def convert_predbox (pred_bbox , test_input_size , org_img_shape , valid_scale ):
17+ """
18+ 预测框进行过滤,去除尺度不合理的框
19+ """
20+ pred_coor = xywh2xyxy (pred_bbox [:, :4 ])
21+ pred_conf = pred_bbox [:, 4 ]
22+ pred_prob = pred_bbox [:, 5 :]
23+
24+ # (1)
25+ # (xmin_org, xmax_org) = ((xmin, xmax) - dw) / resize_ratio
26+ # (ymin_org, ymax_org) = ((ymin, ymax) - dh) / resize_ratio
27+ # 需要注意的是,无论我们在训练的时候使用什么数据增强方式,都不影响此处的转换方式
28+ # 假设我们对输入测试图片使用了转换方式A,那么此处对bbox的转换方式就是方式A的逆向过程
29+ org_h , org_w = org_img_shape
30+ resize_ratio = min (1.0 * test_input_size / org_w , 1.0 * test_input_size / org_h )
31+ dw = (test_input_size - resize_ratio * org_w ) / 2
32+ dh = (test_input_size - resize_ratio * org_h ) / 2
33+ pred_coor [:, 0 ::2 ] = 1.0 * (pred_coor [:, 0 ::2 ] - dw ) / resize_ratio
34+ pred_coor [:, 1 ::2 ] = 1.0 * (pred_coor [:, 1 ::2 ] - dh ) / resize_ratio
35+
36+ # (2)将预测的bbox中超出原图的部分裁掉
37+ pred_coor = np .concatenate ([np .maximum (pred_coor [:, :2 ], [0 , 0 ]),
38+ np .minimum (pred_coor [:, 2 :], [org_w - 1 , org_h - 1 ])], axis = - 1 )
39+ # (3)将无效bbox的coor置为0
40+ invalid_mask = np .logical_or ((pred_coor [:, 0 ] > pred_coor [:, 2 ]), (pred_coor [:, 1 ] > pred_coor [:, 3 ]))
41+ pred_coor [invalid_mask ] = 0
42+
43+ # (4)去掉不在有效范围内的bbox
44+ bboxes_scale = np .sqrt (np .multiply .reduce (pred_coor [:, 2 :4 ] - pred_coor [:, 0 :2 ], axis = - 1 ))
45+ scale_mask = np .logical_and ((valid_scale [0 ] < bboxes_scale ), (bboxes_scale < valid_scale [1 ]))
46+
47+ # (5)将score低于score_threshold的bbox去掉
48+ classes = np .argmax (pred_prob , axis = - 1 )
49+ scores = pred_conf * pred_prob [np .arange (len (pred_coor )), classes ]
50+ score_mask = scores > cfg .VAL ["CONF_THRESH" ]
51+
52+ mask = np .logical_and (scale_mask , score_mask )
53+
54+ coors = pred_coor [mask ]
55+ scores = scores [mask ]
56+ classes = classes [mask ]
57+
58+ bboxes = np .concatenate ([coors , scores [:, np .newaxis ], classes [:, np .newaxis ]], axis = - 1 )
59+
60+ return bboxes
61+
62+
63+ def detect (session , image_src ):
64+ IN_IMAGE_H = session .get_inputs ()[0 ].shape [2 ]
65+ IN_IMAGE_W = session .get_inputs ()[0 ].shape [3 ]
66+
67+ # Input
68+ resized = cv2 .resize (image_src , (IN_IMAGE_W , IN_IMAGE_H ), interpolation = cv2 .INTER_LINEAR )
69+ img_in = cv2 .cvtColor (resized , cv2 .COLOR_BGR2RGB )
70+ img_in = np .transpose (img_in , (2 , 0 , 1 )).astype (np .float32 )
71+ img_in = np .expand_dims (img_in , axis = 0 )
72+ img_in /= 255.0
73+ print ("Shape of the network input: " , img_in .shape )
74+
75+ # Compute
76+ input_name = session .get_inputs ()[0 ].name
77+
78+ outputs = session .run (None , {input_name : img_in })
79+ bboxes = convert_predbox (outputs [- 1 ], IN_IMAGE_H , (IN_IMAGE_H , IN_IMAGE_W ), (0 , np .inf ))
80+ bboxes_prd = nms (bboxes , cfg .VAL ["CONF_THRESH" ], cfg .VAL ["NMS_THRESH" ])
81+ if bboxes_prd .shape [0 ] != 0 :
82+ boxes = bboxes_prd [..., :4 ]
83+ class_inds = bboxes_prd [..., 5 ].astype (np .int32 )
84+ scores = bboxes_prd [..., 4 ]
85+ visualize_boxes (image = image_src , boxes = boxes , labels = class_inds , probs = scores , class_labels = cfg .VOC_DATA ["CLASSES" ])
86+ path = os .path .join (cfg .PROJECT_PATH , "save.jpg" )
87+ cv2 .imwrite (path , image_src )
88+ print ("saved images : {}" .format (path ))
89+
90+
91+ def transform_to_onnx (weight_file , batch_size , IN_IMAGE_H , IN_IMAGE_W ):
92+ model = Build_Model ()
93+ pretrained_dict = torch .load (weight_file , map_location = torch .device ('cpu' ))
94+ model .load_state_dict (pretrained_dict )
95+
96+ input_names = ["input" ]
97+ output_names = ['boxes' , 'confs' ]
98+
99+ dynamic = False
100+ if batch_size <= 0 :
101+ dynamic = True
102+
103+ if dynamic :
104+ x = torch .randn ((1 , 3 , IN_IMAGE_H , IN_IMAGE_W ), requires_grad = True )
105+ onnx_file_name = "yolov4_-1_3_{}_{}_dynamic.onnx" .format (IN_IMAGE_H , IN_IMAGE_W )
106+ dynamic_axes = {"input" : {0 : "batch_size" }, "boxes" : {0 : "batch_size" }, "confs" : {0 : "batch_size" }}
107+ # Export the model
108+ print ('Export the onnx model ...' )
109+ torch .onnx .export (model ,
110+ x ,
111+ onnx_file_name ,
112+ export_params = True ,
113+ opset_version = 11 ,
114+ do_constant_folding = True ,
115+ input_names = input_names , output_names = output_names ,
116+ dynamic_axes = dynamic_axes )
117+
118+ print ('Onnx model exporting done' )
119+ return onnx_file_name
120+
121+ else :
122+ x = torch .randn ((batch_size , 3 , IN_IMAGE_H , IN_IMAGE_W ), requires_grad = True )
123+ onnx_file_name = "yolov4_{}_3_{}_{}_static.onnx" .format (batch_size , IN_IMAGE_H , IN_IMAGE_W )
124+ # Export the model
125+ print ('Export the onnx model ...' )
126+ torch .onnx .export (model ,
127+ x ,
128+ onnx_file_name ,
129+ export_params = True ,
130+ do_constant_folding = True ,
131+ input_names = input_names , output_names = output_names ,
132+ )
133+
134+ print ('Onnx model exporting done' )
135+ return onnx_file_name
136+
137+
138+ def main (weight_file = None , image_path = None , batch_size = 1 , IN_IMAGE_H = 416 , IN_IMAGE_W = 416 ):
139+ if batch_size <= 0 :
140+ onnx_path_demo = transform_to_onnx (weight_file , batch_size , IN_IMAGE_H , IN_IMAGE_W )
141+ else :
142+ # Transform to onnx as specified batch size
143+ transform_to_onnx (weight_file , batch_size , IN_IMAGE_H , IN_IMAGE_W )
144+ # Transform to onnx for demo
145+ onnx_path_demo = transform_to_onnx (weight_file , 1 , IN_IMAGE_H , IN_IMAGE_W )
146+
147+ session = onnxruntime .InferenceSession (onnx_path_demo )
148+ print ("The model expects input shape: " , session .get_inputs ()[0 ].shape )
149+ image_src = cv2 .imread (image_path )
150+ detect (session , image_src )
151+
152+
153+ if __name__ == '__main__' :
154+ import os .path as osp
155+ print ("Converting to onnx and running demo ..." )
156+ PROJECT_PATH = osp .abspath (osp .dirname (__file__ ))
157+ weight_file = osp .join (PROJECT_PATH , 'weight/best.pt' )
158+ image_path = osp .join (PROJECT_PATH , '000001.jpg' )
159+ main (weight_file = weight_file ,image_path = image_path )
0 commit comments