- 
                Notifications
    You must be signed in to change notification settings 
- Fork 25.7k
Refactor dtype propagation #139945
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor dtype propagation #139945
Conversation
[ghstack-poisoned]
| 🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139945
 Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 17fca7f with merge base 43afaa4 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
 This comment was automatically generated by Dr. CI and updates every 15 minutes. | 
[ghstack-poisoned]
A couple changes. - Tries to reuse dtype propagation rules that were already registered in inductor. These were present both with `pointwise_overrides_data` and the `boolean_ops` list. Additionally, the registration of pointwise ops already specified dtype propagation rules. Saves those registrations and reuses them later. - Factors out `get_promoted_dtype` which uses functools.lru_cache to take in non - CSEVariable args because those will not work with the functools cache. Tests get added later in the stack when everything is implemented. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for the refactor.
A couple changes. - Tries to reuse dtype propagation rules that were already registered in inductor. These were present both with `pointwise_overrides_data` and the `boolean_ops` list. Additionally, the registration of pointwise ops already specified dtype propagation rules. Saves those registrations and reuses them later. - Factors out `get_promoted_dtype` which uses functools.lru_cache to take in non - CSEVariable args because those will not work with the functools cache. Tests get added later in the stack when everything is implemented. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov [ghstack-poisoned]
A couple changes. - Tries to reuse dtype propagation rules that were already registered in inductor. These were present both with `pointwise_overrides_data` and the `boolean_ops` list. Additionally, the registration of pointwise ops already specified dtype propagation rules. Saves those registrations and reuses them later. - Factors out `get_promoted_dtype` which uses functools.lru_cache to take in non - CSEVariable args because those will not work with the functools cache. Tests get added later in the stack when everything is implemented. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov [ghstack-poisoned]
A couple changes. - Tries to reuse dtype propagation rules that were already registered in inductor. These were present both with `pointwise_overrides_data` and the `boolean_ops` list. Additionally, the registration of pointwise ops already specified dtype propagation rules. Saves those registrations and reuses them later. - Factors out `get_promoted_dtype` which uses functools.lru_cache to take in non - CSEVariable args because those will not work with the functools cache. Tests get added later in the stack when everything is implemented. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov [ghstack-poisoned]
A couple changes. - Tries to reuse dtype propagation rules that were already registered in inductor. These were present both with `pointwise_overrides_data` and the `boolean_ops` list. Additionally, the registration of pointwise ops already specified dtype propagation rules. Saves those registrations and reuses them later. - Factors out `get_promoted_dtype` which uses functools.lru_cache to take in non - CSEVariable args because those will not work with the functools cache. Tests get added later in the stack when everything is implemented. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov [ghstack-poisoned]
A couple changes. - Tries to reuse dtype propagation rules that were already registered in inductor. These were present both with `pointwise_overrides_data` and the `boolean_ops` list. Additionally, the registration of pointwise ops already specified dtype propagation rules. Saves those registrations and reuses them later. - Factors out `get_promoted_dtype` which uses functools.lru_cache to take in non - CSEVariable args because those will not work with the functools cache. Tests get added later in the stack when everything is implemented. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov [ghstack-poisoned]
A couple changes. - Tries to reuse dtype propagation rules that were already registered in inductor. These were present both with `pointwise_overrides_data` and the `boolean_ops` list. Additionally, the registration of pointwise ops already specified dtype propagation rules. Saves those registrations and reuses them later. - Factors out `get_promoted_dtype` which uses functools.lru_cache to take in non - CSEVariable args because those will not work with the functools cache. Tests get added later in the stack when everything is implemented. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov [ghstack-poisoned]
A couple changes. - Tries to reuse dtype propagation rules that were already registered in inductor. These were present both with `pointwise_overrides_data` and the `boolean_ops` list. Additionally, the registration of pointwise ops already specified dtype propagation rules. Saves those registrations and reuses them later. - Factors out `get_promoted_dtype` which uses functools.lru_cache to take in non - CSEVariable args because those will not work with the functools cache. Tests get added later in the stack when everything is implemented. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov [ghstack-poisoned]
| @pytorchbot merge | 
| Merge failedReason: Approvers from one of the following sets are needed: 
 | 
| @pytorchbot merge | 
| Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team | 
Adds the remaining unimplemented ops as well as an assertion failure if someone adds a new op without a dtype rule. We test all unique pointwise operators registered as lowerings which have an opinfo. There will be some follow ups for this to work well with both `codegen_upcast_to_fp32` as True and False. Pull Request resolved: #140057 Approved by: https://github.com/arui-meta, https://github.com/blaine-rister, https://github.com/ezyang ghstack dependencies: #139945
- Add in upcast_compute_type on creation of new tensors (loads, constants) - Fixes index_expr - right now we are sort of inconsistent in dtype and dont always respect the dtype specified. would be nice to fix but not doing in this pr. - bug fix in view dtype where we were always upcasting back to fp32 when input was in bf16/fp16. we should only be doing that if the output is also in bf16/fp16. - for masked, avoid calling dtype propagation and just use output dtype. Turns on the runtime dtype verification for opinfo tests. The separate test file is still useful because we can use it for testing turning off codegen_upcast_to_fp32. Follow ups: - We could consider requiring less explicit upcast_compute_types calls and do it automatically. That would potentially make things easier but be less flexible in the future. Maybe I should have done it this pr. - Be more consistent on our index expr dtype printing. Pull Request resolved: #141495 Approved by: https://github.com/blaine-rister, https://github.com/arui-meta, https://github.com/ezyang ghstack dependencies: #139945, #140057
| This appears to cause a compile time regression: https://fburl.com/scuba/torch_open_source_signpost/5tti7eql   | 
| sounds expected as test files has changed, but just adding it as another ref point   https://www.internalfb.com/intern/unidash/dashboard/pt2_diff_time_metrics | 
| Fix should be here: #141882, need to verify it though. | 
Should fix compile time regression, it was doing fairly expensive meta programming in init and being instantiated multiple times. Pull Request resolved: #141882 Approved by: https://github.com/ezyang ghstack dependencies: #139945, #140057, #141495
- Set the dtype of "acc" appropriately so that epilogue fusion will have args with dtype - Update dtype propagation to use `type_to_dtype` instead of instantiating tensor - Throw if we have a string arg where we should have a proper CSEVariable, unless we're doing the Modification Subgraph thing which is nyi. everything else is appropriately typed (cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @drisspg ). Pull Request resolved: #141991 Approved by: https://github.com/drisspg ghstack dependencies: #139945, #140057, #141495, #141882
A couple changes. - Tries to reuse dtype propagation rules that were already registered in inductor. These were present both with `pointwise_overrides_data` and the `boolean_ops` list. Additionally, the registration of pointwise ops already specified dtype propagation rules. Saves those registrations and reuses them later. - Factors out `get_promoted_dtype` which uses functools.lru_cache to take in non - CSEVariable args because those will not work with the functools cache. Tests get added later in the stack when everything is implemented. Pull Request resolved: pytorch#139945 Approved by: https://github.com/blaine-rister, https://github.com/arui-meta, https://github.com/ezyang
Adds the remaining unimplemented ops as well as an assertion failure if someone adds a new op without a dtype rule. We test all unique pointwise operators registered as lowerings which have an opinfo. There will be some follow ups for this to work well with both `codegen_upcast_to_fp32` as True and False. Pull Request resolved: pytorch#140057 Approved by: https://github.com/arui-meta, https://github.com/blaine-rister, https://github.com/ezyang ghstack dependencies: pytorch#139945
- Add in upcast_compute_type on creation of new tensors (loads, constants) - Fixes index_expr - right now we are sort of inconsistent in dtype and dont always respect the dtype specified. would be nice to fix but not doing in this pr. - bug fix in view dtype where we were always upcasting back to fp32 when input was in bf16/fp16. we should only be doing that if the output is also in bf16/fp16. - for masked, avoid calling dtype propagation and just use output dtype. Turns on the runtime dtype verification for opinfo tests. The separate test file is still useful because we can use it for testing turning off codegen_upcast_to_fp32. Follow ups: - We could consider requiring less explicit upcast_compute_types calls and do it automatically. That would potentially make things easier but be less flexible in the future. Maybe I should have done it this pr. - Be more consistent on our index expr dtype printing. Pull Request resolved: pytorch#141495 Approved by: https://github.com/blaine-rister, https://github.com/arui-meta, https://github.com/ezyang ghstack dependencies: pytorch#139945, pytorch#140057
Should fix compile time regression, it was doing fairly expensive meta programming in init and being instantiated multiple times. Pull Request resolved: pytorch#141882 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#139945, pytorch#140057, pytorch#141495
- Set the dtype of "acc" appropriately so that epilogue fusion will have args with dtype - Update dtype propagation to use `type_to_dtype` instead of instantiating tensor - Throw if we have a string arg where we should have a proper CSEVariable, unless we're doing the Modification Subgraph thing which is nyi. everything else is appropriately typed (cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @drisspg ). Pull Request resolved: pytorch#141991 Approved by: https://github.com/drisspg ghstack dependencies: pytorch#139945, pytorch#140057, pytorch#141495, pytorch#141882
- Set the dtype of "acc" appropriately so that epilogue fusion will have args with dtype - Update dtype propagation to use `type_to_dtype` instead of instantiating tensor - Throw if we have a string arg where we should have a proper CSEVariable, unless we're doing the Modification Subgraph thing which is nyi. everything else is appropriately typed (cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @drisspg ). Pull Request resolved: pytorch#141991 Approved by: https://github.com/drisspg ghstack dependencies: pytorch#139945, pytorch#140057, pytorch#141495, pytorch#141882
Stack from ghstack (oldest at bottom):
A couple changes.
Tries to reuse dtype propagation rules that were already registered in inductor. These were present both with
pointwise_overrides_dataand theboolean_opslist. Additionally, the registration of pointwise ops already specified dtype propagation rules. Saves those registrations and reuses them later.Factors out
get_promoted_dtypewhich uses functools.lru_cache to take in non - CSEVariable args because those will not work with the functools cache.Tests get added later in the stack when everything is implemented.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov