Skip to content

Commit 158ea92

Browse files
committed
to Onnx
1 parent ba681da commit 158ea92

File tree

1 file changed

+159
-0
lines changed

1 file changed

+159
-0
lines changed

onnx_transform.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import sys
2+
import onnx
3+
import os
4+
import argparse
5+
import numpy as np
6+
import cv2
7+
import onnxruntime
8+
import torch
9+
from utils.tools import *
10+
# from tool.utils import *
11+
from model.build_model import Build_Model
12+
from eval.evaluator import *
13+
import config.yolov4_config as cfg
14+
15+
16+
def convert_predbox(pred_bbox, test_input_size, org_img_shape, valid_scale):
17+
"""
18+
预测框进行过滤,去除尺度不合理的框
19+
"""
20+
pred_coor = xywh2xyxy(pred_bbox[:, :4])
21+
pred_conf = pred_bbox[:, 4]
22+
pred_prob = pred_bbox[:, 5:]
23+
24+
# (1)
25+
# (xmin_org, xmax_org) = ((xmin, xmax) - dw) / resize_ratio
26+
# (ymin_org, ymax_org) = ((ymin, ymax) - dh) / resize_ratio
27+
# 需要注意的是,无论我们在训练的时候使用什么数据增强方式,都不影响此处的转换方式
28+
# 假设我们对输入测试图片使用了转换方式A,那么此处对bbox的转换方式就是方式A的逆向过程
29+
org_h, org_w = org_img_shape
30+
resize_ratio = min(1.0 * test_input_size / org_w, 1.0 * test_input_size / org_h)
31+
dw = (test_input_size - resize_ratio * org_w) / 2
32+
dh = (test_input_size - resize_ratio * org_h) / 2
33+
pred_coor[:, 0::2] = 1.0 * (pred_coor[:, 0::2] - dw) / resize_ratio
34+
pred_coor[:, 1::2] = 1.0 * (pred_coor[:, 1::2] - dh) / resize_ratio
35+
36+
# (2)将预测的bbox中超出原图的部分裁掉
37+
pred_coor = np.concatenate([np.maximum(pred_coor[:, :2], [0, 0]),
38+
np.minimum(pred_coor[:, 2:], [org_w - 1, org_h - 1])], axis=-1)
39+
# (3)将无效bbox的coor置为0
40+
invalid_mask = np.logical_or((pred_coor[:, 0] > pred_coor[:, 2]), (pred_coor[:, 1] > pred_coor[:, 3]))
41+
pred_coor[invalid_mask] = 0
42+
43+
# (4)去掉不在有效范围内的bbox
44+
bboxes_scale = np.sqrt(np.multiply.reduce(pred_coor[:, 2:4] - pred_coor[:, 0:2], axis=-1))
45+
scale_mask = np.logical_and((valid_scale[0] < bboxes_scale), (bboxes_scale < valid_scale[1]))
46+
47+
# (5)将score低于score_threshold的bbox去掉
48+
classes = np.argmax(pred_prob, axis=-1)
49+
scores = pred_conf * pred_prob[np.arange(len(pred_coor)), classes]
50+
score_mask = scores > cfg.VAL["CONF_THRESH"]
51+
52+
mask = np.logical_and(scale_mask, score_mask)
53+
54+
coors = pred_coor[mask]
55+
scores = scores[mask]
56+
classes = classes[mask]
57+
58+
bboxes = np.concatenate([coors, scores[:, np.newaxis], classes[:, np.newaxis]], axis=-1)
59+
60+
return bboxes
61+
62+
63+
def detect(session, image_src):
64+
IN_IMAGE_H = session.get_inputs()[0].shape[2]
65+
IN_IMAGE_W = session.get_inputs()[0].shape[3]
66+
67+
# Input
68+
resized = cv2.resize(image_src, (IN_IMAGE_W, IN_IMAGE_H), interpolation=cv2.INTER_LINEAR)
69+
img_in = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
70+
img_in = np.transpose(img_in, (2, 0, 1)).astype(np.float32)
71+
img_in = np.expand_dims(img_in, axis=0)
72+
img_in /= 255.0
73+
print("Shape of the network input: ", img_in.shape)
74+
75+
# Compute
76+
input_name = session.get_inputs()[0].name
77+
78+
outputs = session.run(None, {input_name: img_in})
79+
bboxes = convert_predbox(outputs[-1], IN_IMAGE_H, (IN_IMAGE_H, IN_IMAGE_W), (0, np.inf))
80+
bboxes_prd = nms(bboxes, cfg.VAL["CONF_THRESH"], cfg.VAL["NMS_THRESH"])
81+
if bboxes_prd.shape[0] != 0:
82+
boxes = bboxes_prd[..., :4]
83+
class_inds = bboxes_prd[..., 5].astype(np.int32)
84+
scores = bboxes_prd[..., 4]
85+
visualize_boxes(image=image_src, boxes=boxes, labels=class_inds, probs=scores, class_labels=cfg.VOC_DATA["CLASSES"])
86+
path = os.path.join(cfg.PROJECT_PATH, "save.jpg")
87+
cv2.imwrite(path, image_src)
88+
print("saved images : {}".format(path))
89+
90+
91+
def transform_to_onnx(weight_file, batch_size, IN_IMAGE_H, IN_IMAGE_W):
92+
model = Build_Model()
93+
pretrained_dict = torch.load(weight_file, map_location=torch.device('cpu'))
94+
model.load_state_dict(pretrained_dict)
95+
96+
input_names = ["input"]
97+
output_names = ['boxes', 'confs']
98+
99+
dynamic = False
100+
if batch_size <= 0:
101+
dynamic = True
102+
103+
if dynamic:
104+
x = torch.randn((1, 3, IN_IMAGE_H, IN_IMAGE_W), requires_grad=True)
105+
onnx_file_name = "yolov4_-1_3_{}_{}_dynamic.onnx".format(IN_IMAGE_H, IN_IMAGE_W)
106+
dynamic_axes = {"input": {0: "batch_size"}, "boxes": {0: "batch_size"}, "confs": {0: "batch_size"}}
107+
# Export the model
108+
print('Export the onnx model ...')
109+
torch.onnx.export(model,
110+
x,
111+
onnx_file_name,
112+
export_params=True,
113+
opset_version=11,
114+
do_constant_folding=True,
115+
input_names=input_names, output_names=output_names,
116+
dynamic_axes=dynamic_axes)
117+
118+
print('Onnx model exporting done')
119+
return onnx_file_name
120+
121+
else:
122+
x = torch.randn((batch_size, 3, IN_IMAGE_H, IN_IMAGE_W), requires_grad=True)
123+
onnx_file_name = "yolov4_{}_3_{}_{}_static.onnx".format(batch_size, IN_IMAGE_H, IN_IMAGE_W)
124+
# Export the model
125+
print('Export the onnx model ...')
126+
torch.onnx.export(model,
127+
x,
128+
onnx_file_name,
129+
export_params=True,
130+
do_constant_folding=True,
131+
input_names=input_names, output_names=output_names,
132+
)
133+
134+
print('Onnx model exporting done')
135+
return onnx_file_name
136+
137+
138+
def main(weight_file=None, image_path= None, batch_size= 1, IN_IMAGE_H= 416, IN_IMAGE_W= 416):
139+
if batch_size <= 0:
140+
onnx_path_demo = transform_to_onnx(weight_file, batch_size, IN_IMAGE_H, IN_IMAGE_W)
141+
else:
142+
# Transform to onnx as specified batch size
143+
transform_to_onnx(weight_file, batch_size, IN_IMAGE_H, IN_IMAGE_W)
144+
# Transform to onnx for demo
145+
onnx_path_demo = transform_to_onnx(weight_file, 1, IN_IMAGE_H, IN_IMAGE_W)
146+
147+
session = onnxruntime.InferenceSession(onnx_path_demo)
148+
print("The model expects input shape: ", session.get_inputs()[0].shape)
149+
image_src = cv2.imread(image_path)
150+
detect(session, image_src)
151+
152+
153+
if __name__ == '__main__':
154+
import os.path as osp
155+
print("Converting to onnx and running demo ...")
156+
PROJECT_PATH = osp.abspath(osp.dirname(__file__))
157+
weight_file = osp.join(PROJECT_PATH, 'weight/best.pt')
158+
image_path = osp.join(PROJECT_PATH, '000001.jpg')
159+
main(weight_file=weight_file,image_path=image_path)

0 commit comments

Comments
 (0)