@@ -32,101 +32,127 @@ namespace csprng {
3232
3333template <typename input_index_calc_t >
3434TORCH_CSPRNG_HOST_DEVICE static void copy_input_to_block (int64_t idx, uint8_t * block, int block_size,
35- void * input_ptr, int64_t input_numel, int input_type_size, input_index_calc_t input_index_calc) {
36- for (auto i = 0 ; i < block_size / input_type_size; ++i) {
37- const auto linear_index = idx * (block_size / input_type_size) + i;
38- if (linear_index < input_numel) {
39- std::memcpy (
40- block + i * input_type_size,
41- &(reinterpret_cast <uint8_t *>(input_ptr)[input_index_calc (linear_index)]),
42- input_type_size
43- );
35+ void * input_ptr, int64_t input_numel, int input_type_size, input_index_calc_t input_index_calc, bool input_is_contiguous) {
36+ if (input_is_contiguous) {
37+ for (auto i = 0 ; i < block_size / input_type_size; ++i) {
38+ const auto linear_index = idx * (block_size / input_type_size) + i;
39+ if (linear_index < input_numel) {
40+ std::memcpy (
41+ block + i * input_type_size,
42+ &(reinterpret_cast <uint8_t *>(input_ptr)[linear_index * input_type_size]),
43+ input_type_size
44+ );
45+ }
46+ }
47+ } else {
48+ for (auto i = 0 ; i < block_size / input_type_size; ++i) {
49+ const auto linear_index = idx * (block_size / input_type_size) + i;
50+ if (linear_index < input_numel) {
51+ std::memcpy (
52+ block + i * input_type_size,
53+ &(reinterpret_cast <uint8_t *>(input_ptr)[input_index_calc (linear_index)]),
54+ input_type_size
55+ );
56+ }
4457 }
4558 }
4659}
4760
4861template <typename output_index_calc_t >
4962TORCH_CSPRNG_HOST_DEVICE static void copy_block_to_output (int64_t idx, uint8_t * block, int output_elem_per_block,
50- void * output_ptr, int64_t output_numel, int output_type_size, output_index_calc_t output_index_calc) {
51- for (auto i = 0 ; i < output_elem_per_block; ++i) {
52- const auto linear_index = idx * output_elem_per_block + i;
53- if (linear_index < output_numel) {
54- std::memcpy (
55- &(reinterpret_cast <uint8_t *>(output_ptr)[output_index_calc (linear_index)]),
56- block + i * output_type_size,
57- output_type_size
58- );
63+ void * output_ptr, int64_t output_numel, int output_type_size, output_index_calc_t output_index_calc, bool output_is_contiguous) {
64+ if (output_is_contiguous) {
65+ for (auto i = 0 ; i < output_elem_per_block; ++i) {
66+ const auto linear_index = idx * output_elem_per_block + i;
67+ if (linear_index < output_numel) {
68+ std::memcpy (
69+ &(reinterpret_cast <uint8_t *>(output_ptr)[linear_index * output_type_size]),
70+ block + i * output_type_size,
71+ output_type_size
72+ );
73+ }
74+ }
75+ } else {
76+ for (auto i = 0 ; i < output_elem_per_block; ++i) {
77+ const auto linear_index = idx * output_elem_per_block + i;
78+ if (linear_index < output_numel) {
79+ std::memcpy (
80+ &(reinterpret_cast <uint8_t *>(output_ptr)[output_index_calc (linear_index)]),
81+ block + i * output_type_size,
82+ output_type_size
83+ );
84+ }
5985 }
6086 }
6187}
6288
6389template <int block_size, typename cipher_t , typename input_index_calc_t , typename output_index_calc_t , typename transform_t >
6490TORCH_CSPRNG_HOST_DEVICE static void block_cipher_kernel_helper (
6591 int64_t idx, cipher_t cipher, int output_elem_per_block,
66- void * input_ptr, int64_t input_numel, int input_type_size, input_index_calc_t input_index_calc,
67- void * output_ptr, int64_t output_numel, int output_type_size, output_index_calc_t output_index_calc,
92+ void * input_ptr, int64_t input_numel, int input_type_size, input_index_calc_t input_index_calc, bool input_is_contiguous,
93+ void * output_ptr, int64_t output_numel, int output_type_size, output_index_calc_t output_index_calc, bool output_is_contiguous,
6894 transform_t transform) {
6995 uint8_t block[block_size];
7096 std::memset (&block, 0 , block_size); // is it ok to use zeros as padding?
7197 if (input_ptr != nullptr ) {
72- copy_input_to_block (idx, block, block_size, input_ptr, input_numel, input_type_size, input_index_calc);
98+ copy_input_to_block (idx, block, block_size, input_ptr, input_numel, input_type_size, input_index_calc, input_is_contiguous );
7399 }
74100 cipher (idx, block);
75101 transform (block);
76- copy_block_to_output (idx, block, output_elem_per_block, output_ptr, output_numel, output_type_size, output_index_calc);
102+ copy_block_to_output (idx, block, output_elem_per_block, output_ptr, output_numel, output_type_size, output_index_calc, output_is_contiguous );
77103}
78104
79105#if defined(__CUDACC__) || defined(__HIPCC__)
80106template <int block_size, typename cipher_t , typename input_index_calc_t , typename output_index_calc_t , typename transform_t >
81107__global__ static void block_cipher_kernel_cuda (cipher_t cipher, int output_elem_per_block,
82- void * input_ptr, int64_t input_numel, int input_type_size, input_index_calc_t input_index_calc,
83- void * output_ptr, int64_t output_numel, int output_type_size, output_index_calc_t output_index_calc,
108+ void * input_ptr, int64_t input_numel, int input_type_size, input_index_calc_t input_index_calc, bool input_is_contiguous,
109+ void * output_ptr, int64_t output_numel, int output_type_size, output_index_calc_t output_index_calc, bool output_is_contiguous,
84110 transform_t transform) {
85111 const auto idx = blockIdx.x * blockDim.x + threadIdx.x ;
86112 block_cipher_kernel_helper<block_size>(idx, cipher, output_elem_per_block,
87- input_ptr, input_numel, input_type_size, input_index_calc,
88- output_ptr, output_numel, output_type_size, output_index_calc,
113+ input_ptr, input_numel, input_type_size, input_index_calc, input_is_contiguous,
114+ output_ptr, output_numel, output_type_size, output_index_calc, output_is_contiguous,
89115 transform);
90116}
91117#endif
92118
93119template <int block_size, typename cipher_t , typename input_index_calc_t , typename output_index_calc_t , typename transform_t >
94120static void block_cipher_kernel_cpu_serial (int64_t begin, int64_t end, cipher_t cipher, int output_elem_per_block,
95- void * input_ptr, int64_t input_numel, int input_type_size, input_index_calc_t input_index_calc,
96- void * output_ptr, int64_t output_numel, int output_type_size, output_index_calc_t output_index_calc,
121+ void * input_ptr, int64_t input_numel, int input_type_size, input_index_calc_t input_index_calc, bool input_is_contiguous,
122+ void * output_ptr, int64_t output_numel, int output_type_size, output_index_calc_t output_index_calc, bool output_is_contiguous,
97123 transform_t transform) {
98124 for (auto idx = begin; idx < end; ++idx) {
99125 block_cipher_kernel_helper<block_size>(idx, cipher, output_elem_per_block,
100- input_ptr, input_numel, input_type_size, input_index_calc,
101- output_ptr, output_numel, output_type_size, output_index_calc,
126+ input_ptr, input_numel, input_type_size, input_index_calc, input_is_contiguous,
127+ output_ptr, output_numel, output_type_size, output_index_calc, output_is_contiguous,
102128 transform);
103129 }
104130}
105131
106132template <int block_size, typename cipher_t , typename input_index_calc_t , typename output_index_calc_t , typename transform_t >
107133static void block_cipher_kernel_cpu (int64_t total, cipher_t cipher, int output_elem_per_block,
108- void * input_ptr, int64_t input_numel, int input_type_size, input_index_calc_t input_index_calc,
109- void * output_ptr, int64_t output_numel, int output_type_size, output_index_calc_t output_index_calc,
134+ void * input_ptr, int64_t input_numel, int input_type_size, input_index_calc_t input_index_calc, bool input_is_contiguous,
135+ void * output_ptr, int64_t output_numel, int output_type_size, output_index_calc_t output_index_calc, bool output_is_contiguous,
110136 transform_t transform_func) {
111137 if (total < at::internal::GRAIN_SIZE || at::get_num_threads () == 1 ) {
112138 block_cipher_kernel_cpu_serial<block_size>(0 , total, cipher, output_elem_per_block,
113- input_ptr, input_numel, input_type_size, input_index_calc,
114- output_ptr, output_numel, output_type_size, output_index_calc,
139+ input_ptr, input_numel, input_type_size, input_index_calc, input_is_contiguous,
140+ output_ptr, output_numel, output_type_size, output_index_calc, output_is_contiguous,
115141 transform_func);
116142 } else {
117143 at::parallel_for (0 , total, at::internal::GRAIN_SIZE, [&](int64_t begin, int64_t end) {
118144 block_cipher_kernel_cpu_serial<block_size>(begin, end, cipher, output_elem_per_block,
119- input_ptr, input_numel, input_type_size, input_index_calc,
120- output_ptr, output_numel, output_type_size, output_index_calc,
145+ input_ptr, input_numel, input_type_size, input_index_calc, input_is_contiguous,
146+ output_ptr, output_numel, output_type_size, output_index_calc, output_is_contiguous,
121147 transform_func);
122148 });
123149 }
124150}
125151
126152template <int block_size, typename cipher_t , typename input_index_calc_t , typename output_index_calc_t , typename transform_t >
127153void block_cipher (
128- void * input_ptr, int64_t input_numel, int input_type_size, input_index_calc_t input_index_calc,
129- void * output_ptr, int64_t output_numel, int output_type_size, output_index_calc_t output_index_calc,
154+ void * input_ptr, int64_t input_numel, int input_type_size, input_index_calc_t input_index_calc, bool input_is_contiguous,
155+ void * output_ptr, int64_t output_numel, int output_type_size, output_index_calc_t output_index_calc, bool output_is_contiguous,
130156 at::Device device, cipher_t cipher, int output_elem_per_block, transform_t transform_func) {
131157 if (output_ptr == nullptr || output_numel == 0 ) {
132158 return ;
@@ -136,8 +162,8 @@ void block_cipher(
136162 const auto total = (output_numel + output_elem_per_block - 1 ) / output_elem_per_block;
137163 block_cipher_kernel_cpu<block_size>(total,
138164 cipher, output_elem_per_block,
139- input_ptr, input_numel, input_type_size, input_index_calc,
140- output_ptr, output_numel, output_type_size, output_index_calc,
165+ input_ptr, input_numel, input_type_size, input_index_calc, input_is_contiguous,
166+ output_ptr, output_numel, output_type_size, output_index_calc, output_is_contiguous,
141167 transform_func
142168 );
143169 } else if (device.type () == at::kCUDA ) {
@@ -147,8 +173,8 @@ void block_cipher(
147173 auto stream = at::cuda::getCurrentCUDAStream ();
148174 block_cipher_kernel_cuda<block_size><<<grid, threads, 0 , stream>>>(
149175 cipher, output_elem_per_block,
150- input_ptr, input_numel, input_type_size, input_index_calc,
151- output_ptr, output_numel, output_type_size, output_index_calc,
176+ input_ptr, input_numel, input_type_size, input_index_calc, input_is_contiguous,
177+ output_ptr, output_numel, output_type_size, output_index_calc, output_is_contiguous,
152178 transform_func
153179 );
154180 AT_CUDA_CHECK (cudaGetLastError ());
@@ -193,8 +219,8 @@ void block_cipher(at::Tensor input, at::Tensor output, cipher_t cipher) {
193219 const auto device = output.device ();
194220
195221 torch::csprng::block_cipher<block_size>(
196- input_ptr, input_numel, input_type_size, input_index_calc,
197- output_ptr, output_numel, output_type_size, output_index_calc,
222+ input_ptr, input_numel, input_type_size, input_index_calc, input. is_contiguous (),
223+ output_ptr, output_numel, output_type_size, output_index_calc, output. is_contiguous (),
198224 device, cipher, block_size / output_type_size,
199225 [] TORCH_CSPRNG_HOST_DEVICE (uint8_t * x) {});
200226}
0 commit comments