KEMBAR78
[SymmMem] Enable NVSHMEM for Triton by kwen2501 · Pull Request #155506 · pytorch/pytorch · GitHub
Skip to content

Conversation

@kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Jun 10, 2025

Stack from ghstack (oldest at bottom):

(This is an Experimental feature)
Allow Triton kernels to invoke NVSHMEM device functions.

Example Triton program

Key parts:

  • Call nvshmem.enable_triton() to initialize;
  • Call nvshmem.putmem_block in Triton kernel;
  • Add extern_libs kwarg at kernel invocation.
import torch.distributed._symmetric_memory._nvshmem_triton as nvshmem

@triton.jit
def put_kernel(
    dst_ptr,
    src_ptr,
    numel: tl.constexpr,
    peer: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    nvshmem.putmem_block(dst_ptr, src_ptr, numel, peer)


if __name__ == "__main__":
    # Enable NVSHMEM for Triton
    nvshmem_lib = nvshmem.enable_triton()

    # Use torch Symmetric Memory to allocate Symmetric tensors
    ...

    peer = 1 - rank
    if rank == 0:
        kernel = put_kernel[(1, 1, 1)](
            dst_ptr,
            src_ptr,
            numel=numel,
            peer=peer,
            BLOCK_SIZE=BLOCK_SIZE,
            extern_libs=nvshmem_lib,
        ) 

    dist.barrier()
    if rank == 1:
        print(f"Rank {rank}: received {out=}")

Test output:

$ TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py -k test_triton_put
Rank 0: writing value 5 to Peer 1
Rank 1: received out=tensor([5, 5, 5, 5, 5, 5, 5, 5], device='cuda:1', dtype=torch.int8)

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Jun 10, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/155506

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 14e167b with merge base 4d9d884 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Jun 10, 2025
@kwen2501 kwen2501 requested review from fduwjj, fegin and ngimel June 10, 2025 00:19
// operations.
void nvshmemx_cumodule_init(uintptr_t module) {
auto cumodule = reinterpret_cast<CUmodule>(module);
TORCH_CHECK(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's time to implement NVSHMEM_CHECK similar to AT_CUDA_CHECK

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes!

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Jun 10, 2025
ghstack-source-id: 37a8e96
Pull-Request-resolved: #155506
@kwen2501 kwen2501 added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 10, 2025
@kwen2501
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@fduwjj
Copy link
Contributor

fduwjj commented Jun 10, 2025

Can you kindly make linter happy and is the unit test failure real?

"""
from triton.runtime.jit import JITFunction

from torch._C._distributed_c10d import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to update pyi file

Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty cool! Some nits.

torch.testing.assert_close(received_chunk, chunk)

@skipIfRocm
@requires_triton()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a decorator to skip the test if NVSHMEM is not available?

# Detect NVSHMEM device library path from python library path
if lib_dir is None:
py_lib_path = sysconfig.get_path("purelib")
lib_dir = py_lib_path + "/nvidia/nvshmem/lib"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: show we use os.path.join?

jit_function = kwargs["fn"].jit_function
kernel_cache, _, _, _ = jit_function.device_caches[device]
kernel = kernel_cache.get(key, None)
kernel.run
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this missing parens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the careful look.
It is interesting that without this line, things don't work.
And I don't actually want the run to execute.



@core.extern
def putmem_block(dst, src, nelems, pe, _builder=None): # type: ignore[no-untyped-def]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add a generic "torch.putmem_block" abstraction that can do dispatching? How hard is dynamic dispatch for cuda/triton kernel?

cc @wconstab

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@fduwjj
Copy link
Contributor

fduwjj commented Jun 11, 2025

Also since #155573 is merged. You might need to rebase your PR on top of it :)

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Jun 11, 2025
ghstack-source-id: c17af34
Pull-Request-resolved: #155506
@kwen2501
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Jun 13, 2025
This is a requirement of most SHMEM backends. Otherwise, allocations may misalign across ranks.

In this PR, we make the (total) input size and output size a constant number, even though the split sizes are created random. (Previously we sum the splits up as input size, which creates misalignment in SHMEM heap across ranks).

Pull Request resolved: #155835
Approved by: https://github.com/fduwjj, https://github.com/fegin, https://github.com/Skylion007
ghstack dependencies: #155506
pytorchmergebot pushed a commit that referenced this pull request Jun 14, 2025
No code enqueues entries to `ptr_to_symm_mem_`, thus it is always empty.
This PR removes it and supports relying functionalities via the `allocations_` map.

Pull Request resolved: #155968
Approved by: https://github.com/Skylion007
ghstack dependencies: #155506, #155835
pytorchmergebot pushed a commit that referenced this pull request Jun 14, 2025
`NVSHMEMSymmetricMemory.cu` and `nvshmem_extension.cu` are under the same compilation condition now (i.e. only when `USE_NVSHMEM=True`), see https://github.com/pytorch/pytorch/blob/main/caffe2/CMakeLists.txt#L1013-L1018.

Therefore there is no need to build an extra layer to hide dependency.

Pull Request resolved: #155971
Approved by: https://github.com/Skylion007
ghstack dependencies: #155506, #155835, #155968
pytorchmergebot pushed a commit that referenced this pull request Jun 15, 2025
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):

Calling `nvshmem_free` when an `NVSHMEMAllocation` is being destructed.

Use a `is_finalizing()` as a guard as done in `CUDASymmetricMemory.cu` to avoid "driver shutting down" error (destruction fiasco).

Pull Request resolved: #155975
Approved by: https://github.com/ngimel
ghstack dependencies: #155506, #155835, #155968, #155971
pytorchmergebot pushed a commit that referenced this pull request Jun 17, 2025
The rank-to-global-rank exchange is a major overhead in `NVSHMEMSymmetricMemory` creation.
We should cache its result on per-group basis.

Before this change:
```
TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py
exchanged_n_times: 18
```

After this change:
```
TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py
exchanged_n_times: 1
```

Pull Request resolved: #156116
Approved by: https://github.com/fegin, https://github.com/ngimel
ghstack dependencies: #155506, #155835, #155968, #155971, #155975
pytorchmergebot pushed a commit that referenced this pull request Jun 17, 2025
Avoiding a copy, not expecting a caller to change its value.

Pull Request resolved: #156117
Approved by: https://github.com/fegin
ghstack dependencies: #155506, #155835, #155968, #155971, #155975, #156116
pytorchmergebot pushed a commit that referenced this pull request Jun 19, 2025
so that we can pick the default backend for SymmetricMemory without
fully relying on env var `TORCH_SYMMMEM=CUDA | NVSHMEM`

On Python side, the following API is added:
`torch.distributed._symmetric_memory.is_nvshmem_available()`

Pull Request resolved: #156291
Approved by: https://github.com/Skylion007
ghstack dependencies: #155506, #155835, #155968, #155971, #155975, #156116, #156117
@github-actions github-actions bot deleted the gh/kwen2501/166/head branch July 14, 2025 02:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants