@@ -115,18 +115,6 @@ def get_name(self):
115115 """Returns the party name of the current process."""
116116 raise NotImplementedError ("get_name is not implemented" )
117117
118- def reset_communication_stats (self ):
119- """Resets communication statistics."""
120- raise NotImplementedError ("reset_communication_stats is not implemented" )
121-
122- def print_communication_stats (self ):
123- """Prints communication statistics."""
124- raise NotImplementedError ("print_communication_stats is not implemented" )
125-
126- def _log_communication (self , nelement ):
127- """Updates log of communication statistics."""
128- raise NotImplementedError ("_log_communication is not implemented" )
129-
130118 def reset_communication_stats (self ):
131119 """Resets communication statistics."""
132120 self .comm_rounds = 0
@@ -135,10 +123,12 @@ def reset_communication_stats(self):
135123
136124 def print_communication_stats (self ):
137125 """Prints communication statistics."""
138- logging .info ("====Communication Stats====" )
139- logging .info ("Rounds: {}" .format (self .comm_rounds ))
140- logging .info ("Bytes : {}" .format (self .comm_bytes ))
141- logging .info ("Comm time: {}" .format (self .comm_time ))
126+ import crypten
127+
128+ crypten .log ("====Communication Stats====" )
129+ crypten .log ("Rounds: {}" .format (self .comm_rounds ))
130+ crypten .log ("Bytes : {}" .format (self .comm_bytes ))
131+ crypten .log ("Comm time: {}" .format (self .comm_time ))
142132
143133 def _log_communication (self , nelement ):
144134 """Updates log of communication statistics."""
@@ -201,9 +191,9 @@ def logging_wrapper(self, *args, **kwargs):
201191 else : # one tensor communicated
202192 self ._log_communication (args [0 ].nelement ())
203193
204- tic = timeit .timeit ()
194+ tic = timeit .default_timer ()
205195 result = func (self , * args , ** kwargs )
206- toc = timeit .timeit ()
196+ toc = timeit .default_timer ()
207197
208198 self ._log_communication_time (toc - tic )
209199 return result
0 commit comments