Skip to content

Commit 3fbd40a

Browse files
committed
Performance optimization for contiguous tensors
ghstack-source-id: bc74b98 Pull Request resolved: #103
1 parent ba3bcab commit 3fbd40a

File tree

2 files changed

+72
-46
lines changed

2 files changed

+72
-46
lines changed

torchcsprng/csrc/block_cipher.h

Lines changed: 70 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -32,101 +32,127 @@ namespace csprng {
3232

3333
template<typename input_index_calc_t>
3434
TORCH_CSPRNG_HOST_DEVICE static void copy_input_to_block(int64_t idx, uint8_t* block, int block_size,
35-
void* input_ptr, int64_t input_numel, int input_type_size, input_index_calc_t input_index_calc) {
36-
for (auto i = 0; i < block_size / input_type_size; ++i) {
37-
const auto linear_index = idx * (block_size / input_type_size) + i;
38-
if (linear_index < input_numel) {
39-
std::memcpy(
40-
block + i * input_type_size,
41-
&(reinterpret_cast<uint8_t*>(input_ptr)[input_index_calc(linear_index)]),
42-
input_type_size
43-
);
35+
void* input_ptr, int64_t input_numel, int input_type_size, input_index_calc_t input_index_calc, bool input_is_contiguous) {
36+
if (input_is_contiguous) {
37+
for (auto i = 0; i < block_size / input_type_size; ++i) {
38+
const auto linear_index = idx * (block_size / input_type_size) + i;
39+
if (linear_index < input_numel) {
40+
std::memcpy(
41+
block + i * input_type_size,
42+
&(reinterpret_cast<uint8_t*>(input_ptr)[linear_index * input_type_size]),
43+
input_type_size
44+
);
45+
}
46+
}
47+
} else {
48+
for (auto i = 0; i < block_size / input_type_size; ++i) {
49+
const auto linear_index = idx * (block_size / input_type_size) + i;
50+
if (linear_index < input_numel) {
51+
std::memcpy(
52+
block + i * input_type_size,
53+
&(reinterpret_cast<uint8_t*>(input_ptr)[input_index_calc(linear_index)]),
54+
input_type_size
55+
);
56+
}
4457
}
4558
}
4659
}
4760

4861
template<typename output_index_calc_t>
4962
TORCH_CSPRNG_HOST_DEVICE static void copy_block_to_output(int64_t idx, uint8_t* block, int output_elem_per_block,
50-
void* output_ptr, int64_t output_numel, int output_type_size, output_index_calc_t output_index_calc) {
51-
for (auto i = 0; i < output_elem_per_block; ++i) {
52-
const auto linear_index = idx * output_elem_per_block + i;
53-
if (linear_index < output_numel) {
54-
std::memcpy(
55-
&(reinterpret_cast<uint8_t*>(output_ptr)[output_index_calc(linear_index)]),
56-
block + i * output_type_size,
57-
output_type_size
58-
);
63+
void* output_ptr, int64_t output_numel, int output_type_size, output_index_calc_t output_index_calc, bool output_is_contiguous) {
64+
if (output_is_contiguous) {
65+
for (auto i = 0; i < output_elem_per_block; ++i) {
66+
const auto linear_index = idx * output_elem_per_block + i;
67+
if (linear_index < output_numel) {
68+
std::memcpy(
69+
&(reinterpret_cast<uint8_t*>(output_ptr)[linear_index * output_type_size]),
70+
block + i * output_type_size,
71+
output_type_size
72+
);
73+
}
74+
}
75+
} else {
76+
for (auto i = 0; i < output_elem_per_block; ++i) {
77+
const auto linear_index = idx * output_elem_per_block + i;
78+
if (linear_index < output_numel) {
79+
std::memcpy(
80+
&(reinterpret_cast<uint8_t*>(output_ptr)[output_index_calc(linear_index)]),
81+
block + i * output_type_size,
82+
output_type_size
83+
);
84+
}
5985
}
6086
}
6187
}
6288

6389
template<int block_size, typename cipher_t, typename input_index_calc_t, typename output_index_calc_t, typename transform_t>
6490
TORCH_CSPRNG_HOST_DEVICE static void block_cipher_kernel_helper(
6591
int64_t idx, cipher_t cipher, int output_elem_per_block,
66-
void* input_ptr, int64_t input_numel, int input_type_size, input_index_calc_t input_index_calc,
67-
void* output_ptr, int64_t output_numel, int output_type_size, output_index_calc_t output_index_calc,
92+
void* input_ptr, int64_t input_numel, int input_type_size, input_index_calc_t input_index_calc, bool input_is_contiguous,
93+
void* output_ptr, int64_t output_numel, int output_type_size, output_index_calc_t output_index_calc, bool output_is_contiguous,
6894
transform_t transform) {
6995
uint8_t block[block_size];
7096
std::memset(&block, 0, block_size); // is it ok to use zeros as padding?
7197
if (input_ptr != nullptr) {
72-
copy_input_to_block(idx, block, block_size, input_ptr, input_numel, input_type_size, input_index_calc);
98+
copy_input_to_block(idx, block, block_size, input_ptr, input_numel, input_type_size, input_index_calc, input_is_contiguous);
7399
}
74100
cipher(idx, block);
75101
transform(block);
76-
copy_block_to_output(idx, block, output_elem_per_block, output_ptr, output_numel, output_type_size, output_index_calc);
102+
copy_block_to_output(idx, block, output_elem_per_block, output_ptr, output_numel, output_type_size, output_index_calc, output_is_contiguous);
77103
}
78104

79105
#if defined(__CUDACC__) || defined(__HIPCC__)
80106
template<int block_size, typename cipher_t, typename input_index_calc_t, typename output_index_calc_t, typename transform_t>
81107
__global__ static void block_cipher_kernel_cuda(cipher_t cipher, int output_elem_per_block,
82-
void* input_ptr, int64_t input_numel, int input_type_size, input_index_calc_t input_index_calc,
83-
void* output_ptr, int64_t output_numel, int output_type_size, output_index_calc_t output_index_calc,
108+
void* input_ptr, int64_t input_numel, int input_type_size, input_index_calc_t input_index_calc, bool input_is_contiguous,
109+
void* output_ptr, int64_t output_numel, int output_type_size, output_index_calc_t output_index_calc, bool output_is_contiguous,
84110
transform_t transform) {
85111
const auto idx = blockIdx.x * blockDim.x + threadIdx.x;
86112
block_cipher_kernel_helper<block_size>(idx, cipher, output_elem_per_block,
87-
input_ptr, input_numel, input_type_size, input_index_calc,
88-
output_ptr, output_numel, output_type_size, output_index_calc,
113+
input_ptr, input_numel, input_type_size, input_index_calc, input_is_contiguous,
114+
output_ptr, output_numel, output_type_size, output_index_calc, output_is_contiguous,
89115
transform);
90116
}
91117
#endif
92118

93119
template<int block_size, typename cipher_t, typename input_index_calc_t, typename output_index_calc_t, typename transform_t>
94120
static void block_cipher_kernel_cpu_serial(int64_t begin, int64_t end, cipher_t cipher, int output_elem_per_block,
95-
void* input_ptr, int64_t input_numel, int input_type_size, input_index_calc_t input_index_calc,
96-
void* output_ptr, int64_t output_numel, int output_type_size, output_index_calc_t output_index_calc,
121+
void* input_ptr, int64_t input_numel, int input_type_size, input_index_calc_t input_index_calc, bool input_is_contiguous,
122+
void* output_ptr, int64_t output_numel, int output_type_size, output_index_calc_t output_index_calc, bool output_is_contiguous,
97123
transform_t transform) {
98124
for (auto idx = begin; idx < end; ++idx) {
99125
block_cipher_kernel_helper<block_size>(idx, cipher, output_elem_per_block,
100-
input_ptr, input_numel, input_type_size, input_index_calc,
101-
output_ptr, output_numel, output_type_size, output_index_calc,
126+
input_ptr, input_numel, input_type_size, input_index_calc, input_is_contiguous,
127+
output_ptr, output_numel, output_type_size, output_index_calc, output_is_contiguous,
102128
transform);
103129
}
104130
}
105131

106132
template<int block_size, typename cipher_t, typename input_index_calc_t, typename output_index_calc_t, typename transform_t>
107133
static void block_cipher_kernel_cpu(int64_t total, cipher_t cipher, int output_elem_per_block,
108-
void* input_ptr, int64_t input_numel, int input_type_size, input_index_calc_t input_index_calc,
109-
void* output_ptr, int64_t output_numel, int output_type_size, output_index_calc_t output_index_calc,
134+
void* input_ptr, int64_t input_numel, int input_type_size, input_index_calc_t input_index_calc, bool input_is_contiguous,
135+
void* output_ptr, int64_t output_numel, int output_type_size, output_index_calc_t output_index_calc, bool output_is_contiguous,
110136
transform_t transform_func) {
111137
if (total < at::internal::GRAIN_SIZE || at::get_num_threads() == 1) {
112138
block_cipher_kernel_cpu_serial<block_size>(0, total, cipher, output_elem_per_block,
113-
input_ptr, input_numel, input_type_size, input_index_calc,
114-
output_ptr, output_numel, output_type_size, output_index_calc,
139+
input_ptr, input_numel, input_type_size, input_index_calc, input_is_contiguous,
140+
output_ptr, output_numel, output_type_size, output_index_calc, output_is_contiguous,
115141
transform_func);
116142
} else {
117143
at::parallel_for(0, total, at::internal::GRAIN_SIZE, [&](int64_t begin, int64_t end) {
118144
block_cipher_kernel_cpu_serial<block_size>(begin, end, cipher, output_elem_per_block,
119-
input_ptr, input_numel, input_type_size, input_index_calc,
120-
output_ptr, output_numel, output_type_size, output_index_calc,
145+
input_ptr, input_numel, input_type_size, input_index_calc, input_is_contiguous,
146+
output_ptr, output_numel, output_type_size, output_index_calc, output_is_contiguous,
121147
transform_func);
122148
});
123149
}
124150
}
125151

126152
template<int block_size, typename cipher_t, typename input_index_calc_t, typename output_index_calc_t, typename transform_t>
127153
void block_cipher(
128-
void* input_ptr, int64_t input_numel, int input_type_size, input_index_calc_t input_index_calc,
129-
void* output_ptr, int64_t output_numel, int output_type_size, output_index_calc_t output_index_calc,
154+
void* input_ptr, int64_t input_numel, int input_type_size, input_index_calc_t input_index_calc, bool input_is_contiguous,
155+
void* output_ptr, int64_t output_numel, int output_type_size, output_index_calc_t output_index_calc, bool output_is_contiguous,
130156
at::Device device, cipher_t cipher, int output_elem_per_block, transform_t transform_func) {
131157
if (output_ptr == nullptr || output_numel == 0) {
132158
return;
@@ -136,8 +162,8 @@ void block_cipher(
136162
const auto total = (output_numel + output_elem_per_block - 1) / output_elem_per_block;
137163
block_cipher_kernel_cpu<block_size>(total,
138164
cipher, output_elem_per_block,
139-
input_ptr, input_numel, input_type_size, input_index_calc,
140-
output_ptr, output_numel, output_type_size, output_index_calc,
165+
input_ptr, input_numel, input_type_size, input_index_calc, input_is_contiguous,
166+
output_ptr, output_numel, output_type_size, output_index_calc, output_is_contiguous,
141167
transform_func
142168
);
143169
} else if (device.type() == at::kCUDA) {
@@ -147,8 +173,8 @@ void block_cipher(
147173
auto stream = at::cuda::getCurrentCUDAStream();
148174
block_cipher_kernel_cuda<block_size><<<grid, threads, 0, stream>>>(
149175
cipher, output_elem_per_block,
150-
input_ptr, input_numel, input_type_size, input_index_calc,
151-
output_ptr, output_numel, output_type_size, output_index_calc,
176+
input_ptr, input_numel, input_type_size, input_index_calc, input_is_contiguous,
177+
output_ptr, output_numel, output_type_size, output_index_calc, output_is_contiguous,
152178
transform_func
153179
);
154180
AT_CUDA_CHECK(cudaGetLastError());
@@ -193,8 +219,8 @@ void block_cipher(at::Tensor input, at::Tensor output, cipher_t cipher) {
193219
const auto device = output.device();
194220

195221
torch::csprng::block_cipher<block_size>(
196-
input_ptr, input_numel, input_type_size, input_index_calc,
197-
output_ptr, output_numel, output_type_size, output_index_calc,
222+
input_ptr, input_numel, input_type_size, input_index_calc, input.is_contiguous(),
223+
output_ptr, output_numel, output_type_size, output_index_calc, output.is_contiguous(),
198224
device, cipher, block_size / output_type_size,
199225
[] TORCH_CSPRNG_HOST_DEVICE (uint8_t* x) {});
200226
}

torchcsprng/csrc/kernels_body.inc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ void aes_helper(at::TensorIterator& iter, const uint8_t* key_bytes, transform_t
6666
return output_offset_calc.get(li)[0];
6767
};
6868
torch::csprng::block_cipher<aes::block_t_size>(
69-
nullptr, 0, 0, output_index_calc,
70-
output.data_ptr(), output.numel(), output.element_size(), output_index_calc,
69+
nullptr, 0, 0, output_index_calc, false,
70+
output.data_ptr(), output.numel(), output.element_size(), output_index_calc, output.is_contiguous(),
7171
iter.device_type(),
7272
[key_bytes] TORCH_CSPRNG_HOST_DEVICE (int64_t idx, uint8_t* block) -> void {
7373
uint8_t idx_block[aes::block_t_size];

0 commit comments

Comments
 (0)