@@ -409,5 +409,56 @@ def create_aes(m, k):
409409
410410 self .assertTrue (np .array_equal (initial_np , decrypted_np ))
411411
412+ def test_encrypt_decrypt_inplace (self ):
413+ key_size_bytes = 16
414+
415+ def sizeof (dtype ):
416+ if dtype == torch .bool :
417+ return 1
418+ elif dtype .is_floating_point :
419+ return torch .finfo (dtype ).bits // 8
420+ else :
421+ return torch .iinfo (dtype ).bits // 8
422+
423+ def create_aes (m , k ):
424+ if m == "ecb" :
425+ return AES .new (k .tobytes (), AES .MODE_ECB )
426+ elif m == "ctr" :
427+ ctr = Counter .new (AES .block_size * 8 , initial_value = 0 , little_endian = True )
428+ return AES .new (k .tobytes (), AES .MODE_CTR , counter = ctr )
429+ else :
430+ return None
431+
432+ for key_dtype in self .all_dtypes :
433+ key_size = key_size_bytes // sizeof (key_dtype )
434+ key = torch .empty (key_size , dtype = key_dtype ).random_ ()
435+ key_np = key .numpy ().view (np .int8 )
436+ for initial_dtype in self .all_dtypes :
437+ for initial_size_bytes in [0 , 16 , 256 , 1048576 ]:
438+ initial_size = initial_size_bytes // sizeof (initial_dtype )
439+ initial = torch .empty (initial_size , dtype = initial_dtype ).random_ ()
440+ initial_np = initial .numpy ().view (np .int8 )
441+ initial_np_copy = np .copy (initial_np )
442+ for mode in ["ecb" , "ctr" ]:
443+ for device in self .all_devices :
444+ key = key .to (device )
445+ initial = initial .to (device )
446+
447+ csprng .encrypt (initial , initial , key , "aes128" , mode )
448+ encrypted_np = initial .cpu ().numpy ().view (np .int8 )
449+ aes = create_aes (mode , key_np )
450+ encrypted_expected = np .frombuffer (aes .encrypt (initial_np_copy .tobytes ()), dtype = np .int8 )
451+ self .assertTrue (np .array_equal (encrypted_np , encrypted_expected ))
452+
453+ encrypted_np_copy = np .copy (encrypted_np )
454+
455+ csprng .decrypt (initial , initial , key , "aes128" , mode )
456+ decrypted_np = initial .cpu ().numpy ().view (np .int8 )
457+ aes = create_aes (mode , key_np )
458+ decrypted_expected = np .frombuffer (aes .decrypt (encrypted_np_copy .tobytes ()), dtype = np .int8 )
459+ self .assertTrue (np .array_equal (decrypted_np , decrypted_expected ))
460+
461+ self .assertTrue (np .array_equal (initial_np_copy , decrypted_np ))
462+
412463if __name__ == '__main__' :
413464 unittest .main ()
0 commit comments