Skip to content

Commit 0437e0b

Browse files
committed
Test encrypt/decrypt inplace
ghstack-source-id: 8a53808 Pull Request resolved: #101
1 parent 96ec4fb commit 0437e0b

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

test/test_csprng.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,5 +409,56 @@ def create_aes(m, k):
409409

410410
self.assertTrue(np.array_equal(initial_np, decrypted_np))
411411

412+
def test_encrypt_decrypt_inplace(self):
413+
key_size_bytes = 16
414+
415+
def sizeof(dtype):
416+
if dtype == torch.bool:
417+
return 1
418+
elif dtype.is_floating_point:
419+
return torch.finfo(dtype).bits // 8
420+
else:
421+
return torch.iinfo(dtype).bits // 8
422+
423+
def create_aes(m, k):
424+
if m == "ecb":
425+
return AES.new(k.tobytes(), AES.MODE_ECB)
426+
elif m == "ctr":
427+
ctr = Counter.new(AES.block_size * 8, initial_value=0, little_endian=True)
428+
return AES.new(k.tobytes(), AES.MODE_CTR, counter=ctr)
429+
else:
430+
return None
431+
432+
for key_dtype in self.all_dtypes:
433+
key_size = key_size_bytes // sizeof(key_dtype)
434+
key = torch.empty(key_size, dtype=key_dtype).random_()
435+
key_np = key.numpy().view(np.int8)
436+
for initial_dtype in self.all_dtypes:
437+
for initial_size_bytes in [0, 16, 256, 1048576]:
438+
initial_size = initial_size_bytes // sizeof(initial_dtype)
439+
initial = torch.empty(initial_size, dtype=initial_dtype).random_()
440+
initial_np = initial.numpy().view(np.int8)
441+
initial_np_copy = np.copy(initial_np)
442+
for mode in ["ecb", "ctr"]:
443+
for device in self.all_devices:
444+
key = key.to(device)
445+
initial = initial.to(device)
446+
447+
csprng.encrypt(initial, initial, key, "aes128", mode)
448+
encrypted_np = initial.cpu().numpy().view(np.int8)
449+
aes = create_aes(mode, key_np)
450+
encrypted_expected = np.frombuffer(aes.encrypt(initial_np_copy.tobytes()), dtype=np.int8)
451+
self.assertTrue(np.array_equal(encrypted_np, encrypted_expected))
452+
453+
encrypted_np_copy = np.copy(encrypted_np)
454+
455+
csprng.decrypt(initial, initial, key, "aes128", mode)
456+
decrypted_np = initial.cpu().numpy().view(np.int8)
457+
aes = create_aes(mode, key_np)
458+
decrypted_expected = np.frombuffer(aes.decrypt(encrypted_np_copy.tobytes()), dtype=np.int8)
459+
self.assertTrue(np.array_equal(decrypted_np, decrypted_expected))
460+
461+
self.assertTrue(np.array_equal(initial_np_copy, decrypted_np))
462+
412463
if __name__ == '__main__':
413464
unittest.main()

0 commit comments

Comments
 (0)