From d7ac4fa6831d7c67c2a3704e22724cd2880ed83d Mon Sep 17 00:00:00 2001 From: lehugueni Date: Mon, 27 Jan 2025 09:53:45 +0100 Subject: [PATCH 1/3] refactor(sampling): use of counter and avoid resampling key on every call --- core/rlwe/encryptor.go | 5 ++++- utils/sampling/prng.go | 30 +++++++++++++++++++++++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/core/rlwe/encryptor.go b/core/rlwe/encryptor.go index 85986f1da..f5ca3e758 100644 --- a/core/rlwe/encryptor.go +++ b/core/rlwe/encryptor.go @@ -62,7 +62,10 @@ func (enc Encryptor) GetRLWEParameters() *Parameters { func newEncryptor(params Parameters) *Encryptor { - prng := &sampling.ThreadSafePRNG{} + prng, err := sampling.NewThreadSafePRNG() + if err != nil { + panic(fmt.Errorf("newEncryptor: %w", err)) + } var bc *ring.BasisExtender if params.PCount() != 0 { diff --git a/utils/sampling/prng.go b/utils/sampling/prng.go index 435077f7c..c46a0ec14 100644 --- a/utils/sampling/prng.go +++ b/utils/sampling/prng.go @@ -2,10 +2,13 @@ package sampling import ( "crypto/rand" + "encoding/binary" "fmt" "io" + "sync/atomic" "golang.org/x/crypto/blake2b" + "golang.org/x/crypto/sha3" ) // PRNG is an interface for secure (keyed) deterministic generation of random bytes @@ -23,11 +26,36 @@ type KeyedPRNG struct { } type ThreadSafePRNG struct { + key []byte + atomicCnt atomic.Uint64 +} + +func NewThreadSafePRNG() (*ThreadSafePRNG, error) { + key := make([]byte, 64) + if _, err := rand.Read(key); err != nil { + return nil, fmt.Errorf("crypto rand error: %w", err) + } + return &ThreadSafePRNG{ + atomicCnt: atomic.Uint64{}, + key: key, + }, nil +} + +func uint64ToByte(n uint64) []byte { + arr := make([]byte, 8) + binary.LittleEndian.PutUint64(arr, n) + return arr } // Read reads bytes from the KeyedPRNG on sum. func (prng *ThreadSafePRNG) Read(sum []byte) (n int, err error) { - tmpPRNG, err := NewPRNG() + tmpPRNG := sha3.NewShake256() + _, err = tmpPRNG.Write(prng.key) + if err != nil { + return 0, fmt.Errorf("crypto rand error: %w", err) + } + cnt := prng.atomicCnt.Add(1) + _, err = tmpPRNG.Write(uint64ToByte(cnt)) if err != nil { return 0, fmt.Errorf("crypto rand error: %w", err) } From c9be75820ceeae912ea3c37e3288c9a51fd2cf62 Mon Sep 17 00:00:00 2001 From: lehugueni Date: Mon, 27 Jan 2025 14:14:15 +0100 Subject: [PATCH 2/3] perf(sampling): store state of shake instead of key --- utils/sampling/prng.go | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/utils/sampling/prng.go b/utils/sampling/prng.go index c46a0ec14..2d13a7e1c 100644 --- a/utils/sampling/prng.go +++ b/utils/sampling/prng.go @@ -26,7 +26,7 @@ type KeyedPRNG struct { } type ThreadSafePRNG struct { - key []byte + xof sha3.ShakeHash atomicCnt atomic.Uint64 } @@ -35,9 +35,14 @@ func NewThreadSafePRNG() (*ThreadSafePRNG, error) { if _, err := rand.Read(key); err != nil { return nil, fmt.Errorf("crypto rand error: %w", err) } + tmpPRNG := sha3.NewShake256() + _, err := tmpPRNG.Write(key) + if err != nil { + return nil, fmt.Errorf("crypto rand error: %w", err) + } return &ThreadSafePRNG{ atomicCnt: atomic.Uint64{}, - key: key, + xof: tmpPRNG, }, nil } @@ -49,11 +54,7 @@ func uint64ToByte(n uint64) []byte { // Read reads bytes from the KeyedPRNG on sum. func (prng *ThreadSafePRNG) Read(sum []byte) (n int, err error) { - tmpPRNG := sha3.NewShake256() - _, err = tmpPRNG.Write(prng.key) - if err != nil { - return 0, fmt.Errorf("crypto rand error: %w", err) - } + tmpPRNG := prng.xof.Clone() cnt := prng.atomicCnt.Add(1) _, err = tmpPRNG.Write(uint64ToByte(cnt)) if err != nil { From 0172e9a7512c76aa7b95635762593d97202de326 Mon Sep 17 00:00:00 2001 From: lehugueni Date: Tue, 28 Jan 2025 07:25:03 +0100 Subject: [PATCH 3/3] test: benchmark different prngs --- ring/ring_benchmark_test.go | 103 ++++++++++++++++++++++++++++++++++++ ring/sampler_uniform.go | 6 +-- utils/sampling/prng.go | 20 +++++++ 3 files changed, 126 insertions(+), 3 deletions(-) diff --git a/ring/ring_benchmark_test.go b/ring/ring_benchmark_test.go index 58e3590e7..aa0239325 100644 --- a/ring/ring_benchmark_test.go +++ b/ring/ring_benchmark_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/require" "github.com/tuneinsight/lattigo/v6/utils/bignum" + "github.com/tuneinsight/lattigo/v6/utils/sampling" ) func BenchmarkRing(b *testing.B) { @@ -88,6 +89,108 @@ func benchSampling(tc *testParams, b *testing.B) { sampler.Read(pol) } }) + b.Run(testString("Sampling/ThreadSafePRNGNaive/Gaussian", tc.ringQ), func(b *testing.B) { + + prng := &sampling.ThreadSafePRNGNaive{} + sampler, err := NewSampler(prng, tc.ringQ, DiscreteGaussian{Sigma: DefaultSigma, Bound: DefaultBound}, false) + require.NoError(b, err) + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + sampler.Read(pol) + } + }) + }) + b.Run(testString("Sampling/ThreadSafeShallowCopy/Gaussian", tc.ringQ), func(b *testing.B) { + + b.RunParallel(func(pb *testing.PB) { + prng, _ := sampling.NewPRNG() + sampler, _ := NewSampler(prng, tc.ringQ, DiscreteGaussian{Sigma: DefaultSigma, Bound: DefaultBound}, false) + for pb.Next() { + sampler.Read(pol) + } + }) + }) + b.Run(testString("Sampling/ThreadSafePRNG/Gaussian", tc.ringQ), func(b *testing.B) { + + prng, err := sampling.NewThreadSafePRNG() + sampler, err := NewSampler(prng, tc.ringQ, DiscreteGaussian{Sigma: DefaultSigma, Bound: DefaultBound}, false) + require.NoError(b, err) + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + sampler.Read(pol) + } + }) + }) + b.Run(testString("Sampling/ThreadSafePRNGNaive/Uniform", tc.ringQ), func(b *testing.B) { + + prng := &sampling.ThreadSafePRNGNaive{} + sampler, err := NewSampler(prng, tc.ringQ, Uniform{}, true) + require.NoError(b, err) + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + sampler.Read(pol) + } + }) + }) + b.Run(testString("Sampling/ThreadSafeShallowCopy/Uniform", tc.ringQ), func(b *testing.B) { + + b.RunParallel(func(pb *testing.PB) { + prng, _ := sampling.NewPRNG() + sampler, _ := NewSampler(prng, tc.ringQ, Uniform{}, true) + for pb.Next() { + sampler.Read(pol) + } + }) + }) + b.Run(testString("Sampling/ThreadSafePRNG/Uniform", tc.ringQ), func(b *testing.B) { + + prng, err := sampling.NewThreadSafePRNG() + sampler, err := NewSampler(prng, tc.ringQ, Uniform{}, true) + require.NoError(b, err) + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + sampler.Read(pol) + } + }) + }) + b.Run(testString("Sampling/ThreadSafePRNGNaive/Ternary/0.3", tc.ringQ), func(b *testing.B) { + + prng := &sampling.ThreadSafePRNGNaive{} + sampler, err := NewSampler(prng, tc.ringQ, Ternary{P: 1.0 / 3}, true) + require.NoError(b, err) + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + sampler.Read(pol) + } + }) + }) + b.Run(testString("Sampling/ThreadSafeShallowCopy/Ternary/0.3", tc.ringQ), func(b *testing.B) { + + b.RunParallel(func(pb *testing.PB) { + prng, _ := sampling.NewPRNG() + sampler, _ := NewSampler(prng, tc.ringQ, Ternary{P: 1.0 / 3}, true) + for pb.Next() { + sampler.Read(pol) + } + }) + }) + b.Run(testString("Sampling/ThreadSafePRNG/Ternary/0.3", tc.ringQ), func(b *testing.B) { + + prng, err := sampling.NewThreadSafePRNG() + sampler, err := NewSampler(prng, tc.ringQ, Ternary{P: 1.0 / 3}, true) + require.NoError(b, err) + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + sampler.Read(pol) + } + }) + }) b.Run(testString("Sampling/Ternary/0.3", tc.ringQ), func(b *testing.B) { diff --git a/ring/sampler_uniform.go b/ring/sampler_uniform.go index b4df1523b..55ba89bf8 100644 --- a/ring/sampler_uniform.go +++ b/ring/sampler_uniform.go @@ -45,14 +45,14 @@ func (u *UniformSampler) read(pol Poly, f func(a, b, c uint64) uint64) { level := u.baseRing.Level() var randomUint, mask, qi uint64 - var buffer [1024]byte + buffer := make([]byte, 1024) prng := u.prng N := u.baseRing.N() byteArrayLength := len(buffer) var ptr int - if _, err := prng.Read(buffer[:]); err != nil { + if _, err := prng.Read(buffer); err != nil { // Sanity check, this error should not happen. panic(err) } @@ -74,7 +74,7 @@ func (u *UniformSampler) read(pol Poly, f func(a, b, c uint64) uint64) { // Refills the buff if it runs empty if ptr == byteArrayLength { - if _, err := u.prng.Read(buffer[:]); err != nil { + if _, err := u.prng.Read(buffer); err != nil { // Sanity check, this error should not happen. panic(err) } diff --git a/utils/sampling/prng.go b/utils/sampling/prng.go index 2d13a7e1c..b975b1714 100644 --- a/utils/sampling/prng.go +++ b/utils/sampling/prng.go @@ -25,6 +25,26 @@ type KeyedPRNG struct { xof blake2b.XOF } +type ThreadSafePRNGNaive struct { +} + +func NewThreadSafePRNGNaive() (*ThreadSafePRNGNaive, error) { + return &ThreadSafePRNGNaive{}, nil +} + +func (prng *ThreadSafePRNGNaive) Read(sum []byte) (n int, err error) { + key := make([]byte, 64) + if _, err := rand.Read(key); err != nil { + return 0, fmt.Errorf("crypto rand error: %w", err) + } + tmpPRNG := sha3.NewShake256() + _, err = tmpPRNG.Write(key) + if err != nil { + return 0, fmt.Errorf("crypto rand error: %w", err) + } + return tmpPRNG.Read(sum) +} + type ThreadSafePRNG struct { xof sha3.ShakeHash atomicCnt atomic.Uint64