Skip to content

Commit 4e708d9

Browse files
HuanyuZhangfacebook-github-bot
authored andcommitted
Reduce module size or the number of steps to avoid over-time tests (#739)
Summary: As titled. Facebook The buck test has the limit of 10 mins. To avoid overtime failure, we slightly reduce the number of parameters or the number of repetitions. This won't reduce the credibility of the tests. Differential Revision: D70707205
1 parent 7264cd7 commit 4e708d9

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

opacus/tests/grad_sample_module_fast_gradient_clipping_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def setUp_data_sequantial(self, size, length, dim):
120120

121121
@given(
122122
size=st.sampled_from([10]),
123-
length=st.sampled_from([1, 10]),
123+
length=st.sampled_from([5]),
124124
dim=st.sampled_from([2]),
125125
)
126126
@settings(deadline=1000000)
@@ -192,12 +192,12 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
192192
diff = flat_norms_normal - flat_norms_gc
193193

194194
logging.info(f"Max difference between (vanilla) Opacus and FGC = {max(diff)}")
195-
msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
195+
msg = "Fail: Per-sample gradient norms from vanilla DP-SGD and from fast gradient clipping are different"
196196
assert torch.allclose(flat_norms_normal, flat_norms_gc, atol=1e-3), msg
197197

198198
@given(
199199
size=st.sampled_from([10]),
200-
length=st.sampled_from([1, 10]),
200+
length=st.sampled_from([5]),
201201
dim=st.sampled_from([2]),
202202
)
203203
@settings(deadline=1000000)

opacus/tests/privacy_engine_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def _compare_to_vanilla(
268268
do_clip=st.booleans(),
269269
do_noise=st.booleans(),
270270
use_closure=st.booleans(),
271-
max_steps=st.sampled_from([1, 4]),
271+
max_steps=st.sampled_from([1, 3]),
272272
)
273273
@settings(suppress_health_check=list(HealthCheck), deadline=None)
274274
def test_compare_to_vanilla(
@@ -660,7 +660,7 @@ def test_checkpoints(
660660

661661
@given(
662662
noise_multiplier=st.floats(0.5, 5.0),
663-
max_steps=st.integers(8, 10),
663+
max_steps=st.integers(3, 5),
664664
secure_mode=st.just(False), # TODO: enable after fixing torchcsprng build
665665
)
666666
@settings(suppress_health_check=list(HealthCheck), deadline=None)

0 commit comments

Comments
 (0)