KEMBAR78
NUMA binding integration with elastic agent and torchrun by raghavhrishi · Pull Request #149334 · pytorch/pytorch · GitHub
Skip to content

Conversation

@raghavhrishi
Copy link
Contributor

@raghavhrishi raghavhrishi commented Mar 17, 2025

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Mar 17, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/149334

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 066a805 with merge base ee72338 (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Mar 17, 2025

CLA Signed

The committers listed above are authorized under a signed CLA.

@raghavhrishi
Copy link
Contributor Author

@kwen2501
Copy link
Contributor

Thanks @raghavhrishi ! Can you please sign the CLA?

return numactlargs


class CoreComplex(Numa):
Copy link
Collaborator

@sanchitintel sanchitintel Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This option seems to be specific to AMD x86_64 processors, which have the concept of a core complex whose cores share L3 cache.
On Intel x86_64 processors, the L3 cache is typically at the granularity of a socket.
L1 & L2 caches are private to each physical core.

Would it be okay to disable this option on Intel x86_64 machines (I'm guessing users would only use this option by mistake on Intel x86_64 machines), or explain the behavior with a warning if it'd be used on an Intel x86_64 machine? @jingxu10, can you please share your opinion?

Thanks!

Copy link
Contributor Author

@raghavhrishi raghavhrishi Mar 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A warning message can be added when the core-complex option is used and also in the help page (while describing the --numa_binding option) so that users are aware of it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been updated in the recent commit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@sanchitintel sanchitintel Apr 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we consider the E P core case?

Thanks for your inputs, @jingxu10!
Looks like some variants of new data-center grade Xeon processors may also have E cores as well, so we should also probably consider them.

@leslie-fang-intel, please share your inputs. Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sharing your thoughts – it's a good idea. It could potentially be a follow-up Pull Request once we’ve had the time to consider the design and how best to integrate it.
cc: @arpitsardhana

@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 20, 2025
resultCpuList = []
for i in range(resultCpuLen):
if (cpusSharedCacheVal >> i) & 1 == 1:
resultCpuList.append(i)
Copy link

@ashesh2512 ashesh2512 Apr 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@raghavhrishi First, great job pushing on this feature!

Referencing your example in #148689, for the exclusive binding option,

If Rank 0 and Rank 1 are both affined to NUMA Node 0, the cores would be split as follows:
Rank 0:
numactl --physcpubind=0-3 --membind=0
Rank 1:
numactl --physcpubind=4-7 --membind=0

This assumes that a contiguous indexing of CPUs would result in the most optimal binding. Could you please confirm is this is indeed an assumption in this PR? In my experience, there are many node architectures where linear indexing of CPUs is not the norm, see Frontier e.g., - https://docs.olcf.ornl.gov/systems/frontier_user_guide.html#frontier-compute-nodes

If linear indexing of CPUs is indeed assumed, would it be possible to have a user option to specify --physcpubind or pass in the CPU/GPU topology?

Copy link
Contributor Author

@raghavhrishi raghavhrishi Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ashesh2512 Thanks for your comment!

Non-linear core indexing might have edge cases in scenarios where there are only two NUMA Nodes available for binding, and multiple ranks (e.g., 4) are affined to the same NUMA Node. In such cases, linear indexing might be necessary to address the issue effectively.

The exclusive binding strategy utilizes topology information to determine the NUMA Node associated with each rank. Once identified, it ensures that ranks affined to the same NUMA Node are assigned distinct sets of cores using physcpubind, preventing overlap. This approach ensures that ranks sharing affinity with a NUMA Node do not use the same cores. The strategy uses the system's underlying topology information and avoids cross-NUMA binding.

As a potential enhancement, we could consider adding an option for users to specify the cores they wish to use in a follow-up pull request after reviewing the design.

cc: @arpitsardhana

Copy link

@ashesh2512 ashesh2512 Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@raghavhrishi Thanks, I think that an option for users to specify the cores they wish to use would be ideal in a follow up PR. I could help with that.

For context, one of the architectures I work with, a single compute node (8 GPUs per node) has the following CPU/GPU affinity. Ideally, the user would be able to bind a process to one or multiple cores, and set the GPU index in PyTorch accordingly.

NUMA 0:
hardware threads 000-007, 064-071 | GPU 4
hardware threads 008-015, 072-079 | GPU 5

NUMA 1:
hardware threads 016-023, 080-087 | GPU 2
hardware threads 024-031, 088-095 | GPU 3

NUMA 2:
hardware threads 032-039, 096-103 | GPU 6
hardware threads 040-047, 104-111 | GPU 7

NUMA 3:
hardware threads 048-055, 112-119 | GPU 0
hardware threads 056-063, 120-127 | GPU 1

@raghavhrishi raghavhrishi requested a review from jeffdaily as a code owner April 26, 2025 04:50
@kwen2501 kwen2501 requested a review from kiukchung April 28, 2025 22:38
requirements.txt Outdated
psutil
pyyaml
requests
pynvml
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@atalman @malfet This seems to introduce a dependency. wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gentle ping @malfet @atalman

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, we deliberately decided not to depend on pynvml, as one can very easily rewrite everything one need with ctypes

Moreover, it's a big no-go for something like ROCM or XPU

Comment on lines 49 to 74
def get_gpu_count(self):
# Initialize NVML
pynvml.nvmlInit()
# Get the number of GPU devices
device_count = pynvml.nvmlDeviceGetCount()
# Shutdown NVML
pynvml.nvmlShutdown()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a device-generic way?
There should be some methods in torch.accelerator package now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you think of this?

Comment on lines 59 to 70
# returns array indexed by GPU id and mapping to value NUMA node id
def get_numa_nodes(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would appreciate an example of the return.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suppose we have 4 GPUs, and they are connected to the following NUMA nodes:

GPU 0 → NUMA Node 0

GPU 1 → NUMA Node 0

GPU 2 → NUMA Node 1

GPU 3 → NUMA Node 1

Then the function would return:

[0, 0, 1, 1]

Comment on lines 75 to 103
for busID in pciBusIDs:
pciFields = busID.split(":")
pciDir = f"{pciFields[0][-4:]}:{pciFields[1]}:{pciFields[2]}"
numaFile = NUMA_CMD.format(value=pciDir.lower())
try:
with open(numaFile) as numa_node_text:
node = int(numa_node_text.read())
numaNodes.append(node)
except FileNotFoundError:
print(f"The file {numaFile} does not exist.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you comment on this block?
Also, is it worth for NVML to add an API to return the needed value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For each GPU's PCI bus ID, constructs the sysfs path to its NUMA node file & reads the NUMA node associated with it. The function returns a list of NUMA nodes associated with each GPU.

Comment on lines 101 to 112
# returns a bitmap for each core, its sibling cores
def get_thread_siblings(self, cpu):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the function for?

Copy link
Contributor Author

@raghavhrishi raghavhrishi May 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_thread_siblings identifies which other CPUs (cores) are on the same NUMA node as the current CPU.

@stas00
Copy link
Contributor

stas00 commented Apr 29, 2025

Super!!! Thank you for implementing this, @raghavhrishi

This issue could be closed as well when merged: #115305

@pdesupinski
Copy link
Contributor

Excited for this @raghavhrishi. Any update on the timeline?

@raghavhrishi
Copy link
Contributor Author

@kwen2501: I've addressed the comments in the PR. Please let me know if there's anything else needed to proceed with the merge.
cc: @arpitsardhana

@raghavhrishi raghavhrishi requested a review from kwen2501 June 14, 2025 05:28
@kwen2501
Copy link
Contributor

kwen2501 commented Jun 16, 2025

Thanks for the improvements.

In general, I am wondering if there is a way to do it in a device-agnostic way. But I understand torch.accelerator APIs are not there yet (like, calculating the distance between a GPU and a CPU). cc @albanD. So perhaps it may be okay as is in the PR for now.

If we'd like to avoid a direct dependency on pynvml (as you did to requirements.txt), can we put a check in torchrun to see if pynvml is available? If available we use the code here; if not, let's fall back (doesn't hurt?)
cc @malfet @atalman

I will defer to @kiukchung and @d4l3k for final decision.

numa_cmd = None
py_executable = os.getenv("PYTHON_EXEC", sys.executable)
if args.numa_binding:
numa_cmd = update_with_numa_binding_pytorch(args.numa_binding)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we implement this at the elastic agent level? Putting this logic here means only CLI users can get numa control and not via the programmatic API

@d4l3k
Copy link
Member

d4l3k commented Jun 23, 2025

@raghavhrishi do you have bandwidth to update this PR? There's still some of refactoring required to get this into a good state

The two main things are:

  1. make nvml a soft dependency
  2. refactor the integration in torchelastic to operate at the agent level (where we launch subprocesses/multiprocessing) rather than in the arg parsing/wrapper script

Primarily asking since we'd like to land this support and have someone who might be interested in pushing this over the line

We could also land this in pieces -- i.e. land the helper utilities and then follow up with a cleaner torchelastic integration

albanD
albanD previously requested changes Jun 23, 2025
"Can be used to override custom logging behavior.",
)

parser.add_argument(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @EikanWang didn't someone from your end already send a PR to do NUMA binding in torchrun? That vaguely rings a bell to me.

Copy link
Contributor

@jingxu10 jingxu10 Jun 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @albanD , do you mean the script https://github.com/pytorch/pytorch/blob/main/torch/backends/xeon/run_cpu.py or #133835 or a separate recent PR?
Discussed with Nikita before, code changes in 133835 involves too many things, I'll split it into smaller PRs later.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ho yes, https://github.com/pytorch/pytorch/blob/main/torch/backends/xeon/run_cpu.py is what I had in mind. @raghavhrishi any link between the two?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD I think this file that you have mentioned is different from this MR's implementation.

setup.py Outdated
"networkx",
"jinja2",
"fsspec",
"pynvml>=11.4.1",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding such a hard dependency is definitely not ok without much deeper considerations.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 24, 2025
@pdesupinski pdesupinski force-pushed the raghavhrishi/numa-binding-torchrun branch from dc8b4f6 to b8ba819 Compare July 24, 2025 19:42
@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Jul 24, 2025
@facebook-github-bot
Copy link
Contributor

@pdesupinski has imported this pull request. If you are a Meta employee, you can view this in D78319234.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 24, 2025
@pdesupinski pdesupinski force-pushed the raghavhrishi/numa-binding-torchrun branch from b8ba819 to 066a805 Compare July 25, 2025 03:10
@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Jul 25, 2025
@facebook-github-bot
Copy link
Contributor

@pdesupinski has imported this pull request. If you are a Meta employee, you can view this in D78319234.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 25, 2025
@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

yangw-dev pushed a commit that referenced this pull request Aug 1, 2025
Implements #148689

Pull Request resolved: #149334
Approved by: https://github.com/d4l3k

Co-authored-by: Paul de Supinski <pdesupinski@gmail.com>
pytorchmergebot pushed a commit that referenced this pull request Aug 12, 2025
# Context
This is an extension of #149334.

# This PR
Add support for NUMA bindings with Callable entrypoints, such as `do_train` instead of `/usr/local/bin/python`.

Most notably, we utilize a hack in order to force `Process.start()` to use custom NUMA bindings for each subprocess. Please search for `HACK:` in the code to see a description of the implementation we chose, and #160006 for discussion of alternatives and why this is necessary.

Other changes:
* Remove unnecessary `--preferred` option from all binding strategies. By default, Linux already allocates memory to the NUMA node local to the CPU which triggered the allocation. (See [MPOL_LOCAL](https://man7.org/linux/man-pages/man2/set_mempolicy.2.html).)
* Refactor so that the main API is `maybe_wrap_command_with_numa_bindings`, which computes bindings for a single rank at a time, rather than `maybe_wrap_with_numa_bindings` which computed bindings for all ranks at once. This allowed for more code sharing between `Callable` and `str` entrypoints.

# Test Plan
## Automated
`$ pytest test/test_numa_binding.py`

## Manual
Using [this benchmark,](https://gist.github.com/pdesupinski/bbe01ade455d86e989794f2c612e2d91), ran

```
$ PYTHONUNBUFFERED=1 LOGLEVEL=INFO perf stat -e ls_dmnd_fills_from_sys.dram_io_far,ls_dmnd_fills_from_sys.dram_io_near -- python -m torch.distributed.run --standalone --nproc-per-node=8 --numa-binding=node --run-path mlp_train.py 2>&1 | tee node_callable.txt && PYTHONUNBUFFERED=1 LOGLEVEL=INFO perf stat -e ls_dmnd_fills_from_sys.dram_io_far,ls_dmnd_fills_from_sys.dram_io_near -- python -u -m torch.distributed.run --standalone --nproc-per-node=8 --run-path mlp_train.py 2>&1 | tee none_callable.txt
```

and observed
* 6.6% remote memory accesses with 'node' bindings
* 11.6% remote without bindings

I also ran similar with `str` entrypoints as before just to be sure it's still working.

NOTE: [--run-path triggers the code to be run inside a `Callable`.](https://github.com/pytorch/pytorch/blob/017259f9c65b6fad55fb9597d7077e2543eaae46/torch/distributed/run.py#L870)

Pull Request resolved: #160163
Approved by: https://github.com/d4l3k
chuanhaozhuge pushed a commit that referenced this pull request Aug 14, 2025
# Context
This is an extension of #149334.

# This PR
Add support for NUMA bindings with Callable entrypoints, such as `do_train` instead of `/usr/local/bin/python`.

Most notably, we utilize a hack in order to force `Process.start()` to use custom NUMA bindings for each subprocess. Please search for `HACK:` in the code to see a description of the implementation we chose, and #160006 for discussion of alternatives and why this is necessary.

Other changes:
* Remove unnecessary `--preferred` option from all binding strategies. By default, Linux already allocates memory to the NUMA node local to the CPU which triggered the allocation. (See [MPOL_LOCAL](https://man7.org/linux/man-pages/man2/set_mempolicy.2.html).)
* Refactor so that the main API is `maybe_wrap_command_with_numa_bindings`, which computes bindings for a single rank at a time, rather than `maybe_wrap_with_numa_bindings` which computed bindings for all ranks at once. This allowed for more code sharing between `Callable` and `str` entrypoints.

# Test Plan
## Automated
`$ pytest test/test_numa_binding.py`

## Manual
Using [this benchmark,](https://gist.github.com/pdesupinski/bbe01ade455d86e989794f2c612e2d91), ran

```
$ PYTHONUNBUFFERED=1 LOGLEVEL=INFO perf stat -e ls_dmnd_fills_from_sys.dram_io_far,ls_dmnd_fills_from_sys.dram_io_near -- python -m torch.distributed.run --standalone --nproc-per-node=8 --numa-binding=node --run-path mlp_train.py 2>&1 | tee node_callable.txt && PYTHONUNBUFFERED=1 LOGLEVEL=INFO perf stat -e ls_dmnd_fills_from_sys.dram_io_far,ls_dmnd_fills_from_sys.dram_io_near -- python -u -m torch.distributed.run --standalone --nproc-per-node=8 --run-path mlp_train.py 2>&1 | tee none_callable.txt
```

and observed
* 6.6% remote memory accesses with 'node' bindings
* 11.6% remote without bindings

I also ran similar with `str` entrypoints as before just to be sure it's still working.

NOTE: [--run-path triggers the code to be run inside a `Callable`.](https://github.com/pytorch/pytorch/blob/017259f9c65b6fad55fb9597d7077e2543eaae46/torch/distributed/run.py#L870)

Pull Request resolved: #160163
Approved by: https://github.com/d4l3k
chuanhaozhuge pushed a commit that referenced this pull request Aug 18, 2025
# Context
This is an extension of #149334.

# This PR
Add support for NUMA bindings with Callable entrypoints, such as `do_train` instead of `/usr/local/bin/python`.

Most notably, we utilize a hack in order to force `Process.start()` to use custom NUMA bindings for each subprocess. Please search for `HACK:` in the code to see a description of the implementation we chose, and #160006 for discussion of alternatives and why this is necessary.

Other changes:
* Remove unnecessary `--preferred` option from all binding strategies. By default, Linux already allocates memory to the NUMA node local to the CPU which triggered the allocation. (See [MPOL_LOCAL](https://man7.org/linux/man-pages/man2/set_mempolicy.2.html).)
* Refactor so that the main API is `maybe_wrap_command_with_numa_bindings`, which computes bindings for a single rank at a time, rather than `maybe_wrap_with_numa_bindings` which computed bindings for all ranks at once. This allowed for more code sharing between `Callable` and `str` entrypoints.

# Test Plan
## Automated
`$ pytest test/test_numa_binding.py`

## Manual
Using [this benchmark,](https://gist.github.com/pdesupinski/bbe01ade455d86e989794f2c612e2d91), ran

```
$ PYTHONUNBUFFERED=1 LOGLEVEL=INFO perf stat -e ls_dmnd_fills_from_sys.dram_io_far,ls_dmnd_fills_from_sys.dram_io_near -- python -m torch.distributed.run --standalone --nproc-per-node=8 --numa-binding=node --run-path mlp_train.py 2>&1 | tee node_callable.txt && PYTHONUNBUFFERED=1 LOGLEVEL=INFO perf stat -e ls_dmnd_fills_from_sys.dram_io_far,ls_dmnd_fills_from_sys.dram_io_near -- python -u -m torch.distributed.run --standalone --nproc-per-node=8 --run-path mlp_train.py 2>&1 | tee none_callable.txt
```

and observed
* 6.6% remote memory accesses with 'node' bindings
* 11.6% remote without bindings

I also ran similar with `str` entrypoints as before just to be sure it's still working.

NOTE: [--run-path triggers the code to be run inside a `Callable`.](https://github.com/pytorch/pytorch/blob/017259f9c65b6fad55fb9597d7077e2543eaae46/torch/distributed/run.py#L870)

Pull Request resolved: #160163
Approved by: https://github.com/d4l3k
pytorchmergebot pushed a commit that referenced this pull request Aug 19, 2025
…60848)

# Context
Another fix to enable broad rollout of #149334.

The implementation assumes that the trainer process with local rank `n` only uses device `cuda:n`. However, there are sometimes jobs with more than one GPU per process, in which case our assumption could be incorrect and actually lead to worse memory locality.

# This PR
As titled.

Pull Request resolved: #160848
Approved by: https://github.com/kiukchung
can-gaa-hou pushed a commit to can-gaa-hou/pytorch that referenced this pull request Aug 22, 2025
# Context
This is an extension of pytorch#149334.

# This PR
Add support for NUMA bindings with Callable entrypoints, such as `do_train` instead of `/usr/local/bin/python`.

Most notably, we utilize a hack in order to force `Process.start()` to use custom NUMA bindings for each subprocess. Please search for `HACK:` in the code to see a description of the implementation we chose, and pytorch#160006 for discussion of alternatives and why this is necessary.

Other changes:
* Remove unnecessary `--preferred` option from all binding strategies. By default, Linux already allocates memory to the NUMA node local to the CPU which triggered the allocation. (See [MPOL_LOCAL](https://man7.org/linux/man-pages/man2/set_mempolicy.2.html).)
* Refactor so that the main API is `maybe_wrap_command_with_numa_bindings`, which computes bindings for a single rank at a time, rather than `maybe_wrap_with_numa_bindings` which computed bindings for all ranks at once. This allowed for more code sharing between `Callable` and `str` entrypoints.

# Test Plan
## Automated
`$ pytest test/test_numa_binding.py`

## Manual
Using [this benchmark,](https://gist.github.com/pdesupinski/bbe01ade455d86e989794f2c612e2d91), ran

```
$ PYTHONUNBUFFERED=1 LOGLEVEL=INFO perf stat -e ls_dmnd_fills_from_sys.dram_io_far,ls_dmnd_fills_from_sys.dram_io_near -- python -m torch.distributed.run --standalone --nproc-per-node=8 --numa-binding=node --run-path mlp_train.py 2>&1 | tee node_callable.txt && PYTHONUNBUFFERED=1 LOGLEVEL=INFO perf stat -e ls_dmnd_fills_from_sys.dram_io_far,ls_dmnd_fills_from_sys.dram_io_near -- python -u -m torch.distributed.run --standalone --nproc-per-node=8 --run-path mlp_train.py 2>&1 | tee none_callable.txt
```

and observed
* 6.6% remote memory accesses with 'node' bindings
* 11.6% remote without bindings

I also ran similar with `str` entrypoints as before just to be sure it's still working.

NOTE: [--run-path triggers the code to be run inside a `Callable`.](https://github.com/pytorch/pytorch/blob/017259f9c65b6fad55fb9597d7077e2543eaae46/torch/distributed/run.py#L870)

Pull Request resolved: pytorch#160163
Approved by: https://github.com/d4l3k
can-gaa-hou pushed a commit to can-gaa-hou/pytorch that referenced this pull request Aug 22, 2025
…torch#160848)

# Context
Another fix to enable broad rollout of pytorch#149334.

The implementation assumes that the trainer process with local rank `n` only uses device `cuda:n`. However, there are sometimes jobs with more than one GPU per process, in which case our assumption could be incorrect and actually lead to worse memory locality.

# This PR
As titled.

Pull Request resolved: pytorch#160848
Approved by: https://github.com/kiukchung
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
# Context
This is an extension of pytorch#149334.

# This PR
Add support for NUMA bindings with Callable entrypoints, such as `do_train` instead of `/usr/local/bin/python`.

Most notably, we utilize a hack in order to force `Process.start()` to use custom NUMA bindings for each subprocess. Please search for `HACK:` in the code to see a description of the implementation we chose, and pytorch#160006 for discussion of alternatives and why this is necessary.

Other changes:
* Remove unnecessary `--preferred` option from all binding strategies. By default, Linux already allocates memory to the NUMA node local to the CPU which triggered the allocation. (See [MPOL_LOCAL](https://man7.org/linux/man-pages/man2/set_mempolicy.2.html).)
* Refactor so that the main API is `maybe_wrap_command_with_numa_bindings`, which computes bindings for a single rank at a time, rather than `maybe_wrap_with_numa_bindings` which computed bindings for all ranks at once. This allowed for more code sharing between `Callable` and `str` entrypoints.

# Test Plan
## Automated
`$ pytest test/test_numa_binding.py`

## Manual
Using [this benchmark,](https://gist.github.com/pdesupinski/bbe01ade455d86e989794f2c612e2d91), ran

```
$ PYTHONUNBUFFERED=1 LOGLEVEL=INFO perf stat -e ls_dmnd_fills_from_sys.dram_io_far,ls_dmnd_fills_from_sys.dram_io_near -- python -m torch.distributed.run --standalone --nproc-per-node=8 --numa-binding=node --run-path mlp_train.py 2>&1 | tee node_callable.txt && PYTHONUNBUFFERED=1 LOGLEVEL=INFO perf stat -e ls_dmnd_fills_from_sys.dram_io_far,ls_dmnd_fills_from_sys.dram_io_near -- python -u -m torch.distributed.run --standalone --nproc-per-node=8 --run-path mlp_train.py 2>&1 | tee none_callable.txt
```

and observed
* 6.6% remote memory accesses with 'node' bindings
* 11.6% remote without bindings

I also ran similar with `str` entrypoints as before just to be sure it's still working.

NOTE: [--run-path triggers the code to be run inside a `Callable`.](https://github.com/pytorch/pytorch/blob/017259f9c65b6fad55fb9597d7077e2543eaae46/torch/distributed/run.py#L870)

Pull Request resolved: pytorch#160163
Approved by: https://github.com/d4l3k
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…torch#160848)

# Context
Another fix to enable broad rollout of pytorch#149334.

The implementation assumes that the trainer process with local rank `n` only uses device `cuda:n`. However, there are sometimes jobs with more than one GPU per process, in which case our assumption could be incorrect and actually lead to worse memory locality.

# This PR
As titled.

Pull Request resolved: pytorch#160848
Approved by: https://github.com/kiukchung
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (c10d) release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.