-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[SymmMem] Enable NVSHMEM for Triton #155506
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/155506
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 14e167b with merge base 4d9d884 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| // operations. | ||
| void nvshmemx_cumodule_init(uintptr_t module) { | ||
| auto cumodule = reinterpret_cast<CUmodule>(module); | ||
| TORCH_CHECK( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's time to implement NVSHMEM_CHECK similar to AT_CUDA_CHECK
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes!
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
Can you kindly make linter happy and is the unit test failure real? |
| """ | ||
| from triton.runtime.jit import JITFunction | ||
|
|
||
| from torch._C._distributed_c10d import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you need to update pyi file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pretty cool! Some nits.
| torch.testing.assert_close(received_chunk, chunk) | ||
|
|
||
| @skipIfRocm | ||
| @requires_triton() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have a decorator to skip the test if NVSHMEM is not available?
| # Detect NVSHMEM device library path from python library path | ||
| if lib_dir is None: | ||
| py_lib_path = sysconfig.get_path("purelib") | ||
| lib_dir = py_lib_path + "/nvidia/nvshmem/lib" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: show we use os.path.join?
| jit_function = kwargs["fn"].jit_function | ||
| kernel_cache, _, _, _ = jit_function.device_caches[device] | ||
| kernel = kernel_cache.get(key, None) | ||
| kernel.run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this missing parens?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the careful look.
It is interesting that without this line, things don't work.
And I don't actually want the run to execute.
|
|
||
|
|
||
| @core.extern | ||
| def putmem_block(dst, src, nelems, pe, _builder=None): # type: ignore[no-untyped-def] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also add a generic "torch.putmem_block" abstraction that can do dispatching? How hard is dynamic dispatch for cuda/triton kernel?
cc @wconstab
|
Also since #155573 is merged. You might need to rebase your PR on top of it :) |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This is a requirement of most SHMEM backends. Otherwise, allocations may misalign across ranks. In this PR, we make the (total) input size and output size a constant number, even though the split sizes are created random. (Previously we sum the splits up as input size, which creates misalignment in SHMEM heap across ranks). Pull Request resolved: #155835 Approved by: https://github.com/fduwjj, https://github.com/fegin, https://github.com/Skylion007 ghstack dependencies: #155506
No code enqueues entries to `ptr_to_symm_mem_`, thus it is always empty. This PR removes it and supports relying functionalities via the `allocations_` map. Pull Request resolved: #155968 Approved by: https://github.com/Skylion007 ghstack dependencies: #155506, #155835
`NVSHMEMSymmetricMemory.cu` and `nvshmem_extension.cu` are under the same compilation condition now (i.e. only when `USE_NVSHMEM=True`), see https://github.com/pytorch/pytorch/blob/main/caffe2/CMakeLists.txt#L1013-L1018. Therefore there is no need to build an extra layer to hide dependency. Pull Request resolved: #155971 Approved by: https://github.com/Skylion007 ghstack dependencies: #155506, #155835, #155968
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): Calling `nvshmem_free` when an `NVSHMEMAllocation` is being destructed. Use a `is_finalizing()` as a guard as done in `CUDASymmetricMemory.cu` to avoid "driver shutting down" error (destruction fiasco). Pull Request resolved: #155975 Approved by: https://github.com/ngimel ghstack dependencies: #155506, #155835, #155968, #155971
The rank-to-global-rank exchange is a major overhead in `NVSHMEMSymmetricMemory` creation. We should cache its result on per-group basis. Before this change: ``` TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py exchanged_n_times: 18 ``` After this change: ``` TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py exchanged_n_times: 1 ``` Pull Request resolved: #156116 Approved by: https://github.com/fegin, https://github.com/ngimel ghstack dependencies: #155506, #155835, #155968, #155971, #155975
so that we can pick the default backend for SymmetricMemory without fully relying on env var `TORCH_SYMMMEM=CUDA | NVSHMEM` On Python side, the following API is added: `torch.distributed._symmetric_memory.is_nvshmem_available()` Pull Request resolved: #156291 Approved by: https://github.com/Skylion007 ghstack dependencies: #155506, #155835, #155968, #155971, #155975, #156116, #156117
Stack from ghstack (oldest at bottom):
(This is an Experimental feature)
Allow Triton kernels to invoke NVSHMEM device functions.
Example Triton program
Key parts:
nvshmem.enable_triton()to initialize;nvshmem.putmem_blockin Triton kernel;extern_libskwarg at kernel invocation.Test output:
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k