Skip to content

Commit a32644e

Browse files
authored
Merge pull request #32 from izolot/libtorch_1.3
some fixes for work with libtorch 1.3
2 parents d967991 + eedbaf6 commit a32644e

File tree

2 files changed

+12
-13
lines changed

2 files changed

+12
-13
lines changed

Darknet.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -133,17 +133,17 @@ string Darknet::get_string_from_cfg(map<string, string> block, string key, strin
133133
torch::nn::Conv2dOptions conv_options(int64_t in_planes, int64_t out_planes, int64_t kerner_size,
134134
int64_t stride, int64_t padding, int64_t groups, bool with_bias=false){
135135
torch::nn::Conv2dOptions conv_options = torch::nn::Conv2dOptions(in_planes, out_planes, kerner_size);
136-
conv_options.stride_ = stride;
137-
conv_options.padding_ = padding;
138-
conv_options.groups_ = groups;
139-
conv_options.with_bias_ = with_bias;
136+
conv_options.stride(stride);
137+
conv_options.padding(padding);
138+
conv_options.groups(groups);
139+
conv_options.with_bias(with_bias);
140140
return conv_options;
141141
}
142142

143143
torch::nn::BatchNormOptions bn_options(int64_t features){
144144
torch::nn::BatchNormOptions bn_options = torch::nn::BatchNormOptions(features);
145-
bn_options.affine_ = true;
146-
bn_options.stateful_ = true;
145+
bn_options.affine(true);
146+
bn_options.stateful(true);
147147
return bn_options;
148148
}
149149

@@ -524,7 +524,7 @@ void Darknet::load_weights(const char *weight_file)
524524
at::TensorOptions options= torch::TensorOptions()
525525
.dtype(torch::kFloat32)
526526
.is_variable(true);
527-
at::Tensor weights = torch::CPU(torch::kFloat32).tensorFromBlob(weights_src, {length/4});
527+
at::Tensor weights = torch::from_blob(weights_src, {length/4});
528528

529529
for (int i = 0; i < module_list.size(); i++)
530530
{
@@ -566,12 +566,12 @@ void Darknet::load_weights(const char *weight_file)
566566
bn_bias = bn_bias.view_as(bn_imp->bias);
567567
bn_weights = bn_weights.view_as(bn_imp->weight);
568568
bn_running_mean = bn_running_mean.view_as(bn_imp->running_mean);
569-
bn_running_var = bn_running_var.view_as(bn_imp->running_variance);
569+
bn_running_var = bn_running_var.view_as(bn_imp->running_var);
570570

571571
bn_imp->bias.set_data(bn_bias);
572572
bn_imp->weight.set_data(bn_weights);
573573
bn_imp->running_mean.set_data(bn_running_mean);
574-
bn_imp->running_variance.set_data(bn_running_var);
574+
bn_imp->running_var.set_data(bn_running_var);
575575
}
576576
else
577577
{

main.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,12 @@ int main(int argc, const char* argv[])
5757
cv::Mat img_float;
5858
resized_image.convertTo(img_float, CV_32F, 1.0/255);
5959

60-
auto img_tensor = torch::CPU(torch::kFloat32).tensorFromBlob(img_float.data, {1, input_image_size, input_image_size, 3});
60+
auto img_tensor = torch::from_blob(img_float.data, {1, input_image_size, input_image_size, 3}).to(device);
6161
img_tensor = img_tensor.permute({0,3,1,2});
62-
auto img_var = torch::autograd::make_variable(img_tensor, false).to(device);
6362

6463
auto start = std::chrono::high_resolution_clock::now();
6564

66-
auto output = net.forward(img_var);
65+
auto output = net.forward(img_tensor);
6766

6867
// filter result by NMS
6968
// class_num = 80
@@ -108,4 +107,4 @@ int main(int argc, const char* argv[])
108107
std::cout << "Done" << endl;
109108

110109
return 0;
111-
}
110+
}

0 commit comments

Comments
 (0)