KEMBAR78
Add rpc.api._barrier() by H-Huang · Pull Request #53423 · pytorch/pytorch · GitHub
Skip to content

Conversation

@H-Huang
Copy link
Member

@H-Huang H-Huang commented Mar 5, 2021

Stack from ghstack:

closes #40166

This change exposes a new internal API, rpc.api._barrier() which blocks the processes of workers running RPC until the specified group of workers completes this function.

Example:

import os
import torch.multiprocessing as mp
import torch.distributed.rpc as rpc
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "5678"

world_size = 4
odd_num_workers = [f"worker{i}" for i in range(world_size) if i % 2]
even_num_workers = [f"worker{i}" for i in range(world_size) if not i % 2]

def worker(i):
    print(i)
    rpc.init_rpc(f"worker{i}", rank=i, world_size=world_size)
    if i % 2:
        print(f"start barrier on {i} for {odd_num_workers}")
        rpc.api._barrier(odd_num_workers)
    else:
        print(f"start barrier on {i} for {even_num_workers}")
        rpc.api._barrier(even_num_workers)
    rpc.shutdown()
    print(f"shutdown{i}")

if __name__ == '__main__':
    with mp.Pool(processes=world_size) as pool:
        pool.map(worker, range(world_size))

Output:

0
1
2
3
start barrier on 0 for ['worker0', 'worker2']
start barrier on 1 for ['worker1', 'worker3']
start barrier on 2 for ['worker0', 'worker2']
start barrier on 3 for ['worker1', 'worker3']
shutdown2
shutdown3
shutdown1
shutdown0

Differential Revision: D27737145

[ghstack-poisoned]
H-Huang added a commit that referenced this pull request Mar 5, 2021
ghstack-source-id: cca771d
Pull Request resolved: #53423
@facebook-github-bot facebook-github-bot added cla signed oncall: distributed Add this issue/PR to distributed oncall triage queue labels Mar 5, 2021
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Mar 5, 2021

💊 CI failures summary and remediations

As of commit 1143616 (more details on the Dr. CI page):


  • 3/3 failures introduced in this PR

🕵️ 2 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_linux_xenial_py3_clang7_onnx_ort_test1 (1/2)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Jun 02 19:11:29 ../../../../opt/conda/lib/pytho..._test.py::TestCaffe2Basic::test_cast FAILED [ 22%]
Jun 02 19:11:28 ../../../../opt/conda/lib/python3.6/site-packages/caffe2/python/models/seq2seq/seq2seq_beam_search_test.py::Seq2SeqBeamSearchTest::test_2layer_attention PASSED [ 22%]
Jun 02 19:11:28 ../../../../opt/conda/lib/python3.6/site-packages/caffe2/python/models/seq2seq/seq2seq_beam_search_test.py::Seq2SeqBeamSearchTest::test_attention PASSED [ 22%]
Jun 02 19:11:28 ../../../../opt/conda/lib/python3.6/site-packages/caffe2/python/models/seq2seq/seq2seq_beam_search_test.py::Seq2SeqBeamSearchTest::test_multi_decoder PASSED [ 22%]
Jun 02 19:11:28 ../../../../opt/conda/lib/python3.6/site-packages/caffe2/python/models/seq2seq/seq2seq_model_helper_test.py::Seq2SeqModelHelperTest::testAddParam PASSED [ 22%]
Jun 02 19:11:28 ../../../../opt/conda/lib/python3.6/site-packages/caffe2/python/models/seq2seq/seq2seq_model_helper_test.py::Seq2SeqModelHelperTest::testConstuctor PASSED [ 22%]
Jun 02 19:11:28 ../../../../opt/conda/lib/python3.6/site-packages/caffe2/python/models/seq2seq/seq2seq_model_helper_test.py::Seq2SeqModelHelperTest::testGetAllParams PASSED [ 22%]
Jun 02 19:11:28 ../../../../opt/conda/lib/python3.6/site-packages/caffe2/python/models/seq2seq/seq2seq_model_helper_test.py::Seq2SeqModelHelperTest::testGetNonTrainableParams PASSED [ 22%]
Jun 02 19:11:28 ../../../../opt/conda/lib/python3.6/site-packages/caffe2/python/onnx/test_onnxifi.py::OnnxifiTest::test_conv_graph SKIPPED [ 22%]
Jun 02 19:11:29 ../../../../opt/conda/lib/python3.6/site-packages/caffe2/python/onnx/test_onnxifi.py::OnnxifiTest::test_relu_graph SKIPPED [ 22%]
Jun 02 19:11:29 ../../../../opt/conda/lib/python3.6/site-packages/caffe2/python/onnx/test_onnxifi.py::OnnxifiTransformTest::test_resnet50_core SKIPPED [ 22%]
Jun 02 19:11:29 ../../../../opt/conda/lib/python3.6/site-packages/caffe2/python/onnx/tests/c2_ref_test.py::TestCaffe2Basic::test_cast FAILED [ 22%]
Jun 02 19:11:29 
Jun 02 19:11:29 =================================== FAILURES ===================================
Jun 02 19:11:29 __________________________ TestCaffe2Basic.test_cast ___________________________
Jun 02 19:11:29 
Jun 02 19:11:29 self = <caffe2.python.onnx.tests.c2_ref_test.TestCaffe2Basic testMethod=test_cast>
Jun 02 19:11:29 
Jun 02 19:11:29     def test_cast(self):
Jun 02 19:11:29         X = np.random.randn(1, 2, 3).astype(np.float32)
Jun 02 19:11:29     
Jun 02 19:11:29         for to_type in ['INT8', caffe2_pb2.TensorProto.INT8,

See CircleCI build pytorch_linux_xenial_py3_clang7_onnx_ort_test2 (2/2)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Jun 02 19:09:57 ../../../../opt/conda/lib/pytho..._test.py::TestCaffe2Basic::test_cast FAILED [ 22%]
Jun 02 19:09:57 ../../../../opt/conda/lib/python3.6/site-packages/caffe2/python/models/seq2seq/seq2seq_beam_search_test.py::Seq2SeqBeamSearchTest::test_2layer_attention PASSED [ 22%]
Jun 02 19:09:57 ../../../../opt/conda/lib/python3.6/site-packages/caffe2/python/models/seq2seq/seq2seq_beam_search_test.py::Seq2SeqBeamSearchTest::test_attention PASSED [ 22%]
Jun 02 19:09:57 ../../../../opt/conda/lib/python3.6/site-packages/caffe2/python/models/seq2seq/seq2seq_beam_search_test.py::Seq2SeqBeamSearchTest::test_multi_decoder PASSED [ 22%]
Jun 02 19:09:57 ../../../../opt/conda/lib/python3.6/site-packages/caffe2/python/models/seq2seq/seq2seq_model_helper_test.py::Seq2SeqModelHelperTest::testAddParam PASSED [ 22%]
Jun 02 19:09:57 ../../../../opt/conda/lib/python3.6/site-packages/caffe2/python/models/seq2seq/seq2seq_model_helper_test.py::Seq2SeqModelHelperTest::testConstuctor PASSED [ 22%]
Jun 02 19:09:57 ../../../../opt/conda/lib/python3.6/site-packages/caffe2/python/models/seq2seq/seq2seq_model_helper_test.py::Seq2SeqModelHelperTest::testGetAllParams PASSED [ 22%]
Jun 02 19:09:57 ../../../../opt/conda/lib/python3.6/site-packages/caffe2/python/models/seq2seq/seq2seq_model_helper_test.py::Seq2SeqModelHelperTest::testGetNonTrainableParams PASSED [ 22%]
Jun 02 19:09:57 ../../../../opt/conda/lib/python3.6/site-packages/caffe2/python/onnx/test_onnxifi.py::OnnxifiTest::test_conv_graph SKIPPED [ 22%]
Jun 02 19:09:57 ../../../../opt/conda/lib/python3.6/site-packages/caffe2/python/onnx/test_onnxifi.py::OnnxifiTest::test_relu_graph SKIPPED [ 22%]
Jun 02 19:09:57 ../../../../opt/conda/lib/python3.6/site-packages/caffe2/python/onnx/test_onnxifi.py::OnnxifiTransformTest::test_resnet50_core SKIPPED [ 22%]
Jun 02 19:09:57 ../../../../opt/conda/lib/python3.6/site-packages/caffe2/python/onnx/tests/c2_ref_test.py::TestCaffe2Basic::test_cast FAILED [ 22%]
Jun 02 19:09:57 
Jun 02 19:09:57 =================================== FAILURES ===================================
Jun 02 19:09:57 __________________________ TestCaffe2Basic.test_cast ___________________________
Jun 02 19:09:57 
Jun 02 19:09:57 self = <caffe2.python.onnx.tests.c2_ref_test.TestCaffe2Basic testMethod=test_cast>
Jun 02 19:09:57 
Jun 02 19:09:57     def test_cast(self):
Jun 02 19:09:57         X = np.random.randn(1, 2, 3).astype(np.float32)
Jun 02 19:09:57     
Jun 02 19:09:57         for to_type in ['INT8', caffe2_pb2.TensorProto.INT8,

1 failure not recognized by patterns:

Job Step Action
CircleCI pytorch_linux_xenial_py3_clang5_asan_test2 Run tests 🔁 rerun

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@H-Huang H-Huang marked this pull request as draft March 5, 2021 23:22
@H-Huang H-Huang changed the title add rpc.barrier Add new rpc.barrier API Mar 5, 2021
@H-Huang H-Huang changed the title Add new rpc.barrier API [WIP] Add new rpc.barrier API Mar 6, 2021
Copy link
Contributor

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be quite useful to have especially for RPC debugability, are there still plans to iterate on this? @H-Huang

@H-Huang
Copy link
Member Author

H-Huang commented Apr 7, 2021

Yes, I'll finish the changes needed for this

closes #40166

This change exposes a new API, rpc.barrier() which blocks the main processes of all workers running RPC until the whole group completes this function. Optionally rpc.barrier can take in a set of worker_names and only synchronize across those worker names. 

Example:
```python
import os
import torch.multiprocessing as mp
import torch.distributed.rpc as rpc
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "5678"

world_size = 4
odd_num_workers = [f"worker{i}" for i in range(world_size) if i % 2]
even_num_workers = [f"worker{i}" for i in range(world_size) if not i % 2]

def worker(i):
    print(i)
    rpc.init_rpc(f"worker{i}", rank=i, world_size=world_size)
    if i % 2:
        print(f"start barrier {i}")
        rpc.barrier(set(odd_num_workers))
    else:
        print(f"start barrier {i}")
        rpc.barrier(set(even_num_workers))
    rpc.shutdown()
    print(f"shutdown{i}")

if __name__ == '__main__':
    with mp.Pool(processes=world_size) as pool:
        pool.map(worker, range(world_size))
```


Currently in draft mode as partial barrier gets stuck and need to follow up on discussion.

[ghstack-poisoned]
closes #40166

This change exposes a new API, rpc.barrier() which blocks the main processes of all workers running RPC until the whole group completes this function. Optionally rpc.barrier can take in a set of worker_names and only synchronize across those worker names. 

Example:
```python
import os
import torch.multiprocessing as mp
import torch.distributed.rpc as rpc
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "5678"

world_size = 4
odd_num_workers = [f"worker{i}" for i in range(world_size) if i % 2]
even_num_workers = [f"worker{i}" for i in range(world_size) if not i % 2]

def worker(i):
    print(i)
    rpc.init_rpc(f"worker{i}", rank=i, world_size=world_size)
    if i % 2:
        print(f"start barrier {i}")
        rpc.barrier(set(odd_num_workers))
    else:
        print(f"start barrier {i}")
        rpc.barrier(set(even_num_workers))
    rpc.shutdown()
    print(f"shutdown{i}")

if __name__ == '__main__':
    with mp.Pool(processes=world_size) as pool:
        pool.map(worker, range(world_size))
```


Currently in draft mode as partial barrier gets stuck and need to follow up on discussion.

[ghstack-poisoned]
H-Huang added a commit that referenced this pull request Apr 13, 2021
ghstack-source-id: 7684af9
Pull Request resolved: #53423
closes #40166

This change exposes a new API, rpc.barrier() which blocks the main processes of all workers running RPC until the whole group completes this function. Optionally rpc.barrier can take in a set of worker_names and only synchronize across those worker names. 

Example:
```python
import os
import torch.multiprocessing as mp
import torch.distributed.rpc as rpc
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "5678"

world_size = 4
odd_num_workers = [f"worker{i}" for i in range(world_size) if i % 2]
even_num_workers = [f"worker{i}" for i in range(world_size) if not i % 2]

def worker(i):
    print(i)
    rpc.init_rpc(f"worker{i}", rank=i, world_size=world_size)
    if i % 2:
        print(f"start barrier {i}")
        rpc.barrier(set(odd_num_workers))
    else:
        print(f"start barrier {i}")
        rpc.barrier(set(even_num_workers))
    rpc.shutdown()
    print(f"shutdown{i}")

if __name__ == '__main__':
    with mp.Pool(processes=world_size) as pool:
        pool.map(worker, range(world_size))
```


Currently in draft mode as partial barrier gets stuck and need to follow up on discussion.

[ghstack-poisoned]
@H-Huang H-Huang changed the title [WIP] Add new rpc.barrier API Add new rpc.barrier API Apr 13, 2021
H-Huang added a commit that referenced this pull request Apr 13, 2021
ghstack-source-id: 9729189
Pull Request resolved: #53423
@H-Huang H-Huang marked this pull request as ready for review April 13, 2021 15:36
@H-Huang H-Huang requested a review from wayi1 as a code owner April 13, 2021 15:36
closes #40166

This change exposes a new API, rpc.barrier() which blocks the main processes of all workers running RPC until the whole group completes this function. Optionally rpc.barrier can take in a set of worker_names and only synchronize across those worker names. 

Example:
```python
import os
import torch.multiprocessing as mp
import torch.distributed.rpc as rpc
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "5678"

world_size = 4
odd_num_workers = [f"worker{i}" for i in range(world_size) if i % 2]
even_num_workers = [f"worker{i}" for i in range(world_size) if not i % 2]

def worker(i):
    print(i)
    rpc.init_rpc(f"worker{i}", rank=i, world_size=world_size)
    if i % 2:
        print(f"start barrier {i}")
        rpc.barrier(odd_num_workers)
    else:
        print(f"start barrier {i}")
        rpc.barrier(even_num_workers)
    rpc.shutdown()
    print(f"shutdown{i}")

if __name__ == '__main__':
    with mp.Pool(processes=world_size) as pool:
        pool.map(worker, range(world_size))
```

Output:

```
0
1
2
3
start barrier 0
start barrier 3
start barrier 2
start barrier 1
shutdown2
shutdown3
shutdown1
shutdown0
```

Differential Revision: [D27737145](https://our.internmc.facebook.com/intern/diff/D27737145)

[ghstack-poisoned]
H-Huang added a commit that referenced this pull request Apr 13, 2021
ghstack-source-id: 4df7fd0
Pull Request resolved: #53423
Copy link
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looping in @kiukchung to make sure this does not conflicts with the control barrier idea. IIUC, rpc.barrier is irrelevant to the control barrier @kiukchung proposed and both can be useful. Just wanna double check. :)

Comment on lines 210 to 211
worker_names (List[str], optional): The set of workers to synchronize. If ``None``, the
set will be all workers. Default is ``None``.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might need one more indent on the second line. Shall we build the docs to verify this can render correctly?

Comment on lines 210 to 211
worker_names (List[str], optional): The set of workers to synchronize. If ``None``, the
set will be all workers. Default is ``None``.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If None, the set will be all workers. Default is None.

This will be a little tricky after we add dynamic membership. As that would require all processes always have a consistent view on the current gang membership, which can be costly and hard. To avoid that potential problem, shall we require worker_names to be not None for now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two Questions:

  1. what if there are concurrent barriers, e.g., two current barriers between [work0, work1] and [worker2, worker3]. Will this still work?
  2. If one of the worker specified a wrong list, what gonna happen?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great points. I will change to the worker_names to be a required argument. For the questions:

  1. Yes, this works.
  2. If the list has workers that doesn't exist, then it would hang and eventually timeout.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the list has workers that doesn't exist, then it would hang and eventually timeout.

I was also concerned about the following case:

worker0: worker_names=["worker0", "worker1", "worker2"]
worker1: worker_names=["worker0", "worker1", "worker2"]
worker2: worker_names=["worker1", "worker2"]

Do we know what gonna happen in this case? I guess it will also timeout as worker0 serves as the leader for the first two and worker1 serves as the leader for worker2.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this situation would also time out, which I think should be expected behavior of the API? For the implementation right now, if one worker calls barrier(<list_of_workers>) it expects all the workers in <ilst_of_workers> to also call barrier(<list_of_workers>) otherwise it will hang.

This will block until all local and remote RPC processes specified under worker_names
reach this method to wait for all outstanding work to complete.
Args:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add an example section to this API?

Copy link
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changes LGTM in general. The main concern I had is that, once we expose this as a public API, we will need to have better coverage for corner cases and error handlings. Could you please add some tests to make sure that when one of the workers specifies a wrong list of worker names, our error message is reasonable? Concurrent barrier might be harder to support. If we don't plan to do it in the first version, let's make it clear in the doc.

@rohan-varma do you know who might be interested in trying this API?

closes #40166

This change exposes a new API, rpc.barrier() which blocks the main processes of all workers running RPC until the whole group completes this function. Optionally rpc.barrier can take in a set of worker_names and only synchronize across those worker names. 

Example:
```python
import os
import torch.multiprocessing as mp
import torch.distributed.rpc as rpc
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "5678"

world_size = 4
odd_num_workers = [f"worker{i}" for i in range(world_size) if i % 2]
even_num_workers = [f"worker{i}" for i in range(world_size) if not i % 2]

def worker(i):
    print(i)
    rpc.init_rpc(f"worker{i}", rank=i, world_size=world_size)
    if i % 2:
        print(f"start barrier {i}")
        rpc.barrier(odd_num_workers)
    else:
        print(f"start barrier {i}")
        rpc.barrier(even_num_workers)
    rpc.shutdown()
    print(f"shutdown{i}")

if __name__ == '__main__':
    with mp.Pool(processes=world_size) as pool:
        pool.map(worker, range(world_size))
```

Output:

```
0
1
2
3
start barrier 0
start barrier 3
start barrier 2
start barrier 1
shutdown2
shutdown3
shutdown1
shutdown0
```

Differential Revision: [D27737145](https://our.internmc.facebook.com/intern/diff/D27737145)

[ghstack-poisoned]
H-Huang added a commit that referenced this pull request Apr 13, 2021
ghstack-source-id: 320bd23
Pull Request resolved: #53423
@H-Huang H-Huang changed the title Add new rpc.barrier API Add rpc.api._barrier() Apr 13, 2021
@kiukchung
Copy link
Collaborator

Looping in @kiukchung to make sure this does not conflicts with the control barrier idea. IIUC, rpc.barrier is irrelevant to the control barrier @kiukchung proposed and both can be useful. Just wanna double check. :)

Thanks for looping me in @mrshenli! Shouldn't conflict, but semantically torch.runtime.barrrier()would do exactly the same thing with the only advantage being that it does not depend on RPC (nor process groups) and rather depends on the the elastic agent (separate process), making it more robust.

@rohan-varma
Copy link
Contributor

@rohan-varma do you know who might be interested in trying this API?

Sorry, missed this comment. I'm not aware of any immediate users, but I think it would be pretty invaluable for folks debugging their applications when looking at possible correctness issues/race conditions.

worker_names = _ALL_WORKER_NAMES
assert (
worker_name in worker_names
), "{worker_name} is not expected by leader.".format(worker_name=worker_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: f-string?



@_require_initialized
def _barrier(worker_names):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make it a public API?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had discussed this with Shen a few weeks ago as the PR was originally for public API. But since we are unsure of the people who will use it and also how the API may change when introducing Elastic RPC, decided to keep it internal for now.

_all_gather(None, set(worker_names), timeout=DEFAULT_SHUTDOWN_TIMEOUT)
except RuntimeError as ex:
logger.error(
f"Failed to respond to 'Shutdown Proceed' in time, got error {ex}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is being used for non-shutdown case, this error might be a bit confusing (says "shutdown proceed")?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I think i had copy-pasted from shutdown() 🙂, will change it.

info = rpc.get_worker_info()
all_worker_info = rpc._get_current_rpc_agent().get_worker_infos()
names = [worker.name for worker in all_worker_info]
rpc.api._barrier(names)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally a test should assert something to test the functionality, though admittedly it is a bit hard for barrier. Could we do something like, before the barrier, assert some global count is 0, then send RPC to each node to increment the global count, and then call barrier, then assert the count == world_size? This will ensure barrier properly waits for all outstanding work. Open to better ideas though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. I'll implement something along the count idea. I'll also add something to test the negative case where barrier does actually time out.

sequence_id = _all_gather_sequence_id
_all_gather_sequence_id += 1
concat_names = "".join(sorted(worker_names))
sequence_num = _all_gather_sequence_id.get(concat_names, 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that we properly clean up this key for _all_gather_sequence_id below when exiting the barrier so that future calls with the same set of worker_names will start at 0 as expected.

However, is it possible to still run into races here? If there are concurrent calls to barrier() (i.e. multiple threads) with the same set of workers:

Thread 1 creates the key in _all_gather_sequence_id
Thread 1 releases lock
Thread 2 calls with the same set of arguments, so it sees an already incremented sequence_num, which could cause the tracking of states to be incorrect.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cleanup done at the end of barrier is for _all_gather_sequence_id_to_states. This var keeps tracks of the barrier state for a single sequence id. This cleanup is needed because there is an existing memory leak (small one) where each call to _all_gather will leave behind an AllGatherStates class in the global dictionary.

The _all_gather_sequence_id is an ID representing a particular barrier operation. It is strictly increasing so that concurrent calls will still work. I'll rewrite your example above to demonstrate what is happening:

  • worker1, Thread 1 performs barrier on ["worker1", "worker2"], retrieve _all_gather_sequence_id (value 0), increment _all_gather_sequence_id
  • worker1, Thread 1 release lock
  • worker1, Thread 2 performs barrier on ["worker1", "worker2"], retrieve _all_gather_sequence_id (value 1), increment _all_gather_sequence_id
  • Thread1 and Thread2 block waiting on the barrier from worker2
  • When worker2 calls barrier ["worker1", "worker2"] the first time, it will allow worker1, Thread1 to make progress. The second call to barrier will allow worker1, Thread2 to make progress.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Could we add a unittest to verify this multithreaded behavior?

closes #40166

This change exposes a new internal API, rpc.api._barrier() which blocks the processes of workers running RPC until the specified group of workers completes this function.

Example:
```python
import os
import torch.multiprocessing as mp
import torch.distributed.rpc as rpc
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "5678"

world_size = 4
odd_num_workers = [f"worker{i}" for i in range(world_size) if i % 2]
even_num_workers = [f"worker{i}" for i in range(world_size) if not i % 2]

def worker(i):
    print(i)
    rpc.init_rpc(f"worker{i}", rank=i, world_size=world_size)
    if i % 2:
        print(f"start barrier on {i} for {odd_num_workers}")
        rpc.api._barrier(odd_num_workers)
    else:
        print(f"start barrier on {i} for {even_num_workers}")
        rpc.api._barrier(even_num_workers)
    rpc.shutdown()
    print(f"shutdown{i}")

if __name__ == '__main__':
    with mp.Pool(processes=world_size) as pool:
        pool.map(worker, range(world_size))
```

Output:

```
0
1
2
3
start barrier on 0 for ['worker0', 'worker2']
start barrier on 1 for ['worker1', 'worker3']
start barrier on 2 for ['worker0', 'worker2']
start barrier on 3 for ['worker1', 'worker3']
shutdown2
shutdown3
shutdown1
shutdown0
```

Differential Revision: [D27737145](https://our.internmc.facebook.com/intern/diff/D27737145)

[ghstack-poisoned]
H-Huang added a commit that referenced this pull request May 6, 2021
ghstack-source-id: 165880c
Pull Request resolved: #53423
closes #40166

This change exposes a new internal API, rpc.api._barrier() which blocks the processes of workers running RPC until the specified group of workers completes this function.

Example:
```python
import os
import torch.multiprocessing as mp
import torch.distributed.rpc as rpc
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "5678"

world_size = 4
odd_num_workers = [f"worker{i}" for i in range(world_size) if i % 2]
even_num_workers = [f"worker{i}" for i in range(world_size) if not i % 2]

def worker(i):
    print(i)
    rpc.init_rpc(f"worker{i}", rank=i, world_size=world_size)
    if i % 2:
        print(f"start barrier on {i} for {odd_num_workers}")
        rpc.api._barrier(odd_num_workers)
    else:
        print(f"start barrier on {i} for {even_num_workers}")
        rpc.api._barrier(even_num_workers)
    rpc.shutdown()
    print(f"shutdown{i}")

if __name__ == '__main__':
    with mp.Pool(processes=world_size) as pool:
        pool.map(worker, range(world_size))
```

Output:

```
0
1
2
3
start barrier on 0 for ['worker0', 'worker2']
start barrier on 1 for ['worker1', 'worker3']
start barrier on 2 for ['worker0', 'worker2']
start barrier on 3 for ['worker1', 'worker3']
shutdown2
shutdown3
shutdown1
shutdown0
```

Differential Revision: [D27737145](https://our.internmc.facebook.com/intern/diff/D27737145)

[ghstack-poisoned]
H-Huang added a commit that referenced this pull request May 28, 2021
ghstack-source-id: 07946ab
Pull Request resolved: #53423
@rohan-varma rohan-varma self-requested a review June 2, 2021 17:52
Copy link
Contributor

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Thanks for adding the multithreading tests as we discussed.

reach this method to wait for all outstanding work to complete.
Args:
worker_names (List[str]): The set of workers to synchronize.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be added in a follow up diff, but should we also accept WorkerInfo or int for WorkerId here? The rest of the RPC APIs seem to accept either so might be good to add that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, will do.

closes #40166

This change exposes a new internal API, rpc.api._barrier() which blocks the processes of workers running RPC until the specified group of workers completes this function.

Example:
```python
import os
import torch.multiprocessing as mp
import torch.distributed.rpc as rpc
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "5678"

world_size = 4
odd_num_workers = [f"worker{i}" for i in range(world_size) if i % 2]
even_num_workers = [f"worker{i}" for i in range(world_size) if not i % 2]

def worker(i):
    print(i)
    rpc.init_rpc(f"worker{i}", rank=i, world_size=world_size)
    if i % 2:
        print(f"start barrier on {i} for {odd_num_workers}")
        rpc.api._barrier(odd_num_workers)
    else:
        print(f"start barrier on {i} for {even_num_workers}")
        rpc.api._barrier(even_num_workers)
    rpc.shutdown()
    print(f"shutdown{i}")

if __name__ == '__main__':
    with mp.Pool(processes=world_size) as pool:
        pool.map(worker, range(world_size))
```

Output:

```
0
1
2
3
start barrier on 0 for ['worker0', 'worker2']
start barrier on 1 for ['worker1', 'worker3']
start barrier on 2 for ['worker0', 'worker2']
start barrier on 3 for ['worker1', 'worker3']
shutdown2
shutdown3
shutdown1
shutdown0
```

Differential Revision: [D27737145](https://our.internmc.facebook.com/intern/diff/D27737145)

[ghstack-poisoned]
H-Huang added a commit that referenced this pull request Jun 2, 2021
ghstack-source-id: 2e6945a
Pull Request resolved: #53423
@H-Huang
Copy link
Member Author

H-Huang commented Jun 2, 2021

@H-Huang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@H-Huang merged this pull request in 7ee6836.

@facebook-github-bot facebook-github-bot deleted the gh/H-Huang/9/head branch June 6, 2021 14:16
deniskokarev pushed a commit to deniskokarev/pytorch that referenced this pull request Jun 9, 2021
Summary:
Pull Request resolved: pytorch#53423

closes pytorch#40166

This change exposes a new API, rpc.barrier() which blocks the main processes of all workers running RPC until the whole group completes this function. Optionally rpc.barrier can take in a set of worker_names and only synchronize across those worker names.

Example:
```python
import os
import torch.multiprocessing as mp
import torch.distributed.rpc as rpc
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "5678"

world_size = 4
odd_num_workers = [f"worker{i}" for i in range(world_size) if i % 2]
even_num_workers = [f"worker{i}" for i in range(world_size) if not i % 2]

def worker(i):
    print(i)
    rpc.init_rpc(f"worker{i}", rank=i, world_size=world_size)
    if i % 2:
        print(f"start barrier {i}")
        rpc.barrier(set(odd_num_workers))
    else:
        print(f"start barrier {i}")
        rpc.barrier(set(even_num_workers))
    rpc.shutdown()
    print(f"shutdown{i}")

if __name__ == '__main__':
    with mp.Pool(processes=world_size) as pool:
        pool.map(worker, range(world_size))
```

Test Plan: Imported from OSS

Reviewed By: rohan-varma

Differential Revision: D27737145

Pulled By: H-Huang

fbshipit-source-id: 369196bc62446f506d1fb6a3fa5bebcb0b09da9f
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged oncall: distributed Add this issue/PR to distributed oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants