-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Use different conv layout optimization heuristics for inference #114600
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/114600
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (5 Unrelated Failures)As of commit 70140dc with merge base 56a95af ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…rence" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…rence" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…rence" While many models regress in training when converted to channels last, in inference the results are quite different. Almost all of the models experienced a speedup when converted to channels last. There were a few big regressions in torchbench - `timm_regnet` from `1.4343 → 1.0573` and `timm_resnet` from `1.7484 → 1.2868`. I used a modified script of the operator benchmarks [here](https://gist.github.com/eellison/e11dc645412f52e8b45fb26ba6f9f6a1) to measure the average speedup of convolutions across all of the input shapes found in torchbench according to the existing classifications that shunting314 used - grouped convs, small channel convs, convolution with larger in-channel than out-channel. Only grouped convolutions benchmarked as a slowdown in inference. I updated the inference heuristic to multiply the flops of each conv with its predicted speedup/slowdown in channels last. With this heuristic the two previously regressing models no longer regress. Speeds up inference for torchbench ~8% and timm ~6%. The motivating model here was SDXL which now hits channels last and improves 10%. There were some models that were sped up in training when forcing channels last (along with a number of regressions), it's possible there is some speedup there to be had. We could also have more granular classification/predictions. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…rence" While many models regress in training when converted to channels last, in inference the results are quite different. Almost all of the models experienced a speedup when converted to channels last. There were a few big regressions in torchbench - `timm_regnet` from `1.4343 → 1.0573` and `timm_resnet` from `1.7484 → 1.2868`. I used a modified script of the operator benchmarks [here](https://gist.github.com/eellison/e11dc645412f52e8b45fb26ba6f9f6a1) to measure the average speedup of convolutions across all of the input shapes found in torchbench according to the existing classifications that shunting314 used - grouped convs, small channel convs, convolution with larger in-channel than out-channel. Only grouped convolutions benchmarked as a slowdown in inference. I updated the inference heuristic to multiply the flops of each conv with its predicted speedup/slowdown in channels last. With this heuristic the two previously regressing models no longer regress. Speeds up inference for torchbench ~8% and timm ~6%. The motivating model here was SDXL which now hits channels last and improves 10%. There were some models that were sped up in training when forcing channels last (along with a number of regressions), it's possible there is some speedup there to be had. We could also have more granular classification/predictions. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…rence" While many models regress in training when converted to channels last, in inference the results are quite different. Almost all of the models experienced a speedup when converted to channels last. There were a few big regressions in torchbench - `timm_regnet` from `1.4343 → 1.0573` and `timm_resnet` from `1.7484 → 1.2868`. I used a modified script of the operator benchmarks [here](https://gist.github.com/eellison/e11dc645412f52e8b45fb26ba6f9f6a1) to measure the average speedup of convolutions across all of the input shapes found in torchbench according to the existing classifications that shunting314 used - grouped convs, small channel convs, convolution with larger in-channel than out-channel. Only grouped convolutions benchmarked as a slowdown in inference. I updated the inference heuristic to multiply the flops of each conv with its predicted speedup/slowdown in channels last. With this heuristic the two previously regressing models no longer regress. Speeds up inference for torchbench ~8% and timm ~6%. The motivating model here was SDXL which now hits channels last and improves 10%. There were some models that were sped up in training when forcing channels last (along with a number of regressions), it's possible there is some speedup there to be had. We could also have more granular classification/predictions. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
| "test_kwargs_dynamic_shapes": TestFailure(("cpu",)), | ||
| # calling div on only symint args | ||
| "test_AllenaiLongformerBase_repro_dynamic_shapes": TestFailure(("cpu", "cuda")), | ||
| "test_conv_inference_heuristics_dynamic_shapes": TestFailure("cuda"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We turn off channels last optimization for dynamic shapes. That is blocked on persistent reduction perf not being enabled for dynamic shapes. CC @peterbell10 @shunting314 was one of you working on persistent reduction with dynamic shapes ?
…rence" While many models regress in training when converted to channels last, in inference the results are quite different. Almost all of the models experienced a speedup when converted to channels last. There were a few big regressions in torchbench - `timm_regnet` from `1.4343 → 1.0573` and `timm_resnet` from `1.7484 → 1.2868`. I used a modified script of the operator benchmarks [here](https://gist.github.com/eellison/e11dc645412f52e8b45fb26ba6f9f6a1) to measure the average speedup of convolutions across all of the input shapes found in torchbench according to the existing classifications that shunting314 used - grouped convs, small channel convs, convolution with larger in-channel than out-channel. Only grouped convolutions benchmarked as a slowdown in inference. I updated the inference heuristic to multiply the flops of each conv with its predicted speedup/slowdown in channels last. With this heuristic the two previously regressing models no longer regress. Speeds up inference for torchbench ~8% and timm ~6%. The motivating model here was SDXL which now hits channels last and improves 10%. There were some models that were sped up in training when forcing channels last (along with a number of regressions). It's possible there is some speedup in training to be had with additional heuristics. We could also have more granular classification/predictions which might benefit both training and inference. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Is this speedup between using the new inference specific heuristics v.s. using the previous heuristics for both training and inference? |
|
@shunting314 the speedup is for inference relative to main branch. |
|
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
Is there any reason as to why this speedup/slowdown occurs and also why the behavior changes during training and inference> |
Stack from ghstack (oldest at bottom):
While many models regress in training when converted to channels last, in inference the results are quite different. Almost all of the models experienced a speedup when converted to channels last. There were a few big regressions in torchbench -
timm_regnetfrom1.4343 → 1.0573andtimm_resnetfrom1.7484 → 1.2868.I used a modified script of the operator benchmarks here to measure the average speedup of convolutions across all of the input shapes found in torchbench according to the existing classifications that @shunting314 used - grouped convs, small channel convs, convolution with larger in-channel than out-channel. Only grouped convolutions benchmarked as a slowdown in inference.
I updated the inference heuristic to multiply the flops of each conv with its predicted speedup/slowdown in channels last. With this heuristic the two previously regressing models no longer regress.
Speeds up inference for torchbench ~8% and timm ~6%. The motivating model here was SDXL which now hits channels last and improves 10%.
There were some models that were sped up in training when forcing channels last (along with a number of regressions). It's possible there is some speedup in training to be had with additional heuristics. We could also have more granular classification/predictions which might benefit both training and inference.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler