This repository was archived by the owner on May 13, 2025. It is now read-only.
Fixes issue 403 for select functions: add, sub, mul #413
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Types of changes
Motivation and Context / Related issue
At present, it is not possible to do things like
torch_tensor.add(cryptensor)ortorch_tensor + cryptensor. The problem is that functions like__radd__never get called becausetorch.Tensor.addfails with aTypeErrorrather than aNotImplementedError(which would trigger the reverse function to get called). This limitation leads to issues such as #403This PR fixes this issue for the
add,sub, andmulfunctions. The general approach is as follows:torch.Tensor.{add,sub,mul}in the__torch_function__handler via an@implementsdecorator.__init_subclass__function inCrypTensorthat ensures these decorators are inherited by subclasses ofCrypTensor.MPCTensordynamically adds functions likeadd,sub, andmulafter the subclass is created, the registration is also done manually for those functions inMPCTensor.MPCTensor.binary_wrapper_functionassumes specific structure ofMPCTensorthattorch.Tensordoes not have, we switch the order of the arguments if needed and alter the function name to be__radd__,__rsub__, etc.Note that it is not immediately clear how to make the same work for other functions like
matmulthat do not have an__rmatmul__or for functions that do not exist in PyTorch likeconv1d. It can be done but things will get pretty messy. So the question with this PR is if this is a path we want to continue on.How Has This Been Tested
This PR is currently an RFC so I have not deeply tested all the changes yet. I would first like to get feedback on whether we want to make this change at all.
That said, these simple examples pass:
Similarly, the example from #403 passes:
If we want to proceed in this direction, I will add full unit tests.
Checklist