Skip to content
Closed
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions opacus/utils/batch_memory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
)
from torch.utils.data import BatchSampler, DataLoader, Sampler


class BatchSplittingSampler(Sampler[List[int]]):
"""
Samples according to the underlying instance of ``Sampler``, but splits
Expand Down Expand Up @@ -70,14 +69,16 @@ def __iter__(self):

def __len__(self):
if isinstance(self.sampler, BatchSampler):
return int(
return math.ceil(
len(self.sampler) * (self.sampler.batch_size / self.max_batch_size)
)
elif isinstance(self.sampler, UniformWithReplacementSampler) or isinstance(
self.sampler, DistributedUniformWithReplacementSampler
):
expected_batch_size = self.sampler.sample_rate * self.sampler.num_samples
return int(len(self.sampler) * (expected_batch_size / self.max_batch_size))
return math.ceil(
len(self.sampler) * (expected_batch_size / self.max_batch_size)
)

return len(self.sampler)

Expand Down