KEMBAR78
[Reland2][DDP] Merge work and future_work in reducer by wayi1 · Pull Request #59574 · pytorch/pytorch · GitHub
Skip to content

Conversation

@wayi1
Copy link
Contributor

@wayi1 wayi1 commented Jun 7, 2021

Stack from ghstack:

Remove work attribute from Reducer class in favor of future_work.

Additionally, remove copy_grad_to_bucket method since now it's only one-line implementation, and created a new C++ comm hook called _AllReduceCommHookWithDivFactor to replace allreduce and also support handling uneven input.

  1. Compared with the reverted [DDP] Merge work and future_work in reducer #58937, updated _AllReduceCommHookWithDivFactor in default_comm_hooks.cpp to apply division first and hence avoid FP16 overflow.

  2. Compared with the reverted [Reland][DDP] Merge work and future_work in reducer #59520, disabled test_DistributedDataParallel_non_default_stream on AMD, because now applying division first hurts the gradient averaging accuracy on AMD.
    See [07:48:26]:
    https://ci.pytorch.org/jenkins/job/pytorch-builds/job/pytorch-linux-bionic-rocm4.2-py3.6-test1/1129/console

#Original PR Issue: #41266

Differential Revision: D28940800

Remove `work` attribute from Reducer class in favor of `future_work`.

Additionally, remove `copy_grad_to_bucket` method since now it's only one-line implementation, and created a new C++ comm hook called `_AllReduceCommHookWithDivFactor` to replace allreduce and also support handling uneven input.

1) Compared with the reverted #58937, updated `_AllReduceCommHookWithDivFactor` in `default_comm_hooks.cpp` to apply division first and hence avoid FP16 overflow.

2) Compared with the reverted #59520, disabled `test_DistributedDataParallel_non_default_stream` on AMD, because now applying division first hurts the gradient averaging accuracy on AMD.
See [07:48:26]:
https://ci.pytorch.org/jenkins/job/pytorch-builds/job/pytorch-linux-bionic-rocm4.2-py3.6-test1/1129/console

#Original PR Issue: #41266

Differential Revision: [D28940800](https://our.internmc.facebook.com/intern/diff/D28940800/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jun 7, 2021

💊 CI failures summary and remediations

As of commit 65c1c2a (more details on the Dr. CI page):


  • 2/2 failures possibly* introduced in this PR
    • 1/2 non-scanned failure(s)

1 failure not recognized by patterns:

Job Step Action
GitHub Actions Label PRs & Issues / auto-label-rocm Unknown 🔁 rerun

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

wayi1 pushed a commit that referenced this pull request Jun 7, 2021
Remove `work` attribute from Reducer class in favor of `future_work`.

Additionally, remove `copy_grad_to_bucket` method since now it's only one-line implementation, and created a new C++ comm hook called `_AllReduceCommHookWithDivFactor` to replace allreduce and also support handling uneven input.

1) Compared with the reverted #58937, updated `_AllReduceCommHookWithDivFactor` in `default_comm_hooks.cpp` to apply division first and hence avoid FP16 overflow.

2) Compared with the reverted #59520, disabled `test_DistributedDataParallel_non_default_stream` on AMD, because now applying division first hurts the gradient averaging accuracy on AMD.
See [07:48:26]:
https://ci.pytorch.org/jenkins/job/pytorch-builds/job/pytorch-linux-bionic-rocm4.2-py3.6-test1/1129/console

#Original PR Issue: #41266

Differential Revision: [D28940800](https://our.internmc.facebook.com/intern/diff/D28940800/)

ghstack-source-id: 130752393
Pull Request resolved: #59574
@facebook-github-bot facebook-github-bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jun 7, 2021
Copy link
Contributor

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM to unblock for now, would be great to file an issue to investigate why it fails on ROCm. Thanks!

@agolynski
Copy link
Contributor

A couple of questions:
wondering why it is called _AllReduceCommHookWithDivFactor, are you planning to deprecate/remove it later?
are we changing computation logic here or this is just refactoring PR (e.g. why overflow problem just surfaced in this PR)?

@wayi1
Copy link
Contributor Author

wayi1 commented Jun 7, 2021

@jithunnair-amd test_DistributedDataParallel_non_default_stream is disabled on AMD, because now when we compute average gradients, we first divide the local gradient by the group size, and then sum local gradients up, in order to prevent the overflow in the range of FP16. However, this has caused a non-trivial discrepancy on the output average. The same test can pass on NVIDIA GPUs.

See [07:48:26]:
https://ci.pytorch.org/jenkins/job/pytorch-builds/job/pytorch-linux-bionic-rocm4.2-py3.6-test1/1129/console

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 6575975.

deniskokarev pushed a commit to deniskokarev/pytorch that referenced this pull request Jun 9, 2021
Summary:
Pull Request resolved: pytorch#59574

Remove `work` attribute from Reducer class in favor of `future_work`.

Additionally, remove `copy_grad_to_bucket` method since now it's only one-line implementation, and created a new C++ comm hook called `_AllReduceCommHookWithDivFactor` to replace allreduce and also support handling uneven input.

1) Compared with the reverted pytorch#58937, updated `_AllReduceCommHookWithDivFactor` in `default_comm_hooks.cpp` to apply division first and hence avoid FP16 overflow.

2) Compared with the reverted pytorch#59520, disabled `test_DistributedDataParallel_non_default_stream` on AMD, because now applying division first hurts the gradient averaging accuracy on AMD.
See [07:48:26]:
https://ci.pytorch.org/jenkins/job/pytorch-builds/job/pytorch-linux-bionic-rocm4.2-py3.6-test1/1129/console

#Original PR Issue: pytorch#41266
ghstack-source-id: 130752393

Test Plan:
buck test caffe2/test/distributed:distributed_gloo_fork --  test_accumulate_gradients_no_sync
buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork -- test_accumulate_gradients_no_sync
buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork -- test_ddp_grad_div_uneven_inputs
buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_fp16
buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_fp16_grad_is_view

buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork --  test_DistributedDataParallel_non_default_stream

Reviewed By: rohan-varma

Differential Revision: D28940800

fbshipit-source-id: 1ba727ac951ebc1e7875dc1a1be8108a2c8d9462
@facebook-github-bot facebook-github-bot deleted the gh/SciPioneer/144/head branch June 11, 2021 14:17
wayi1 pushed a commit that referenced this pull request Jul 7, 2021
…n when no comm hook is specified

The optimization was accidentally removed in #59574

This optimization can help save a scan over all the input parameters, by fusing copy and div operations.

Now the default temporary hook is allreduce by sum, and no extra division is done inside the hook.

Differential Revision: [D29597614](https://our.internmc.facebook.com/intern/diff/D29597614/)

[ghstack-poisoned]
wayi1 pushed a commit that referenced this pull request Jul 7, 2021
…n when no comm hook is specified

The optimization was accidentally removed in #59574

This optimization can help save a scan over all the input parameters, by fusing copy and div operations.

Now the default temporary hook is allreduce by sum, and no extra division is done inside the hook.

Differential Revision: [D29597614](https://our.internmc.facebook.com/intern/diff/D29597614/)

ghstack-source-id: 133169350
Pull Request resolved: #61379
wayi1 pushed a commit that referenced this pull request Jul 7, 2021
… of fusing copy and division when no comm hook is specified"

The optimization was accidentally removed in #59574

This optimization can help save a scan over all the input parameters, by fusing copy and div operations.

Now the default temporary hook is allreduce by sum, and no extra division is done inside the hook.

Differential Revision: [D29597614](https://our.internmc.facebook.com/intern/diff/D29597614/)

[ghstack-poisoned]
wayi1 pushed a commit that referenced this pull request Jul 7, 2021
…and division when no comm hook is specified"

The optimization was accidentally removed in #59574

This optimization can help save a scan over all the input parameters, by fusing copy and div operations.

Now the default temporary hook is allreduce by sum, and no extra division is done inside the hook.

Differential Revision: [D29597614](https://our.internmc.facebook.com/intern/diff/D29597614/)

[ghstack-poisoned]
wayi1 pushed a commit that referenced this pull request Jul 7, 2021
…n when no comm hook is specified

Pull Request resolved: #61379

The optimization was accidentally removed in #59574

This optimization can help save a scan over all the input parameters, by fusing copy and div operations.

Now the default temporary hook is allreduce by sum, and no extra division is done inside the hook.
ghstack-source-id: 133174301

Differential Revision: [D29597614](https://our.internmc.facebook.com/intern/diff/D29597614/)
wayi1 pushed a commit that referenced this pull request Jul 9, 2021
… of fusing copy and division when no comm hook is specified"

The optimization was accidentally removed in #59574

This optimization can help save a scan over all the input parameters, by fusing copy and div operations.

Now the default temporary hook is allreduce by sum, and no extra division is done inside the hook.

Differential Revision: [D29597614](https://our.internmc.facebook.com/intern/diff/D29597614/)

[ghstack-poisoned]
wayi1 pushed a commit that referenced this pull request Jul 9, 2021
…and division when no comm hook is specified"

The optimization was accidentally removed in #59574

This optimization can help save a scan over all the input parameters, by fusing copy and div operations.

Now the default temporary hook is allreduce by sum, and no extra division is done inside the hook.

Differential Revision: [D29597614](https://our.internmc.facebook.com/intern/diff/D29597614/)

[ghstack-poisoned]
wayi1 pushed a commit that referenced this pull request Jul 9, 2021
…n when no comm hook is specified

Pull Request resolved: #61379

The optimization was accidentally removed in #59574

This optimization can help save a scan over all the input parameters, by fusing copy and div operations.

Now the default temporary hook is allreduce by sum, and no extra division is done inside the hook.
ghstack-source-id: 133277170

Differential Revision: [D29597614](https://our.internmc.facebook.com/intern/diff/D29597614/)
wayi1 pushed a commit that referenced this pull request Jul 9, 2021
… of fusing copy and division when no comm hook is specified"

The optimization was accidentally removed in #59574

This optimization can help save a scan over all the input parameters, by fusing copy and div operations.

Now the default temporary hook is allreduce by sum, and no extra division is done inside the hook.

Differential Revision: [D29597614](https://our.internmc.facebook.com/intern/diff/D29597614/)

[ghstack-poisoned]
wayi1 pushed a commit that referenced this pull request Jul 9, 2021
…and division when no comm hook is specified"

The optimization was accidentally removed in #59574

This optimization can help save a scan over all the input parameters, by fusing copy and div operations.

Now the default temporary hook is allreduce by sum, and no extra division is done inside the hook.

Differential Revision: [D29597614](https://our.internmc.facebook.com/intern/diff/D29597614/)

[ghstack-poisoned]
wayi1 pushed a commit that referenced this pull request Jul 9, 2021
…n when no comm hook is specified

Pull Request resolved: #61379

The optimization was accidentally removed in #59574

This optimization can help save a scan over all the input parameters, by fusing copy and div operations.

Now the default temporary hook is allreduce by sum, and no extra division is done inside the hook.
ghstack-source-id: 133282616

Differential Revision: [D29597614](https://our.internmc.facebook.com/intern/diff/D29597614/)
wayi1 pushed a commit that referenced this pull request Jul 9, 2021
… of fusing copy and division when no comm hook is specified"

The optimization was accidentally removed in #59574

This optimization can help save a scan over all the input parameters, by fusing copy and div operations.

Now the default temporary hook is allreduce by sum, and no extra division is done inside the hook.

Differential Revision: [D29597614](https://our.internmc.facebook.com/intern/diff/D29597614/)

[ghstack-poisoned]
wayi1 pushed a commit that referenced this pull request Jul 9, 2021
…and division when no comm hook is specified"

The optimization was accidentally removed in #59574

This optimization can help save a scan over all the input parameters, by fusing copy and div operations.

Now the default temporary hook is allreduce by sum, and no extra division is done inside the hook.

Differential Revision: [D29597614](https://our.internmc.facebook.com/intern/diff/D29597614/)

[ghstack-poisoned]
wayi1 pushed a commit that referenced this pull request Jul 9, 2021
…n when no comm hook is specified

Pull Request resolved: #61379

The optimization was accidentally removed in #59574

This optimization can help save a scan over all the input parameters, by fusing copy and div operations.

Now the default temporary hook is allreduce by sum, and no extra division is done inside the hook.
ghstack-source-id: 133288529

Differential Revision: [D29597614](https://our.internmc.facebook.com/intern/diff/D29597614/)
facebook-github-bot pushed a commit that referenced this pull request Jul 9, 2021
…n when no comm hook is specified (#61379)

Summary:
Pull Request resolved: #61379

The optimization was accidentally removed in #59574

This optimization can help save a scan over all the input parameters, by fusing copy and div operations.

Now the default temporary hook is allreduce by sum, and no extra division is done inside the hook.
ghstack-source-id: 133288529

Test Plan:
buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork -- test_accumulate_gradients_no_sync
buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork -- test_ddp_grad_div_uneven_inputs
buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_fp16
buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_fp16_grad_is_view
buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork --  test_DistributedDataParallel_non_default_stream

buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_sparse_gradient

buck test mode/dev-nosan caffe2/test/distributed:c10 -- test_ddp_checkpointing_once
buck test mode/dev-nosan caffe2/test/distributed:c10 -- test_ddp_checkpointing_twice

Reviewed By: rohan-varma

Differential Revision: D29597614

fbshipit-source-id: 2434e4fd4e6abad7871cfe47886fe97b6e4ba28f
@taozhiwei
Copy link
Contributor

May I ask for your advice:
in redecur.cpp call bucket.future_work->wait(),not call work->wait,how to ensure that the calculation stream is waiting for NCCL stream

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged oncall: distributed Add this issue/PR to distributed oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants