KEMBAR78
[SymmMem] Experimental NVSHMEM integration by kwen2501 · Pull Request #151261 · pytorch/pytorch · GitHub
Skip to content

Conversation

@kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Apr 14, 2025

Stack from ghstack (oldest at bottom):

Adding NVSHMEM as a backend for SymmetricMemory, implementation of which is in NVSHMEMSymmetricMemory.cu.

Moving some helper functions in CUDASymmetricMemory.cu to CUDASymmetricMemoryUtils.cpp, so that they can be shared by NVSHMEMSymmetricMemory. These functions are mostly side-band exchange helpers (store_all_gather, IpcChannel, etc).

Adding TORCH_SYMMEM to control which implementation to use for CUDA tensors, currently support: CUDA (in-house impl), NVSHMEM.

The NVSHMEM feature is gated by build-time flag: USE_NVSHMEM=1. And NVSHMEM_HOME setting is required (TODO).

Ported most code from #146593.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Apr 14, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Apr 14, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/151261

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 01f7e72 with merge base 9c864f9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Divigroup-RAP pushed a commit to Divigroup-RAP/PYTORCH that referenced this pull request Apr 22, 2025
ghstack-source-id: f54274a
Pull Request resolved: pytorch/pytorch#151261
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k

[ghstack-poisoned]
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k

[ghstack-poisoned]
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k

[ghstack-poisoned]
@kwen2501 kwen2501 changed the title [WIP] experimental NVSHMEM integration [SymmMem] Experimental NVSHMEM integration Apr 28, 2025
Adding NVSHMEM as a backend for `SymmetricMemory`, implementation of which is in `NVSHMEMSymmetricMemory.cu`.

Moving some helper functions in `CUDASymmetricMemory.cu` to `CUDASymmetricMemoryUtils.cpp`, so that they can be shared by `NVSHMEMSymmetricMemory`. These functions are mostly side-band exchange helpers (`store_all_gather`, `IpcChannel`, etc).

Adding `TORCH_SYMMEM` to control which implementation to use for CUDA tensors, currently support: `CUDA`, `NVSHMEM`.

The NVSHMEM feature is gated by build-time flag: `USE_NVSHMEM=1`. And `NVSHMEM_HOME` setting is required (TODO).

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k

[ghstack-poisoned]
Comment on lines +980 to +1014

# Use env var for these for now for prototyping purposes
set(USE_NVSHMEM $ENV{USE_NVSHMEM} CACHE BOOL "Enable NVSHMEM support")
set(NVSHMEM_HOME $ENV{NVSHMEM_HOME} CACHE PATH "Path to NVSHMEM build dir")

if(USE_NVSHMEM)
set(NVSHMEM_INCLUDE_DIR "${NVSHMEM_HOME}/include")
set(NVSHMEM_LIB_DIR "${NVSHMEM_HOME}/lib")

include_directories(${NVSHMEM_INCLUDE_DIR})

# Linking with nvshmem requires the source binary to be built with -rdc
# which is not viable for libtorch_cuda. So we isolate the linking of
# nvshmem in nvshmem_extension.
add_library(nvshmem_extension SHARED
"${TORCH_SRC_DIR}/csrc/distributed/c10d/nvshmem_extension.cu"
"${TORCH_SRC_DIR}/csrc/distributed/c10d/NVSHMEMSymmetricMemory.cu"
"${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemoryUtils.cpp"
"${TORCH_SRC_DIR}/csrc/distributed/c10d/cuda/utils.cpp"
)
set_target_properties(nvshmem_extension PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
target_compile_options(nvshmem_extension PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-rdc=true>)
target_compile_options(nvshmem_extension PRIVATE "-U__CUDA_NO_HALF_OPERATORS__")
target_link_libraries(nvshmem_extension PRIVATE
${NVSHMEM_LIB_DIR}/libnvshmem.a
${NVSHMEM_LIB_DIR}/nvshmem_bootstrap_uid.so
)
target_link_libraries(nvshmem_extension PRIVATE mlx5)
target_link_libraries(torch_cuda PRIVATE nvshmem_extension)
install(TARGETS nvshmem_extension EXPORT Caffe2Targets DESTINATION lib)
install(
FILES ${NVSHMEM_LIB_DIR}/nvshmem_bootstrap_uid.so
DESTINATION lib
)
endif()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @atalman @malfet do you mind having a look at this section? Today we gate the NVSHMEM feature with build-time flag USE_NVSHMEM=1. We may change it later.

Adding NVSHMEM as a backend for `SymmetricMemory`, implementation of which is in `NVSHMEMSymmetricMemory.cu`.

Moving some helper functions in `CUDASymmetricMemory.cu` to `CUDASymmetricMemoryUtils.cpp`, so that they can be shared by `NVSHMEMSymmetricMemory`. These functions are mostly side-band exchange helpers (`store_all_gather`, `IpcChannel`, etc).

Adding `TORCH_SYMMEM` to control which implementation to use for CUDA tensors, currently support: `CUDA`, `NVSHMEM`.

The NVSHMEM feature is gated by build-time flag: `USE_NVSHMEM=1`. And `NVSHMEM_HOME` setting is required (TODO).

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k

[ghstack-poisoned]

m.def("nvshmem_broadcast(Tensor(a!) input, str group_name) -> Tensor(a!)");
m.def(
"nvshmem_reduce_scatter_out(Tensor input, str group_name, Tensor(a!) out) -> Tensor(a!)");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do these two APIs have the nvshmem_ prefix? Do we plan to unify nvshmem_reduce_scatter_out with reduce_scatter_out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. I expect we can remove the prefix by the release.

dim3 grid_dim(32), block_dim(544);
auto stream = at::cuda::getCurrentCUDAStream();
nvshmemx_barrier_on_stream(team, stream);
nvshmemx_collective_launch(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my learning, we need to use nvshmemx_collective_launch because nvshmem_all_reduce_kernel uses some nvshmem APIs. Is this correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From documentation:

The nvshmemx_collective_launch function must be used to launch CUDA kernels on the GPU when the CUDA kernels use NVSHMEM synchronization or collective APIs (e.g., nvshmem_wait, nvshmem_barrier, nvshmem_barrier_all, or any other collective operation). CUDA kernels that do not use synchronizing NVSHMEM APIs (or that do not use NVSHMEM APIs at all), are not required to be launched by this API. This call is collective across the PEs in the NVSHMEM job.

Adding NVSHMEM as a backend for `SymmetricMemory`, implementation of which is in `NVSHMEMSymmetricMemory.cu`.

Moving some helper functions in `CUDASymmetricMemory.cu` to `CUDASymmetricMemoryUtils.cpp`, so that they can be shared by `NVSHMEMSymmetricMemory`. These functions are mostly side-band exchange helpers (`store_all_gather`, `IpcChannel`, etc).

Adding `TORCH_SYMMEM` to control which implementation to use for CUDA tensors, currently support: `CUDA` (in-house impl), `NVSHMEM`.

The NVSHMEM feature is gated by build-time flag: `USE_NVSHMEM=1`. And `NVSHMEM_HOME` setting is required (TODO).

Ported most code from #146593.

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k

[ghstack-poisoned]
true;
}

// Query environment variable to get the backend used for CUDA Symmetric Memory.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's your plan for documenting this to users?
a) how long-lived is it? (iiuc, it's long-lived since we have to support 'Cuda' backend for AMD
b) what are the user-facing tradeoffs? Why would i choose CUDASymmetricMemory vs NVSHMEMSymmetricMemory in a given scenario?

Copy link
Contributor Author

@kwen2501 kwen2501 Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a) The env itself could be longer term due to possibility of "NCCL" backend. The "CUDASymmetricMemory" impl would stay as long as needed for the reason you mentioned.
b) intra vs inter capability is the main difference from user point of view.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(b) is still not clear to me. Does nvshem backend support intra+inter or is it purely intra?

if nvshmem supports the superset of cuda backend except for AMD support, then I would argue we delete the env, and autodetect AMD to switch it to the cuda backend. otherwise, default to the nvshmem backend?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The env is also for NCCLSymmetricMemory vs NVSHMEMSymmetricMemory.

@huydhn
Copy link
Contributor

huydhn commented Apr 29, 2025

@pytorchbot drci

# which is not viable for libtorch_cuda. So we isolate the linking of
# nvshmem in nvshmem_extension.
add_library(nvshmem_extension SHARED
"${TORCH_SRC_DIR}/csrc/distributed/c10d/nvshmem_extension.cu"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i was gonna suggest having a dedicated folder for the nvshmem backend files so its clear to readers that they are built into a separate compilation unit.

however maybe its an annoyingly long path? csrc/distributed/c10d/symm_mem/common and /nvshmem

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, totally agree with a dedicated folder. The c10d folder is a mix of everything today, will refactor together with other libraries later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can have a symmetric memory dedicated folder?

#define CUDART_SUPPORTS_MULTICAST
#endif

namespace {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should really do the code-motion part in a previous PR so its easier to tell what is real changes


static size_t store_comm_seq_id = 0;

template <typename T>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 functions moved to utils file; not checking the details, assuming its just a direct move

namespace c10d::symmetric_memory {

bool device_has_multicast_support(int device_idx) {
if (c10::utils::check_env("TORCH_SYMM_MEM_DISABLE_MULTICAST") == true) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

planning to document these ENVs? any idea if they are going to be long-lived?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably clean this env.

true;
}

// Query environment variable to get the backend used for CUDA Symmetric Memory.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(b) is still not clear to me. Does nvshem backend support intra+inter or is it purely intra?

if nvshmem supports the superset of cuda backend except for AMD support, then I would argue we delete the env, and autodetect AMD to switch it to the cuda backend. otherwise, default to the nvshmem backend?

/* Start of CUDASymmetricMemory implementation */

constexpr size_t signal_pad_size = 2048;
const std::string store_comm_prefix = "CUDASymmetricMemory";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious, since this is declared as 'extern' in the header, what does that mean when there is a definition in both torch and nvshmem backend? won't the torch definition always exist and take precedence even when we are using nvshmem backend? (did you confirm in tests that the nvshmem backend prefix is being used)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right that usually compiler does not allow defining the same variable in multiple source files.
Here we are enjoying a special rule for extern const where static linkage applies -- meaning the variable is not visible to other translation units.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wconstab thanks for your suggestion. I made a little class StoreExchange to encapsulate those store_ methods and the prefix. Then each SymmMem implementation can locally instantiate a StoreExchange instance to avoid namespace & prefix contention.

const at::Tensor& local_input,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::string reduce_op,
// NOLINTNEXTLINE(performance-unnecessary-value-param)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would it be clearer to do (void) reduce_op in the meta kernel? its not a big deal though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think dispatcher would complain about the signature mismatch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, you keep the signature.

in theory this is the preferred way to tell the compiler/linter that you're intentionally not using a provided value and iiuc it should silence linter errors, but .. your way is fine too

void foo(int x){
   (void)x;
   return;
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see what you mean now. Thanks for the suggestion!
In this specific case, the linter wants me to change std::string to const std::string& to avoid a copy in argument passing. But const std::string& is unfortunately not supported by torch dispatcher signature definition (if I am not mistaken).

virtual int get_rank() = 0;
virtual int get_world_size() = 0;

virtual std::vector<int> get_rank_to_global_rank() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this NYI or NotImplemented? (Will we implement it for Cuda backend? should this be documented? (why is this different than c10d existing rank mapping helpers)

Copy link
Contributor Author

@kwen2501 kwen2501 Apr 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it stands for unimplemented.
Good question whether other backends need it. Perhaps I should move them into NVSHMEMSymmetricMemory for now.

}
}

IpcChannel::IpcChannel()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i didn't review the IpcChannel code yet, maybe someone else can take a close look there

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a move. Sorry

m.def(
"memset32_(Tensor(a!) input, int offset, int val, int count) -> Tensor(a!)");

m.def("nvshmem_broadcast(Tensor(a!) input, str group_name) -> Tensor(a!)");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still want to ask the nvshmem_ prefix. Do we imagine that in the future different backend will export different broadcast or should we just have one broadcast and error out if the underlying backend doesn't support it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latter I guess.

return team;
}

at::Tensor nvshmem_broadcast(at::Tensor& input, const std::string& group_name) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one minor nit: the CUDA backend puts the ops (e.g., one_shot_all_reduce) in CUDASymmetricMemoryOps.cu. Do we want to follow that convention?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a matter of renaming this file. I think we can do it later (after we sedate the matter of whether NVSHMEM should be a SymmetricMemory class by itself or an extension to CUDASymmetricMemory)

Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do the APIs renaming, file moving in latter PRs. Please fix the linter error before landing.

@kwen2501
Copy link
Contributor Author

@fegin Thanks. No lint error AFAIK.

@kwen2501
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 30, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Copy link
Contributor

@fduwjj fduwjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in general this looks good but we really want to separate refactor and new code. Landing a large PR will increase the risk of your PR getting reverted which happens many times this half.

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

Adding NVSHMEM as a backend for `SymmetricMemory`, implementation of which is in `NVSHMEMSymmetricMemory.cu`.

Moving some helper functions in `CUDASymmetricMemory.cu` to `CUDASymmetricMemoryUtils.cpp`, so that they can be shared by `NVSHMEMSymmetricMemory`. These functions are mostly side-band exchange helpers (`store_all_gather`, `IpcChannel`, etc).

Adding `TORCH_SYMMEM` to control which implementation to use for CUDA tensors, currently support: `CUDA` (in-house impl), `NVSHMEM`.

The NVSHMEM feature is gated by build-time flag: `USE_NVSHMEM=1`. And `NVSHMEM_HOME` setting is required (TODO).

Ported most code from #146593.

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k

[ghstack-poisoned]
@kwen2501
Copy link
Contributor Author

kwen2501 commented May 1, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request May 2, 2025
Add an all-to-all impl based on NVSHMEM's on-stream API `nvshmemx_alltoallmem_on_stream`.

Pull Request resolved: #151498
Approved by: https://github.com/fegin, https://github.com/fduwjj
ghstack dependencies: #151261
pytorchmergebot pushed a commit that referenced this pull request May 2, 2025
Merge in/out splits into one tensor

Multi-block

Use sync instead of barrier

Use nvshmemx_collective_launch

Rotate blocks among peer

write back input splits

Parallel scan works

Use scan for output offsets

Use at most 16 blocks

Pull Request resolved: #151819
Approved by: https://github.com/ngimel, https://github.com/fduwjj
ghstack dependencies: #151261, #151498
pytorchmergebot pushed a commit that referenced this pull request May 2, 2025
@jcao-ai
Copy link

jcao-ai commented May 5, 2025

@kwen2501 Great job, and wonder whether is it possible to enable both backend in a single runtime (in future) ? Because it may make sense in some finer-grained scenarios (both inter-node & intra-node).

@github-actions github-actions bot deleted the gh/kwen2501/141/head branch June 15, 2025 02:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants