-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Improve reflection_pad2d lowering for dynamic shapes #110988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Fixes #110696 Signed-off-by: Edward Z. Yang <ezyang@meta.com> [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110988
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit c4aa9cb with merge base 261cae7 ( UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Still needs tests, but a review will still be helpful |
torch/_inductor/index_propagation.py
Outdated
| def constant(value: Union[int, float, bool], dtype: torch.dtype) -> TypedExpr: | ||
| if is_boolean_dtype(dtype): | ||
| if isinstance(value, sympy.Expr): | ||
| expr = value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ops.constant shouldn't be called with sympy.Expr. Instead it should be ops.index_expr.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So it turns out that it's easy enough to change the ops.constant call in the lowering to ops.index_expr, but now I am wondering, is there really any material difference between these two calls? Like, it's not like you're going to get worse code if you call index_expr with a constant, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally I think they should be similar but in practice they are materially different. ops.constant will respect the dtype argument and generate code like tl.full([1, 1], 256, tl.int32) whereas ops.index_expr goes through the sympy expression printer which gives an integer literal 256 which ignores the specified dtype.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in any case I fixed it, plz review lol
Fixes #110696 Signed-off-by: Edward Z. Yang <ezyangmeta.com> cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Fixes #110696 Signed-off-by: Edward Z. Yang <ezyangmeta.com> cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
| dtype=x.get_dtype(), | ||
| inner_fn=fn, | ||
| ranges=[*batch, sympy.Integer(h + top + bot), sympy.Integer(w + left + right)], | ||
| ranges=[*batch, sympy.sympify(h + top + bot), sympy.sympify(w + left + right)], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need sympy.sympify here ? h or w should either be integers or symbols which would work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if h, top and bot happen to be all int, we won't return a sympy.Integer, but I noticed that in inductor it is sometimes load bearing to return a sympy expression, not a plain int
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be fine to return an int here
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Fixes #110696
Signed-off-by: Edward Z. Yang ezyang@meta.com
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler