Skip to content

Commit 0172e9a

Browse files
committed
test: benchmark different prngs
1 parent c9be758 commit 0172e9a

File tree

3 files changed

+126
-3
lines changed

3 files changed

+126
-3
lines changed

ring/ring_benchmark_test.go

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66

77
"github.com/stretchr/testify/require"
88
"github.com/tuneinsight/lattigo/v6/utils/bignum"
9+
"github.com/tuneinsight/lattigo/v6/utils/sampling"
910
)
1011

1112
func BenchmarkRing(b *testing.B) {
@@ -88,6 +89,108 @@ func benchSampling(tc *testParams, b *testing.B) {
8889
sampler.Read(pol)
8990
}
9091
})
92+
b.Run(testString("Sampling/ThreadSafePRNGNaive/Gaussian", tc.ringQ), func(b *testing.B) {
93+
94+
prng := &sampling.ThreadSafePRNGNaive{}
95+
sampler, err := NewSampler(prng, tc.ringQ, DiscreteGaussian{Sigma: DefaultSigma, Bound: DefaultBound}, false)
96+
require.NoError(b, err)
97+
98+
b.RunParallel(func(pb *testing.PB) {
99+
for pb.Next() {
100+
sampler.Read(pol)
101+
}
102+
})
103+
})
104+
b.Run(testString("Sampling/ThreadSafeShallowCopy/Gaussian", tc.ringQ), func(b *testing.B) {
105+
106+
b.RunParallel(func(pb *testing.PB) {
107+
prng, _ := sampling.NewPRNG()
108+
sampler, _ := NewSampler(prng, tc.ringQ, DiscreteGaussian{Sigma: DefaultSigma, Bound: DefaultBound}, false)
109+
for pb.Next() {
110+
sampler.Read(pol)
111+
}
112+
})
113+
})
114+
b.Run(testString("Sampling/ThreadSafePRNG/Gaussian", tc.ringQ), func(b *testing.B) {
115+
116+
prng, err := sampling.NewThreadSafePRNG()
117+
sampler, err := NewSampler(prng, tc.ringQ, DiscreteGaussian{Sigma: DefaultSigma, Bound: DefaultBound}, false)
118+
require.NoError(b, err)
119+
120+
b.RunParallel(func(pb *testing.PB) {
121+
for pb.Next() {
122+
sampler.Read(pol)
123+
}
124+
})
125+
})
126+
b.Run(testString("Sampling/ThreadSafePRNGNaive/Uniform", tc.ringQ), func(b *testing.B) {
127+
128+
prng := &sampling.ThreadSafePRNGNaive{}
129+
sampler, err := NewSampler(prng, tc.ringQ, Uniform{}, true)
130+
require.NoError(b, err)
131+
132+
b.RunParallel(func(pb *testing.PB) {
133+
for pb.Next() {
134+
sampler.Read(pol)
135+
}
136+
})
137+
})
138+
b.Run(testString("Sampling/ThreadSafeShallowCopy/Uniform", tc.ringQ), func(b *testing.B) {
139+
140+
b.RunParallel(func(pb *testing.PB) {
141+
prng, _ := sampling.NewPRNG()
142+
sampler, _ := NewSampler(prng, tc.ringQ, Uniform{}, true)
143+
for pb.Next() {
144+
sampler.Read(pol)
145+
}
146+
})
147+
})
148+
b.Run(testString("Sampling/ThreadSafePRNG/Uniform", tc.ringQ), func(b *testing.B) {
149+
150+
prng, err := sampling.NewThreadSafePRNG()
151+
sampler, err := NewSampler(prng, tc.ringQ, Uniform{}, true)
152+
require.NoError(b, err)
153+
154+
b.RunParallel(func(pb *testing.PB) {
155+
for pb.Next() {
156+
sampler.Read(pol)
157+
}
158+
})
159+
})
160+
b.Run(testString("Sampling/ThreadSafePRNGNaive/Ternary/0.3", tc.ringQ), func(b *testing.B) {
161+
162+
prng := &sampling.ThreadSafePRNGNaive{}
163+
sampler, err := NewSampler(prng, tc.ringQ, Ternary{P: 1.0 / 3}, true)
164+
require.NoError(b, err)
165+
166+
b.RunParallel(func(pb *testing.PB) {
167+
for pb.Next() {
168+
sampler.Read(pol)
169+
}
170+
})
171+
})
172+
b.Run(testString("Sampling/ThreadSafeShallowCopy/Ternary/0.3", tc.ringQ), func(b *testing.B) {
173+
174+
b.RunParallel(func(pb *testing.PB) {
175+
prng, _ := sampling.NewPRNG()
176+
sampler, _ := NewSampler(prng, tc.ringQ, Ternary{P: 1.0 / 3}, true)
177+
for pb.Next() {
178+
sampler.Read(pol)
179+
}
180+
})
181+
})
182+
b.Run(testString("Sampling/ThreadSafePRNG/Ternary/0.3", tc.ringQ), func(b *testing.B) {
183+
184+
prng, err := sampling.NewThreadSafePRNG()
185+
sampler, err := NewSampler(prng, tc.ringQ, Ternary{P: 1.0 / 3}, true)
186+
require.NoError(b, err)
187+
188+
b.RunParallel(func(pb *testing.PB) {
189+
for pb.Next() {
190+
sampler.Read(pol)
191+
}
192+
})
193+
})
91194

92195
b.Run(testString("Sampling/Ternary/0.3", tc.ringQ), func(b *testing.B) {
93196

ring/sampler_uniform.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,14 @@ func (u *UniformSampler) read(pol Poly, f func(a, b, c uint64) uint64) {
4545
level := u.baseRing.Level()
4646

4747
var randomUint, mask, qi uint64
48-
var buffer [1024]byte
48+
buffer := make([]byte, 1024)
4949

5050
prng := u.prng
5151
N := u.baseRing.N()
5252
byteArrayLength := len(buffer)
5353

5454
var ptr int
55-
if _, err := prng.Read(buffer[:]); err != nil {
55+
if _, err := prng.Read(buffer); err != nil {
5656
// Sanity check, this error should not happen.
5757
panic(err)
5858
}
@@ -74,7 +74,7 @@ func (u *UniformSampler) read(pol Poly, f func(a, b, c uint64) uint64) {
7474

7575
// Refills the buff if it runs empty
7676
if ptr == byteArrayLength {
77-
if _, err := u.prng.Read(buffer[:]); err != nil {
77+
if _, err := u.prng.Read(buffer); err != nil {
7878
// Sanity check, this error should not happen.
7979
panic(err)
8080
}

utils/sampling/prng.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,26 @@ type KeyedPRNG struct {
2525
xof blake2b.XOF
2626
}
2727

28+
type ThreadSafePRNGNaive struct {
29+
}
30+
31+
func NewThreadSafePRNGNaive() (*ThreadSafePRNGNaive, error) {
32+
return &ThreadSafePRNGNaive{}, nil
33+
}
34+
35+
func (prng *ThreadSafePRNGNaive) Read(sum []byte) (n int, err error) {
36+
key := make([]byte, 64)
37+
if _, err := rand.Read(key); err != nil {
38+
return 0, fmt.Errorf("crypto rand error: %w", err)
39+
}
40+
tmpPRNG := sha3.NewShake256()
41+
_, err = tmpPRNG.Write(key)
42+
if err != nil {
43+
return 0, fmt.Errorf("crypto rand error: %w", err)
44+
}
45+
return tmpPRNG.Read(sum)
46+
}
47+
2848
type ThreadSafePRNG struct {
2949
xof sha3.ShakeHash
3050
atomicCnt atomic.Uint64

0 commit comments

Comments
 (0)