-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Add rpc.api._barrier() #53423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add rpc.api._barrier() #53423
Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 1143616 (more details on the Dr. CI page):
🕵️ 2 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
| Job | Step | Action |
|---|---|---|
| Run tests | 🔁 rerun |
This comment was automatically generated by Dr. CI (expand for details).
Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be quite useful to have especially for RPC debugability, are there still plans to iterate on this? @H-Huang
|
Yes, I'll finish the changes needed for this |
closes #40166 This change exposes a new API, rpc.barrier() which blocks the main processes of all workers running RPC until the whole group completes this function. Optionally rpc.barrier can take in a set of worker_names and only synchronize across those worker names. Example: ```python import os import torch.multiprocessing as mp import torch.distributed.rpc as rpc os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "5678" world_size = 4 odd_num_workers = [f"worker{i}" for i in range(world_size) if i % 2] even_num_workers = [f"worker{i}" for i in range(world_size) if not i % 2] def worker(i): print(i) rpc.init_rpc(f"worker{i}", rank=i, world_size=world_size) if i % 2: print(f"start barrier {i}") rpc.barrier(set(odd_num_workers)) else: print(f"start barrier {i}") rpc.barrier(set(even_num_workers)) rpc.shutdown() print(f"shutdown{i}") if __name__ == '__main__': with mp.Pool(processes=world_size) as pool: pool.map(worker, range(world_size)) ``` Currently in draft mode as partial barrier gets stuck and need to follow up on discussion. [ghstack-poisoned]
closes #40166 This change exposes a new API, rpc.barrier() which blocks the main processes of all workers running RPC until the whole group completes this function. Optionally rpc.barrier can take in a set of worker_names and only synchronize across those worker names. Example: ```python import os import torch.multiprocessing as mp import torch.distributed.rpc as rpc os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "5678" world_size = 4 odd_num_workers = [f"worker{i}" for i in range(world_size) if i % 2] even_num_workers = [f"worker{i}" for i in range(world_size) if not i % 2] def worker(i): print(i) rpc.init_rpc(f"worker{i}", rank=i, world_size=world_size) if i % 2: print(f"start barrier {i}") rpc.barrier(set(odd_num_workers)) else: print(f"start barrier {i}") rpc.barrier(set(even_num_workers)) rpc.shutdown() print(f"shutdown{i}") if __name__ == '__main__': with mp.Pool(processes=world_size) as pool: pool.map(worker, range(world_size)) ``` Currently in draft mode as partial barrier gets stuck and need to follow up on discussion. [ghstack-poisoned]
closes #40166 This change exposes a new API, rpc.barrier() which blocks the main processes of all workers running RPC until the whole group completes this function. Optionally rpc.barrier can take in a set of worker_names and only synchronize across those worker names. Example: ```python import os import torch.multiprocessing as mp import torch.distributed.rpc as rpc os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "5678" world_size = 4 odd_num_workers = [f"worker{i}" for i in range(world_size) if i % 2] even_num_workers = [f"worker{i}" for i in range(world_size) if not i % 2] def worker(i): print(i) rpc.init_rpc(f"worker{i}", rank=i, world_size=world_size) if i % 2: print(f"start barrier {i}") rpc.barrier(set(odd_num_workers)) else: print(f"start barrier {i}") rpc.barrier(set(even_num_workers)) rpc.shutdown() print(f"shutdown{i}") if __name__ == '__main__': with mp.Pool(processes=world_size) as pool: pool.map(worker, range(world_size)) ``` Currently in draft mode as partial barrier gets stuck and need to follow up on discussion. [ghstack-poisoned]
closes #40166 This change exposes a new API, rpc.barrier() which blocks the main processes of all workers running RPC until the whole group completes this function. Optionally rpc.barrier can take in a set of worker_names and only synchronize across those worker names. Example: ```python import os import torch.multiprocessing as mp import torch.distributed.rpc as rpc os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "5678" world_size = 4 odd_num_workers = [f"worker{i}" for i in range(world_size) if i % 2] even_num_workers = [f"worker{i}" for i in range(world_size) if not i % 2] def worker(i): print(i) rpc.init_rpc(f"worker{i}", rank=i, world_size=world_size) if i % 2: print(f"start barrier {i}") rpc.barrier(odd_num_workers) else: print(f"start barrier {i}") rpc.barrier(even_num_workers) rpc.shutdown() print(f"shutdown{i}") if __name__ == '__main__': with mp.Pool(processes=world_size) as pool: pool.map(worker, range(world_size)) ``` Output: ``` 0 1 2 3 start barrier 0 start barrier 3 start barrier 2 start barrier 1 shutdown2 shutdown3 shutdown1 shutdown0 ``` Differential Revision: [D27737145](https://our.internmc.facebook.com/intern/diff/D27737145) [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looping in @kiukchung to make sure this does not conflicts with the control barrier idea. IIUC, rpc.barrier is irrelevant to the control barrier @kiukchung proposed and both can be useful. Just wanna double check. :)
torch/distributed/rpc/__init__.py
Outdated
| worker_names (List[str], optional): The set of workers to synchronize. If ``None``, the | ||
| set will be all workers. Default is ``None``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we might need one more indent on the second line. Shall we build the docs to verify this can render correctly?
torch/distributed/rpc/__init__.py
Outdated
| worker_names (List[str], optional): The set of workers to synchronize. If ``None``, the | ||
| set will be all workers. Default is ``None``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If
None, the set will be all workers. Default isNone.
This will be a little tricky after we add dynamic membership. As that would require all processes always have a consistent view on the current gang membership, which can be costly and hard. To avoid that potential problem, shall we require worker_names to be not None for now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two Questions:
- what if there are concurrent barriers, e.g., two current barriers between [work0, work1] and [worker2, worker3]. Will this still work?
- If one of the worker specified a wrong list, what gonna happen?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great points. I will change to the worker_names to be a required argument. For the questions:
- Yes, this works.
- If the list has workers that doesn't exist, then it would hang and eventually timeout.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the list has workers that doesn't exist, then it would hang and eventually timeout.
I was also concerned about the following case:
worker0: worker_names=["worker0", "worker1", "worker2"]
worker1: worker_names=["worker0", "worker1", "worker2"]
worker2: worker_names=["worker1", "worker2"]
Do we know what gonna happen in this case? I guess it will also timeout as worker0 serves as the leader for the first two and worker1 serves as the leader for worker2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this situation would also time out, which I think should be expected behavior of the API? For the implementation right now, if one worker calls barrier(<list_of_workers>) it expects all the workers in <ilst_of_workers> to also call barrier(<list_of_workers>) otherwise it will hang.
torch/distributed/rpc/__init__.py
Outdated
| This will block until all local and remote RPC processes specified under worker_names | ||
| reach this method to wait for all outstanding work to complete. | ||
| Args: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we add an example section to this API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changes LGTM in general. The main concern I had is that, once we expose this as a public API, we will need to have better coverage for corner cases and error handlings. Could you please add some tests to make sure that when one of the workers specifies a wrong list of worker names, our error message is reasonable? Concurrent barrier might be harder to support. If we don't plan to do it in the first version, let's make it clear in the doc.
@rohan-varma do you know who might be interested in trying this API?
closes #40166 This change exposes a new API, rpc.barrier() which blocks the main processes of all workers running RPC until the whole group completes this function. Optionally rpc.barrier can take in a set of worker_names and only synchronize across those worker names. Example: ```python import os import torch.multiprocessing as mp import torch.distributed.rpc as rpc os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "5678" world_size = 4 odd_num_workers = [f"worker{i}" for i in range(world_size) if i % 2] even_num_workers = [f"worker{i}" for i in range(world_size) if not i % 2] def worker(i): print(i) rpc.init_rpc(f"worker{i}", rank=i, world_size=world_size) if i % 2: print(f"start barrier {i}") rpc.barrier(odd_num_workers) else: print(f"start barrier {i}") rpc.barrier(even_num_workers) rpc.shutdown() print(f"shutdown{i}") if __name__ == '__main__': with mp.Pool(processes=world_size) as pool: pool.map(worker, range(world_size)) ``` Output: ``` 0 1 2 3 start barrier 0 start barrier 3 start barrier 2 start barrier 1 shutdown2 shutdown3 shutdown1 shutdown0 ``` Differential Revision: [D27737145](https://our.internmc.facebook.com/intern/diff/D27737145) [ghstack-poisoned]
Thanks for looping me in @mrshenli! Shouldn't conflict, but semantically |
Sorry, missed this comment. I'm not aware of any immediate users, but I think it would be pretty invaluable for folks debugging their applications when looking at possible correctness issues/race conditions. |
torch/distributed/rpc/api.py
Outdated
| worker_names = _ALL_WORKER_NAMES | ||
| assert ( | ||
| worker_name in worker_names | ||
| ), "{worker_name} is not expected by leader.".format(worker_name=worker_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: f-string?
|
|
||
|
|
||
| @_require_initialized | ||
| def _barrier(worker_names): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we make it a public API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had discussed this with Shen a few weeks ago as the PR was originally for public API. But since we are unsure of the people who will use it and also how the API may change when introducing Elastic RPC, decided to keep it internal for now.
torch/distributed/rpc/api.py
Outdated
| _all_gather(None, set(worker_names), timeout=DEFAULT_SHUTDOWN_TIMEOUT) | ||
| except RuntimeError as ex: | ||
| logger.error( | ||
| f"Failed to respond to 'Shutdown Proceed' in time, got error {ex}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is being used for non-shutdown case, this error might be a bit confusing (says "shutdown proceed")?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, I think i had copy-pasted from shutdown() 🙂, will change it.
| info = rpc.get_worker_info() | ||
| all_worker_info = rpc._get_current_rpc_agent().get_worker_infos() | ||
| names = [worker.name for worker in all_worker_info] | ||
| rpc.api._barrier(names) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally a test should assert something to test the functionality, though admittedly it is a bit hard for barrier. Could we do something like, before the barrier, assert some global count is 0, then send RPC to each node to increment the global count, and then call barrier, then assert the count == world_size? This will ensure barrier properly waits for all outstanding work. Open to better ideas though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point. I'll implement something along the count idea. I'll also add something to test the negative case where barrier does actually time out.
| sequence_id = _all_gather_sequence_id | ||
| _all_gather_sequence_id += 1 | ||
| concat_names = "".join(sorted(worker_names)) | ||
| sequence_num = _all_gather_sequence_id.get(concat_names, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that we properly clean up this key for _all_gather_sequence_id below when exiting the barrier so that future calls with the same set of worker_names will start at 0 as expected.
However, is it possible to still run into races here? If there are concurrent calls to barrier() (i.e. multiple threads) with the same set of workers:
Thread 1 creates the key in _all_gather_sequence_id
Thread 1 releases lock
Thread 2 calls with the same set of arguments, so it sees an already incremented sequence_num, which could cause the tracking of states to be incorrect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The cleanup done at the end of barrier is for _all_gather_sequence_id_to_states. This var keeps tracks of the barrier state for a single sequence id. This cleanup is needed because there is an existing memory leak (small one) where each call to _all_gather will leave behind an AllGatherStates class in the global dictionary.
The _all_gather_sequence_id is an ID representing a particular barrier operation. It is strictly increasing so that concurrent calls will still work. I'll rewrite your example above to demonstrate what is happening:
- worker1, Thread 1 performs barrier on ["worker1", "worker2"], retrieve
_all_gather_sequence_id(value 0), increment_all_gather_sequence_id - worker1, Thread 1 release lock
- worker1, Thread 2 performs barrier on ["worker1", "worker2"], retrieve
_all_gather_sequence_id(value 1), increment_all_gather_sequence_id - Thread1 and Thread2 block waiting on the barrier from worker2
- When worker2 calls barrier ["worker1", "worker2"] the first time, it will allow worker1, Thread1 to make progress. The second call to barrier will allow worker1, Thread2 to make progress.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. Could we add a unittest to verify this multithreaded behavior?
closes #40166 This change exposes a new internal API, rpc.api._barrier() which blocks the processes of workers running RPC until the specified group of workers completes this function. Example: ```python import os import torch.multiprocessing as mp import torch.distributed.rpc as rpc os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "5678" world_size = 4 odd_num_workers = [f"worker{i}" for i in range(world_size) if i % 2] even_num_workers = [f"worker{i}" for i in range(world_size) if not i % 2] def worker(i): print(i) rpc.init_rpc(f"worker{i}", rank=i, world_size=world_size) if i % 2: print(f"start barrier on {i} for {odd_num_workers}") rpc.api._barrier(odd_num_workers) else: print(f"start barrier on {i} for {even_num_workers}") rpc.api._barrier(even_num_workers) rpc.shutdown() print(f"shutdown{i}") if __name__ == '__main__': with mp.Pool(processes=world_size) as pool: pool.map(worker, range(world_size)) ``` Output: ``` 0 1 2 3 start barrier on 0 for ['worker0', 'worker2'] start barrier on 1 for ['worker1', 'worker3'] start barrier on 2 for ['worker0', 'worker2'] start barrier on 3 for ['worker1', 'worker3'] shutdown2 shutdown3 shutdown1 shutdown0 ``` Differential Revision: [D27737145](https://our.internmc.facebook.com/intern/diff/D27737145) [ghstack-poisoned]
closes #40166 This change exposes a new internal API, rpc.api._barrier() which blocks the processes of workers running RPC until the specified group of workers completes this function. Example: ```python import os import torch.multiprocessing as mp import torch.distributed.rpc as rpc os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "5678" world_size = 4 odd_num_workers = [f"worker{i}" for i in range(world_size) if i % 2] even_num_workers = [f"worker{i}" for i in range(world_size) if not i % 2] def worker(i): print(i) rpc.init_rpc(f"worker{i}", rank=i, world_size=world_size) if i % 2: print(f"start barrier on {i} for {odd_num_workers}") rpc.api._barrier(odd_num_workers) else: print(f"start barrier on {i} for {even_num_workers}") rpc.api._barrier(even_num_workers) rpc.shutdown() print(f"shutdown{i}") if __name__ == '__main__': with mp.Pool(processes=world_size) as pool: pool.map(worker, range(world_size)) ``` Output: ``` 0 1 2 3 start barrier on 0 for ['worker0', 'worker2'] start barrier on 1 for ['worker1', 'worker3'] start barrier on 2 for ['worker0', 'worker2'] start barrier on 3 for ['worker1', 'worker3'] shutdown2 shutdown3 shutdown1 shutdown0 ``` Differential Revision: [D27737145](https://our.internmc.facebook.com/intern/diff/D27737145) [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Thanks for adding the multithreading tests as we discussed.
| reach this method to wait for all outstanding work to complete. | ||
| Args: | ||
| worker_names (List[str]): The set of workers to synchronize. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can be added in a follow up diff, but should we also accept WorkerInfo or int for WorkerId here? The rest of the RPC APIs seem to accept either so might be good to add that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, will do.
closes #40166 This change exposes a new internal API, rpc.api._barrier() which blocks the processes of workers running RPC until the specified group of workers completes this function. Example: ```python import os import torch.multiprocessing as mp import torch.distributed.rpc as rpc os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "5678" world_size = 4 odd_num_workers = [f"worker{i}" for i in range(world_size) if i % 2] even_num_workers = [f"worker{i}" for i in range(world_size) if not i % 2] def worker(i): print(i) rpc.init_rpc(f"worker{i}", rank=i, world_size=world_size) if i % 2: print(f"start barrier on {i} for {odd_num_workers}") rpc.api._barrier(odd_num_workers) else: print(f"start barrier on {i} for {even_num_workers}") rpc.api._barrier(even_num_workers) rpc.shutdown() print(f"shutdown{i}") if __name__ == '__main__': with mp.Pool(processes=world_size) as pool: pool.map(worker, range(world_size)) ``` Output: ``` 0 1 2 3 start barrier on 0 for ['worker0', 'worker2'] start barrier on 1 for ['worker1', 'worker3'] start barrier on 2 for ['worker0', 'worker2'] start barrier on 3 for ['worker1', 'worker3'] shutdown2 shutdown3 shutdown1 shutdown0 ``` Differential Revision: [D27737145](https://our.internmc.facebook.com/intern/diff/D27737145) [ghstack-poisoned]
|
@H-Huang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Summary: Pull Request resolved: pytorch#53423 closes pytorch#40166 This change exposes a new API, rpc.barrier() which blocks the main processes of all workers running RPC until the whole group completes this function. Optionally rpc.barrier can take in a set of worker_names and only synchronize across those worker names. Example: ```python import os import torch.multiprocessing as mp import torch.distributed.rpc as rpc os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "5678" world_size = 4 odd_num_workers = [f"worker{i}" for i in range(world_size) if i % 2] even_num_workers = [f"worker{i}" for i in range(world_size) if not i % 2] def worker(i): print(i) rpc.init_rpc(f"worker{i}", rank=i, world_size=world_size) if i % 2: print(f"start barrier {i}") rpc.barrier(set(odd_num_workers)) else: print(f"start barrier {i}") rpc.barrier(set(even_num_workers)) rpc.shutdown() print(f"shutdown{i}") if __name__ == '__main__': with mp.Pool(processes=world_size) as pool: pool.map(worker, range(world_size)) ``` Test Plan: Imported from OSS Reviewed By: rohan-varma Differential Revision: D27737145 Pulled By: H-Huang fbshipit-source-id: 369196bc62446f506d1fb6a3fa5bebcb0b09da9f
Stack from ghstack:
closes #40166
This change exposes a new internal API, rpc.api._barrier() which blocks the processes of workers running RPC until the specified group of workers completes this function.
Example:
Output:
Differential Revision: D27737145