Skip to content

Commit 1b295d9

Browse files
Shobha Venkataramanfacebook-github-bot
authored andcommitted
Small modifications to crypten.load (facebookresearch#203)
Summary: Pull Request resolved: fairinternal/CrypTen#203 - Modified crypten.load to allow the input `f` to be None. - Added assert to ensure exactly one of `f` or `preloaded` is None - Updated unit tests to ensure preloaded loads correctly. Reviewed By: knottb Differential Revision: D21212000 fbshipit-source-id: 82042e208a3a95328ec5472dba9ce964d8033075
1 parent 1aa8ced commit 1b295d9

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

crypten/__init__.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def _setup_przs():
202202

203203

204204
def load_from_party(
205-
f,
205+
f=None,
206206
preloaded=None,
207207
encrypted=False,
208208
dummy_model=None,
@@ -245,9 +245,16 @@ def load_from_party(
245245
src >= 0 and src < comm.get().get_world_size()
246246
), "Load failed: src must be in [0, world_size)"
247247

248+
assert (f is None and (preloaded is not None)) or (
249+
(f is not None) and preloaded is None
250+
), "Exactly one of f and preloaded must not be None"
251+
248252
# source party
249253
if comm.get().get_rank() == src:
250-
result = preloaded if preloaded else load_closure(f, **kwargs)
254+
if f is None:
255+
result = preloaded
256+
if preloaded is None:
257+
result = load_closure(f, **kwargs)
251258

252259
# Zero out the tensors / modules to hide loaded data from broadcast
253260
if torch.is_tensor(result):

test/test_crypten.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,16 @@ def custom_save_function(obj, f):
192192
complete_file, src=src, load_closure=(lambda f: None)
193193
)
194194

195+
# test pre-loaded
196+
encrypted_preloaded = crypten.load_from_party(
197+
src=src, preloaded=tensor
198+
)
199+
self._check(
200+
encrypted_preloaded,
201+
reference,
202+
"crypten.load() failed using preloaded",
203+
)
204+
195205
def test_save_load_module(self):
196206
"""Test that crypten.save and crypten.load properly save and load modules"""
197207
import tempfile

0 commit comments

Comments
 (0)