1919
2020using namespace at ;
2121using namespace at ::native::templates;
22- using namespace torch ::custom_prng ;
22+ using namespace torch ::csprng ;
2323
2424inline uint64_t make64BitsFrom32Bits (uint32_t hi, uint32_t lo) {
2525 return (static_cast <uint64_t >(hi) << 32 ) | lo;
2626}
2727
2828// CUDA CSPRNG is actually CPU generator which is used only to generate a random key on CPU for AES running in a block mode on CUDA
29- struct CustomGeneratorImpl : public c10 ::GeneratorImpl {
30- CustomGeneratorImpl (bool use_rd) : c10::GeneratorImpl{Device (DeviceType::CPU), DispatchKeySet (DispatchKey::CustomRNGKeyId)}, use_rd_{use_rd} {}
31- CustomGeneratorImpl (const std::string& token) : c10::GeneratorImpl{Device (DeviceType::CPU), DispatchKeySet (DispatchKey::CustomRNGKeyId)}, use_rd_{true }, rd_{token} {}
32- CustomGeneratorImpl (uint64_t seed) : c10::GeneratorImpl{Device (DeviceType::CPU), DispatchKeySet (DispatchKey::CustomRNGKeyId)}, use_rd_{false }, mt_{static_cast <unsigned int >(seed)} { }
33- ~CustomGeneratorImpl () = default ;
29+ struct CSPRNGGeneratorImpl : public c10 ::GeneratorImpl {
30+ CSPRNGGeneratorImpl (bool use_rd) : c10::GeneratorImpl{Device (DeviceType::CPU), DispatchKeySet (DispatchKey::CustomRNGKeyId)}, use_rd_{use_rd} {}
31+ CSPRNGGeneratorImpl (const std::string& token) : c10::GeneratorImpl{Device (DeviceType::CPU), DispatchKeySet (DispatchKey::CustomRNGKeyId)}, use_rd_{true }, rd_{token} {}
32+ CSPRNGGeneratorImpl (uint64_t seed) : c10::GeneratorImpl{Device (DeviceType::CPU), DispatchKeySet (DispatchKey::CustomRNGKeyId)}, use_rd_{false }, mt_{static_cast <unsigned int >(seed)} { }
33+ ~CSPRNGGeneratorImpl () = default ;
3434 uint32_t random () { return use_rd_ ? rd_ () : mt_ (); }
3535 uint64_t random64 () { return use_rd_ ? make64BitsFrom32Bits (rd_ (), rd_ ()) : make64BitsFrom32Bits (mt_ (), mt_ ()); }
3636
3737 void set_current_seed (uint64_t seed) override { throw std::runtime_error (" not implemented" ); }
3838 uint64_t current_seed () const override { throw std::runtime_error (" not implemented" ); }
3939 uint64_t seed () override { throw std::runtime_error (" not implemented" ); }
40- CustomGeneratorImpl * clone_impl () const override { throw std::runtime_error (" not implemented" ); }
40+ CSPRNGGeneratorImpl * clone_impl () const override { throw std::runtime_error (" not implemented" ); }
4141
4242 static DeviceType device_type () { return DeviceType::CPU; }
4343
@@ -161,11 +161,11 @@ struct RandomFromToKernel {
161161};
162162
163163Tensor& random_ (Tensor& self, c10::optional<Generator> generator) {
164- return random_impl<RandomKernel, CustomGeneratorImpl >(self, generator);
164+ return random_impl<RandomKernel, CSPRNGGeneratorImpl >(self, generator);
165165}
166166
167167Tensor& random_from_to (Tensor& self, int64_t from, optional<int64_t > to, c10::optional<Generator> generator) {
168- return random_from_to_impl<RandomFromToKernel, CustomGeneratorImpl >(self, from, to, generator);
168+ return random_from_to_impl<RandomFromToKernel, CSPRNGGeneratorImpl >(self, from, to, generator);
169169}
170170
171171Tensor& random_to (Tensor& self, int64_t to, c10::optional<Generator> generator) {
@@ -191,7 +191,7 @@ struct UniformKernel {
191191};
192192
193193Tensor& uniform_ (Tensor& self, double from, double to, c10::optional<Generator> generator) {
194- return uniform_impl_<UniformKernel, CustomGeneratorImpl >(self, from, to, generator);
194+ return uniform_impl_<UniformKernel, CSPRNGGeneratorImpl >(self, from, to, generator);
195195}
196196
197197// ==================================================== Normal ========================================================
@@ -214,31 +214,31 @@ struct NormalKernel {
214214};
215215
216216Tensor& normal_ (Tensor& self, double mean, double std, c10::optional<Generator> generator) {
217- return normal_impl_<NormalKernel, CustomGeneratorImpl >(self, mean, std, generator);
217+ return normal_impl_<NormalKernel, CSPRNGGeneratorImpl >(self, mean, std, generator);
218218}
219219
220220Tensor& normal_Tensor_float_out (Tensor& output, const Tensor& mean, double std, c10::optional<Generator> gen) {
221- return normal_out_impl<NormalKernel, CustomGeneratorImpl >(output, mean, std, gen);
221+ return normal_out_impl<NormalKernel, CSPRNGGeneratorImpl >(output, mean, std, gen);
222222}
223223
224224Tensor& normal_float_Tensor_out (Tensor& output, double mean, const Tensor& std, c10::optional<Generator> gen) {
225- return normal_out_impl<NormalKernel, CustomGeneratorImpl >(output, mean, std, gen);
225+ return normal_out_impl<NormalKernel, CSPRNGGeneratorImpl >(output, mean, std, gen);
226226}
227227
228228Tensor& normal_Tensor_Tensor_out (Tensor& output, const Tensor& mean, const Tensor& std, c10::optional<Generator> gen) {
229- return normal_out_impl<NormalKernel, CustomGeneratorImpl >(output, mean, std, gen);
229+ return normal_out_impl<NormalKernel, CSPRNGGeneratorImpl >(output, mean, std, gen);
230230}
231231
232232Tensor normal_Tensor_float (const Tensor& mean, double std, c10::optional<Generator> gen) {
233- return normal_impl<NormalKernel, CustomGeneratorImpl >(mean, std, gen);
233+ return normal_impl<NormalKernel, CSPRNGGeneratorImpl >(mean, std, gen);
234234}
235235
236236Tensor normal_float_Tensor (double mean, const Tensor& std, c10::optional<Generator> gen) {
237- return normal_impl<NormalKernel, CustomGeneratorImpl >(mean, std, gen);
237+ return normal_impl<NormalKernel, CSPRNGGeneratorImpl >(mean, std, gen);
238238}
239239
240240Tensor normal_Tensor_Tensor (const Tensor& mean, const Tensor& std, c10::optional<Generator> gen) {
241- return normal_impl<NormalKernel, CustomGeneratorImpl >(mean, std, gen);
241+ return normal_impl<NormalKernel, CSPRNGGeneratorImpl >(mean, std, gen);
242242}
243243
244244// ==================================================== Cauchy ========================================================
@@ -260,7 +260,7 @@ struct CauchyKernel {
260260};
261261
262262Tensor& cauchy_ (Tensor& self, double median, double sigma, c10::optional<Generator> generator) {
263- return cauchy_impl_<CauchyKernel, CustomGeneratorImpl >(self, median, sigma, generator);
263+ return cauchy_impl_<CauchyKernel, CSPRNGGeneratorImpl >(self, median, sigma, generator);
264264}
265265
266266// ================================================== LogNormal =======================================================
@@ -282,7 +282,7 @@ struct LogNormalKernel {
282282};
283283
284284Tensor& log_normal_ (Tensor& self, double mean, double std, c10::optional<Generator> gen) {
285- return log_normal_impl_<LogNormalKernel, CustomGeneratorImpl >(self, mean, std, gen);
285+ return log_normal_impl_<LogNormalKernel, CSPRNGGeneratorImpl >(self, mean, std, gen);
286286}
287287
288288// ================================================== Geometric =======================================================
@@ -304,7 +304,7 @@ struct GeometricKernel {
304304};
305305
306306Tensor& geometric_ (Tensor& self, double p, c10::optional<Generator> gen) {
307- return geometric_impl_<GeometricKernel, CustomGeneratorImpl >(self, p, gen);
307+ return geometric_impl_<GeometricKernel, CSPRNGGeneratorImpl >(self, p, gen);
308308}
309309
310310// ================================================== Exponential =====================================================
@@ -326,24 +326,24 @@ struct ExponentialKernel {
326326};
327327
328328Tensor& exponential_ (Tensor& self, double lambda, c10::optional<Generator> gen) {
329- return exponential_impl_<ExponentialKernel, CustomGeneratorImpl >(self, lambda, gen);
329+ return exponential_impl_<ExponentialKernel, CSPRNGGeneratorImpl >(self, lambda, gen);
330330}
331331
332332// ====================================================================================================================
333333
334334Generator create_random_device_generator (c10::optional<std::string> token = c10::nullopt ) {
335335 if (token.has_value ()) {
336- return make_generator<CustomGeneratorImpl >(*token);
336+ return make_generator<CSPRNGGeneratorImpl >(*token);
337337 } else {
338- return make_generator<CustomGeneratorImpl >(true );
338+ return make_generator<CSPRNGGeneratorImpl >(true );
339339 }
340340}
341341
342342Generator create_mt19937_generator (c10::optional<uint64_t > seed = c10::nullopt ) {
343343 if (seed.has_value ()) {
344- return make_generator<CustomGeneratorImpl >(*seed);
344+ return make_generator<CSPRNGGeneratorImpl >(*seed);
345345 } else {
346- return make_generator<CustomGeneratorImpl >(false );
346+ return make_generator<CSPRNGGeneratorImpl >(false );
347347 }
348348}
349349
0 commit comments