KEMBAR78
[Distributed] Improve efficiency of NaN checker by kwen2501 · Pull Request #135414 · pytorch/pytorch · GitHub
Skip to content

Conversation

@kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Sep 7, 2024

Some customers would like to run the NaN checks on the fly, so we are improving its efficiency.

Benchmarking

Allreduce 2G floats. TORCH_NCCL_NAN_CHECK=1
Red kernel: ncclAllreduce
Blue kernel: Nan check

Screenshot 2024-09-06 at 10 00 05 PM

Comparison with torch ops:

Let's say a user manually check for NaNs with the following torch ops before all-reduce:

torch.any(torch.isnan(x))
Screenshot 2024-09-06 at 10 14 53 PM

So our perf is on-par with torch ops.

Changes

  • Load from vidmem using "big packs" of 16 bytes
  • Bump blockDim.x from 256 to 512
  • Separate loads and checks into two loops, each of 8 iterations
  • Unroll the loops
  • Templated functions for checking NaN in a "big pack" based on dtype

cc @XilunWu @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o
Special thanks to @jbachan from NCCL!

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 7, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/135414

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ea492a5 with merge base 5f57be7 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Sep 7, 2024
@awgu
Copy link
Collaborator

awgu commented Sep 7, 2024

Base tensor is guaranteed to have 16-byte alignment, but a view into it does not have to be 🤔

@kwen2501
Copy link
Contributor Author

kwen2501 commented Sep 8, 2024

@awgu Good catch, thanks! Care taken now: 6e7340a

Copy link
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think, given the increased complexity of the kernel, it would be good to add a test that more carefully checks for cases where the NaN detector misses a NaN.

given we can't realistically afford to do a test where we loop through the indices of a large tensor and set each value to NaN exhaustively, do you think it makes sense to do some combination of (a) exhaustive testing on a small-medium tensor that still is large enough to exercise both the unrolled and suffix loops, (b) a test that sets a random index to NaN, so at least throughout many repetitions of CI we could expect a 'flaky' signal if we are missing certain values?

@kwen2501
Copy link
Contributor Author

kwen2501 commented Sep 9, 2024

Re (a): yes, I can add a test that shmoo's through small-medium sizes (and data types).

Re (b): Yep, I think the existing tests can be modified to support the randomness.

@kwen2501
Copy link
Contributor Author

kwen2501 commented Sep 9, 2024

@wconstab Test modified to cover wider size range.

@wconstab
Copy link
Contributor

wconstab commented Sep 9, 2024

one more thing- could be a separate PR, but we are still missing fp8 iiuc. We should definitely cover this. If its convenient to add in this PR, it might make sense if it is yet one more case of the kernel template

@kwen2501
Copy link
Contributor Author

kwen2501 commented Sep 9, 2024

@wconstab can it be in a separate PR for cleanness reason?
I also need to study CUDA's FP8 support.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this appears to only put NaN values on the I diagonal. What about something like this?

index = tuple([randint(...) for _ in len(size)])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, adopted.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iiuc this generalized kernel would only be used for float8? i guess in a later PR, you would possibly replace this by a specialized one too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, does checkChunk get called with a different ptr offset for each thread?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Below at line 134, checkChunk is called like this:
checkChunk<T>(ptr + offset);
offset accounts for different offsets for different threads.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confused, if we are sure we have enough data left for one call to CheckBytePack<T, B/T> doesn't that also mean we have enough data for a faster call to CheckBytePack?

Copy link
Contributor Author

@kwen2501 kwen2501 Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean why not a call to checkChunk?
The reason is that checkChunk checks on 8*BytePack in one call, while CheckBytePack checks 1 BytePack.
This slow loop here accounts for the case that we don't have 8*BytePack left.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh- yes, i got confused between the two. i think this makes sense.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So in summary, i think the algorithm is

Pre: always < 1 BytePack, since =1 would imply its already 16B aligned, so don't even use 'CheckByte', do a local check

Body: process chunks of 8 BytePack (e.g. 8*16 =  128B chunks) per call

Tail: since alignment is now guaranteed, just process (N < 8) 16B Bytepacks individually

@shuqiangzhang
Copy link
Contributor

I might understand wrong, but If "So our perf is on-par with torch ops.", why not just use torch.any(torch.isnan(x))?

@kwen2501
Copy link
Contributor Author

Good question. I had the same question too.

So the reasoning goes like this: (backward)

  • we need to stop communication from spreading NaNs;
  • we need to stop NCCL kernel from launching;
  • torch.any(torch.isnan(x)) does not stop NCCL kernel from launching.

Re why "we need to stop communication from spreading NaNs", here is a view from @wconstab :
"technically if we can be sure which rank (or, even which host) detected the first nan, then its OK to let the nan spread to some other hosts. but in practice i dont know if we have good enough way to align our logs on different hosts, so if we let the nan spread to a few other hosts we may lose track of which one was first”

@wconstab
Copy link
Contributor

why not just use torch.any(torch.isnan(x))?

another flavor on this is, we could use it if we could easily modify it to trap() on nan, instead of asynchronously producing a bool tensor that someone (who?) has to check (when?). We definitely don't want to do a cuda synchronize after each nan check and check it on the cpu side.

@kwen2501
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 11, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
Some customers would like to run the NaN checks on the fly, so we are improving its efficiency.

## Benchmarking
Allreduce 2G floats. `TORCH_NCCL_NAN_CHECK=1`
Red kernel: ncclAllreduce
Blue kernel: Nan check

<img width="1093" alt="Screenshot 2024-09-06 at 10 00 05 PM" src="https://github.com/user-attachments/assets/5501bc31-024f-4115-adb2-dd66eb4025d3">

## Comparison with torch ops:
Let's say a user manually check for NaNs with the following torch ops before all-reduce:
```
torch.any(torch.isnan(x))
```
<img width="1091" alt="Screenshot 2024-09-06 at 10 14 53 PM" src="https://github.com/user-attachments/assets/1f8b5f63-c955-4612-bb96-241b6c69959b">

So our perf is on-par with torch ops.

## Changes
- Load from vidmem using "big packs" of 16 bytes
- Bump `blockDim.x` from 256 to 512
- Separate loads and checks into two loops, each of 8 iterations
- Unroll the loops
- Templated functions for checking NaN in a "big pack" based on dtype

Special thanks to @jbachan from NCCL!
Pull Request resolved: pytorch#135414
Approved by: https://github.com/wconstab
@github-actions github-actions bot deleted the nan_perf branch October 12, 2024 02:06
github-merge-queue bot pushed a commit to intel/torch-xpu-ops that referenced this pull request Jul 23, 2025
Refer from pytorch/pytorch#125726,
pytorch/pytorch#135414.
Add nan check for xccl.
why we need to stop communication from spreading NaNs?
"technically if we can be sure which rank (or, even which host) detected
the first nan, then its OK to let the nan spread to some other hosts.
but in practice i dont know if we have good enough way to align our logs
on different hosts, so if we let the nan spread to a few other hosts we
may lose track of which one was first”

---------

Co-authored-by: mengfei25 <mengfei.li@Intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants