-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Fix slice scatter dtype consistency #160851
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix slice scatter dtype consistency #160851
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160851
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 69bdfd0 with merge base 6737e2c ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot label "module: inductor" |
|
@eellison Thank you for your review. I have fixed the CI errors. Please let me know if any further modifications are needed, and kindly help run the CI checks again. |
|
@soulitzer Thank you for the review. I have gained a new understanding of this part. In Eager mode, the C++ implementation is directly invoked, while in Inductor mode, Python lowering performs type checks. I have now removed the assertion checks for Python lowering and fully delegated type handling to the C++ implementation to maintain logical consistency. The tests I conducted passed successfully. |
|
Thanks for the update, that sounds good. Could you add a test? |
|
@soulitzer Removing the assert x.get_dtype() == src.get_dtype() is not a good idea, so the solution has been reverted to the original approach: type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH. We have added tests, which have already passed. |
|
@soulitzer Testing type promotion |
Why not? We don't have that assumption in eager?
Hmm, but my earlier point was that we don't actually do type promotion in this op. |
|
@soulitzer |
|
@soulitzer The ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH test passes. I suspect it actually performs non-trivial work beyond simple type promotion, although the detailed mechanics remain unclear at this stage. |
|
I suspect its not doing type promotion based on this example: Are you testing that? |
YES
|
Sorry to clarify, what I mean to ask is if you have checked parity for this example between the inductor lowering and the op? I'm asking because I don't see it in the test case "test_slice_scatter_dtype_consistency", could you add it there? |
|
@soulitzer Thank you for your review. I apologize, as my previous test was incorrect. Removing the assertion x.get_dtype() == src.get_dtype() was the right decision. I have updated the test code accordingly. Thanks again for the review. |
test/inductor/test_torchinductor.py
Outdated
| return torch.slice_scatter(y, x, 0) | ||
|
|
||
| for dtype in [ | ||
| torch.int8, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, but probably no need to test every single int type
one int, float64, and then complex is probably sufficient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified the test
|
Thanks for the update. For the test failures, let's just skip in test/inductor/test_torchinductor_codegen_dynamic_shapes.py |
|
@soulitzer Thanks for the review.. Can we merge it? |
|
We may have to fix a test failure to pass CI. See my previous comment here #160851 (comment) |
|
❌ 🤖 pytorchbot command failed: Try |
|
|
||
| @register_lowering(aten.slice_scatter, type_promotion_kind=None) | ||
| def slice_scatter(x, src, dim=0, start=None, end=None, step=1): | ||
| assert x.get_dtype() == src.get_dtype() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should replace the assertion with:
src = to_dtype(src, x.get_dtype())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All set, I've updated it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
|
@pytorchbot merge |
|
Successfully rebased |
da3ee2e to
69bdfd0
Compare
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes pytorch#147842 Fix torch.slice_scatter type inconsistency issue. I noticed previous PRs on this have stalled, so I'm opening this new PR. Pull Request resolved: pytorch#160851 Approved by: https://github.com/soulitzer
Fixes pytorch#147842 Fix torch.slice_scatter type inconsistency issue. I noticed previous PRs on this have stalled, so I'm opening this new PR. Pull Request resolved: pytorch#160851 Approved by: https://github.com/soulitzer
Fixes #147842
Fix torch.slice_scatter type inconsistency issue. I noticed previous PRs on this have stalled, so I'm opening this new PR.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben