-
Notifications
You must be signed in to change notification settings - Fork 38
Description
Hi! Recently when I tried to benchmark torchcomms by iteratively launching NCCL kernels, I found the HBM usage is really high, and after several iterations OOM error popped out.
This phenomenon is very similar to the one in older version torch, because torch tries to ensure the tensor is not released until previous NCCL kernels finished. However, this feature considerably increases the HBM usage, and thus is disabled by setting TORCH_NCCL_AVOID_RECORD_STREAM=0 now.
Therefore, I think the reason torchcomms uses too much HBM is because TORCH_NCCL_AVOID_RECORD_STREAM hasn't been exposed to users yet.
Disabling this feature is quite critical, because torch has set it as default now, for the sake of reasonable HBM usage.
The following is my test snippet, and OOM took place at the last iteration:
for i in range(10):
output = torch.empty_like(input)
main_comm.all_to_all_single(output, input, async_op=False)
When I add torch.cuda.synchronize() at the end of each iteration to ensure previous NCCL kernels all finished, the problem seems to be solved:
for i in range(10):
output = torch.empty_like(input)
main_comm.all_to_all_single(output, input, async_op=False)
torch.cuda.synchronize()
That's why I think it is probably related to TORCH_NCCL_AVOID_RECORD_STREAM.
Thank!