Skip to content

Commit ff9c39f

Browse files
committed
Rename custom_prng to csprng and CustomGeneratorImpl to CSPRNGGeneratorImpl
ghstack-source-id: 862ac1e Pull Request resolved: #52
1 parent 9b32198 commit ff9c39f

File tree

3 files changed

+27
-27
lines changed

3 files changed

+27
-27
lines changed

torchcsprng/csrc/aes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <cstdint>
55

66
namespace torch {
7-
namespace custom_prng {
7+
namespace csprng {
88
namespace aes {
99

1010
// This AES implementation is based on

torchcsprng/csrc/block_cipher.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
#endif
2121

2222
namespace torch {
23-
namespace custom_prng {
23+
namespace csprng {
2424

2525
// Generates `block_t_size`-bytes random key Tensor on CPU
2626
// using `generator`, which must be an instance of `at::CPUGeneratorImpl`

torchcsprng/csrc/csprng.h

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,25 @@
1919

2020
using namespace at;
2121
using namespace at::native::templates;
22-
using namespace torch::custom_prng;
22+
using namespace torch::csprng;
2323

2424
inline uint64_t make64BitsFrom32Bits(uint32_t hi, uint32_t lo) {
2525
return (static_cast<uint64_t>(hi) << 32) | lo;
2626
}
2727

2828
// CUDA CSPRNG is actually CPU generator which is used only to generate a random key on CPU for AES running in a block mode on CUDA
29-
struct CustomGeneratorImpl : public c10::GeneratorImpl {
30-
CustomGeneratorImpl(bool use_rd) : c10::GeneratorImpl{Device(DeviceType::CPU), DispatchKeySet(DispatchKey::CustomRNGKeyId)}, use_rd_{use_rd} {}
31-
CustomGeneratorImpl(const std::string& token) : c10::GeneratorImpl{Device(DeviceType::CPU), DispatchKeySet(DispatchKey::CustomRNGKeyId)}, use_rd_{true}, rd_{token} {}
32-
CustomGeneratorImpl(uint64_t seed) : c10::GeneratorImpl{Device(DeviceType::CPU), DispatchKeySet(DispatchKey::CustomRNGKeyId)}, use_rd_{false}, mt_{static_cast<unsigned int>(seed)} { }
33-
~CustomGeneratorImpl() = default;
29+
struct CSPRNGGeneratorImpl : public c10::GeneratorImpl {
30+
CSPRNGGeneratorImpl(bool use_rd) : c10::GeneratorImpl{Device(DeviceType::CPU), DispatchKeySet(DispatchKey::CustomRNGKeyId)}, use_rd_{use_rd} {}
31+
CSPRNGGeneratorImpl(const std::string& token) : c10::GeneratorImpl{Device(DeviceType::CPU), DispatchKeySet(DispatchKey::CustomRNGKeyId)}, use_rd_{true}, rd_{token} {}
32+
CSPRNGGeneratorImpl(uint64_t seed) : c10::GeneratorImpl{Device(DeviceType::CPU), DispatchKeySet(DispatchKey::CustomRNGKeyId)}, use_rd_{false}, mt_{static_cast<unsigned int>(seed)} { }
33+
~CSPRNGGeneratorImpl() = default;
3434
uint32_t random() { return use_rd_ ? rd_() : mt_(); }
3535
uint64_t random64() { return use_rd_ ? make64BitsFrom32Bits(rd_(), rd_()) : make64BitsFrom32Bits(mt_(), mt_()); }
3636

3737
void set_current_seed(uint64_t seed) override { throw std::runtime_error("not implemented"); }
3838
uint64_t current_seed() const override { throw std::runtime_error("not implemented"); }
3939
uint64_t seed() override { throw std::runtime_error("not implemented"); }
40-
CustomGeneratorImpl* clone_impl() const override { throw std::runtime_error("not implemented"); }
40+
CSPRNGGeneratorImpl* clone_impl() const override { throw std::runtime_error("not implemented"); }
4141

4242
static DeviceType device_type() { return DeviceType::CPU; }
4343

@@ -161,11 +161,11 @@ struct RandomFromToKernel {
161161
};
162162

163163
Tensor& random_(Tensor& self, c10::optional<Generator> generator) {
164-
return random_impl<RandomKernel, CustomGeneratorImpl>(self, generator);
164+
return random_impl<RandomKernel, CSPRNGGeneratorImpl>(self, generator);
165165
}
166166

167167
Tensor& random_from_to(Tensor& self, int64_t from, optional<int64_t> to, c10::optional<Generator> generator) {
168-
return random_from_to_impl<RandomFromToKernel, CustomGeneratorImpl>(self, from, to, generator);
168+
return random_from_to_impl<RandomFromToKernel, CSPRNGGeneratorImpl>(self, from, to, generator);
169169
}
170170

171171
Tensor& random_to(Tensor& self, int64_t to, c10::optional<Generator> generator) {
@@ -191,7 +191,7 @@ struct UniformKernel {
191191
};
192192

193193
Tensor& uniform_(Tensor& self, double from, double to, c10::optional<Generator> generator) {
194-
return uniform_impl_<UniformKernel, CustomGeneratorImpl>(self, from, to, generator);
194+
return uniform_impl_<UniformKernel, CSPRNGGeneratorImpl>(self, from, to, generator);
195195
}
196196

197197
// ==================================================== Normal ========================================================
@@ -214,31 +214,31 @@ struct NormalKernel {
214214
};
215215

216216
Tensor& normal_(Tensor& self, double mean, double std, c10::optional<Generator> generator) {
217-
return normal_impl_<NormalKernel, CustomGeneratorImpl>(self, mean, std, generator);
217+
return normal_impl_<NormalKernel, CSPRNGGeneratorImpl>(self, mean, std, generator);
218218
}
219219

220220
Tensor& normal_Tensor_float_out(Tensor& output, const Tensor& mean, double std, c10::optional<Generator> gen) {
221-
return normal_out_impl<NormalKernel, CustomGeneratorImpl>(output, mean, std, gen);
221+
return normal_out_impl<NormalKernel, CSPRNGGeneratorImpl>(output, mean, std, gen);
222222
}
223223

224224
Tensor& normal_float_Tensor_out(Tensor& output, double mean, const Tensor& std, c10::optional<Generator> gen) {
225-
return normal_out_impl<NormalKernel, CustomGeneratorImpl>(output, mean, std, gen);
225+
return normal_out_impl<NormalKernel, CSPRNGGeneratorImpl>(output, mean, std, gen);
226226
}
227227

228228
Tensor& normal_Tensor_Tensor_out(Tensor& output, const Tensor& mean, const Tensor& std, c10::optional<Generator> gen) {
229-
return normal_out_impl<NormalKernel, CustomGeneratorImpl>(output, mean, std, gen);
229+
return normal_out_impl<NormalKernel, CSPRNGGeneratorImpl>(output, mean, std, gen);
230230
}
231231

232232
Tensor normal_Tensor_float(const Tensor& mean, double std, c10::optional<Generator> gen) {
233-
return normal_impl<NormalKernel, CustomGeneratorImpl>(mean, std, gen);
233+
return normal_impl<NormalKernel, CSPRNGGeneratorImpl>(mean, std, gen);
234234
}
235235

236236
Tensor normal_float_Tensor(double mean, const Tensor& std, c10::optional<Generator> gen) {
237-
return normal_impl<NormalKernel, CustomGeneratorImpl>(mean, std, gen);
237+
return normal_impl<NormalKernel, CSPRNGGeneratorImpl>(mean, std, gen);
238238
}
239239

240240
Tensor normal_Tensor_Tensor(const Tensor& mean, const Tensor& std, c10::optional<Generator> gen) {
241-
return normal_impl<NormalKernel, CustomGeneratorImpl>(mean, std, gen);
241+
return normal_impl<NormalKernel, CSPRNGGeneratorImpl>(mean, std, gen);
242242
}
243243

244244
// ==================================================== Cauchy ========================================================
@@ -260,7 +260,7 @@ struct CauchyKernel {
260260
};
261261

262262
Tensor& cauchy_(Tensor& self, double median, double sigma, c10::optional<Generator> generator) {
263-
return cauchy_impl_<CauchyKernel, CustomGeneratorImpl>(self, median, sigma, generator);
263+
return cauchy_impl_<CauchyKernel, CSPRNGGeneratorImpl>(self, median, sigma, generator);
264264
}
265265

266266
// ================================================== LogNormal =======================================================
@@ -282,7 +282,7 @@ struct LogNormalKernel {
282282
};
283283

284284
Tensor& log_normal_(Tensor& self, double mean, double std, c10::optional<Generator> gen) {
285-
return log_normal_impl_<LogNormalKernel, CustomGeneratorImpl>(self, mean, std, gen);
285+
return log_normal_impl_<LogNormalKernel, CSPRNGGeneratorImpl>(self, mean, std, gen);
286286
}
287287

288288
// ================================================== Geometric =======================================================
@@ -304,7 +304,7 @@ struct GeometricKernel {
304304
};
305305

306306
Tensor& geometric_(Tensor& self, double p, c10::optional<Generator> gen) {
307-
return geometric_impl_<GeometricKernel, CustomGeneratorImpl>(self, p, gen);
307+
return geometric_impl_<GeometricKernel, CSPRNGGeneratorImpl>(self, p, gen);
308308
}
309309

310310
// ================================================== Exponential =====================================================
@@ -326,24 +326,24 @@ struct ExponentialKernel {
326326
};
327327

328328
Tensor& exponential_(Tensor& self, double lambda, c10::optional<Generator> gen) {
329-
return exponential_impl_<ExponentialKernel, CustomGeneratorImpl>(self, lambda, gen);
329+
return exponential_impl_<ExponentialKernel, CSPRNGGeneratorImpl>(self, lambda, gen);
330330
}
331331

332332
// ====================================================================================================================
333333

334334
Generator create_random_device_generator(c10::optional<std::string> token = c10::nullopt) {
335335
if (token.has_value()) {
336-
return make_generator<CustomGeneratorImpl>(*token);
336+
return make_generator<CSPRNGGeneratorImpl>(*token);
337337
} else {
338-
return make_generator<CustomGeneratorImpl>(true);
338+
return make_generator<CSPRNGGeneratorImpl>(true);
339339
}
340340
}
341341

342342
Generator create_mt19937_generator(c10::optional<uint64_t> seed = c10::nullopt) {
343343
if (seed.has_value()) {
344-
return make_generator<CustomGeneratorImpl>(*seed);
344+
return make_generator<CSPRNGGeneratorImpl>(*seed);
345345
} else {
346-
return make_generator<CustomGeneratorImpl>(false);
346+
return make_generator<CSPRNGGeneratorImpl>(false);
347347
}
348348
}
349349

0 commit comments

Comments
 (0)