KEMBAR78
Use different conv layout optimization heuristics for inference by eellison · Pull Request #114600 · pytorch/pytorch · GitHub
Skip to content

Conversation

eellison
Copy link
Contributor

@eellison eellison commented Nov 27, 2023

Stack from ghstack (oldest at bottom):

While many models regress in training when converted to channels last, in inference the results are quite different. Almost all of the models experienced a speedup when converted to channels last. There were a few big regressions in torchbench - timm_regnet from 1.4343 → 1.0573 and timm_resnet from 1.7484 → 1.2868.

I used a modified script of the operator benchmarks here to measure the average speedup of convolutions across all of the input shapes found in torchbench according to the existing classifications that @shunting314 used - grouped convs, small channel convs, convolution with larger in-channel than out-channel. Only grouped convolutions benchmarked as a slowdown in inference.

I updated the inference heuristic to multiply the flops of each conv with its predicted speedup/slowdown in channels last. With this heuristic the two previously regressing models no longer regress.

Speeds up inference for torchbench ~8% and timm ~6%. The motivating model here was SDXL which now hits channels last and improves 10%.

There were some models that were sped up in training when forcing channels last (along with a number of regressions). It's possible there is some speedup in training to be had with additional heuristics. We could also have more granular classification/predictions which might benefit both training and inference.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 27, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/114600

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (5 Unrelated Failures)

As of commit 70140dc with merge base 56a95af (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…rence"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
…rence"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
…rence"



While many models regress in training when converted to channels last, in inference the results are quite different. Almost all of the models experienced a speedup when converted to channels last. There were a few big regressions in torchbench - `timm_regnet` from `1.4343 → 1.0573` and `timm_resnet` from `1.7484 → 1.2868`.

 I used a modified script of the operator benchmarks [here](https://gist.github.com/eellison/e11dc645412f52e8b45fb26ba6f9f6a1) to measure the average speedup of convolutions across all of the input shapes found in torchbench according to the existing classifications that shunting314 used - grouped convs, small channel convs, convolution with larger in-channel than out-channel. Only grouped convolutions benchmarked as a slowdown in inference. 

I updated the inference heuristic to multiply the flops of each conv with its predicted speedup/slowdown in channels last. With this heuristic the two previously regressing models no longer regress. 

Speeds up inference for torchbench ~8% and timm ~6%. The motivating model here was SDXL which now hits channels last and improves 10%.

There were some models that were sped up in training when forcing channels last (along with a number of regressions), it's possible there is some speedup there to be had. We could also have more granular classification/predictions.



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
@eellison eellison requested review from Chillee, jansel and shunting314 and removed request for Chillee, jansel and shunting314 November 27, 2023 17:57
…rence"



While many models regress in training when converted to channels last, in inference the results are quite different. Almost all of the models experienced a speedup when converted to channels last. There were a few big regressions in torchbench - `timm_regnet` from `1.4343 → 1.0573` and `timm_resnet` from `1.7484 → 1.2868`.

 I used a modified script of the operator benchmarks [here](https://gist.github.com/eellison/e11dc645412f52e8b45fb26ba6f9f6a1) to measure the average speedup of convolutions across all of the input shapes found in torchbench according to the existing classifications that shunting314 used - grouped convs, small channel convs, convolution with larger in-channel than out-channel. Only grouped convolutions benchmarked as a slowdown in inference. 

I updated the inference heuristic to multiply the flops of each conv with its predicted speedup/slowdown in channels last. With this heuristic the two previously regressing models no longer regress. 

Speeds up inference for torchbench ~8% and timm ~6%. The motivating model here was SDXL which now hits channels last and improves 10%.

There were some models that were sped up in training when forcing channels last (along with a number of regressions), it's possible there is some speedup there to be had. We could also have more granular classification/predictions.



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
eellison added a commit that referenced this pull request Nov 27, 2023
@eellison eellison requested review from Chillee, jansel, jgong5 and shunting314 and removed request for jgong5 and shunting314 November 27, 2023 22:17
…rence"



While many models regress in training when converted to channels last, in inference the results are quite different. Almost all of the models experienced a speedup when converted to channels last. There were a few big regressions in torchbench - `timm_regnet` from `1.4343 → 1.0573` and `timm_resnet` from `1.7484 → 1.2868`.

 I used a modified script of the operator benchmarks [here](https://gist.github.com/eellison/e11dc645412f52e8b45fb26ba6f9f6a1) to measure the average speedup of convolutions across all of the input shapes found in torchbench according to the existing classifications that shunting314 used - grouped convs, small channel convs, convolution with larger in-channel than out-channel. Only grouped convolutions benchmarked as a slowdown in inference. 

I updated the inference heuristic to multiply the flops of each conv with its predicted speedup/slowdown in channels last. With this heuristic the two previously regressing models no longer regress. 

Speeds up inference for torchbench ~8% and timm ~6%. The motivating model here was SDXL which now hits channels last and improves 10%.

There were some models that were sped up in training when forcing channels last (along with a number of regressions), it's possible there is some speedup there to be had. We could also have more granular classification/predictions.



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
eellison added a commit that referenced this pull request Nov 28, 2023
"test_kwargs_dynamic_shapes": TestFailure(("cpu",)),
# calling div on only symint args
"test_AllenaiLongformerBase_repro_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_conv_inference_heuristics_dynamic_shapes": TestFailure("cuda"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We turn off channels last optimization for dynamic shapes. That is blocked on persistent reduction perf not being enabled for dynamic shapes. CC @peterbell10 @shunting314 was one of you working on persistent reduction with dynamic shapes ?

…rence"



While many models regress in training when converted to channels last, in inference the results are quite different. Almost all of the models experienced a speedup when converted to channels last. There were a few big regressions in torchbench - `timm_regnet` from `1.4343 → 1.0573` and `timm_resnet` from `1.7484 → 1.2868`.

 I used a modified script of the operator benchmarks [here](https://gist.github.com/eellison/e11dc645412f52e8b45fb26ba6f9f6a1) to measure the average speedup of convolutions across all of the input shapes found in torchbench according to the existing classifications that shunting314 used - grouped convs, small channel convs, convolution with larger in-channel than out-channel. Only grouped convolutions benchmarked as a slowdown in inference. 

I updated the inference heuristic to multiply the flops of each conv with its predicted speedup/slowdown in channels last. With this heuristic the two previously regressing models no longer regress. 

Speeds up inference for torchbench ~8% and timm ~6%. The motivating model here was SDXL which now hits channels last and improves 10%.

There were some models that were sped up in training when forcing channels last (along with a number of regressions). It's possible there is some speedup in training to be had with additional heuristics. We could also have more granular classification/predictions which might benefit both training and inference.



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
eellison added a commit that referenced this pull request Nov 28, 2023
@eellison eellison requested a review from jansel November 28, 2023 17:04
@shunting314
Copy link
Contributor

Speeds up inference for torchbench ~8% and timm ~6%.

Is this speedup between using the new inference specific heuristics v.s. using the previous heuristics for both training and inference?

@eellison
Copy link
Contributor Author

eellison commented Nov 29, 2023

@shunting314 the speedup is for inference relative to main branch.

@eellison
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 29, 2023
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@eellison
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/eellison/561/head branch December 2, 2023 15:28
@stncil
Copy link

stncil commented Feb 13, 2024

Is there any reason as to why this speedup/slowdown occurs and also why the behavior changes during training and inference>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants