-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[Inductor] improve the stride preservation logic of user-visible outputs #136732
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136732
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 091c939 with merge base 5ea6777 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…put_idxs" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…put_idxs" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
|
@tianyu-l this should address the inconsistent striding issue in full-graph capture with async-TP. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a more explicit test? Take a graph, run a custom FX graph pass on it that changes strides, ensure that the output still has the same stride.
…put_idxs" ## Context Currently, there's no guarantee that post_grad passes won't change the strides of user-visible outputs. This breaks certain guarantees in Inductor. To fix this, we can track the strides of user-visible outputs before the post_grad passes, and fix them up with `inductor_prims.force_stride_order` after the post_grad pass. However, there's another issue - currently, user-visible outputs are tracked by their names. There's also no guarantee that post_grad pass won't change the output names (in fact, this is likely to happen). So we need to track the track user-visible outputs with a more stable manner. ## This PR Instead of tracking user-visible outputs by their name, track them by their indices (this is also consistent with `static_input_idxs`). cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! I had discussed this issue (post grad passes changing strides) with @zou3519 so great to see a fix for it.
We also want to do this for custom_operators. Similar to @Chillee's comment, maybe it sense to have put the metadata on the output node of the graph corresponding to its inputs. We could put metadata on custom ops corresponding to their inputs in the same way.
cc @zou3519 for thoughts
|
I'm not sure if the metadata should go in the graph's metadata or directly on the output node. Putting the metadata on the node lines up with what we want to do with custom ops, but then it runs the risk that someone may not preserve the metadata. Putting the output metadata in the graph means that it'll always be there. Also, one interesting case is what happens if the strides contain symints. Maybe this is already handled, but it seems tricky to ensure that the symints get codegenned correctly. |
On the "node" here means the output node. The output node is more or less guaranteed to remain unchanged I think. |
|
And @zou3519 symints should already be part of the lowering. Maybe there are fixes to be had but shouldn't be changing as part of this PR. |
|
Output node sounds good to me then |
|
Thanks for the input @eellison @zou3519!
It seems that in
nvm I see we meant the output node. |
…put_idxs" ## Context Currently, there's no guarantee that post_grad passes won't change the strides of user-visible outputs. This breaks certain guarantees in Inductor. To fix this, we can track the strides of user-visible outputs before the post_grad passes, and fix them up with `inductor_prims.force_stride_order` after the post_grad pass. However, there's another issue - currently, user-visible outputs are tracked by their names. There's also no guarantee that post_grad pass won't change the output names (in fact, this is likely to happen). So we need to track the track user-visible outputs with a more stable manner. ## This PR Instead of tracking user-visible outputs by their name, track them by their indices (this is also consistent with `static_input_idxs`). cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
…put_idxs" ## Context Currently, there's no guarantee that post_grad passes won't change the strides of user-visible outputs. This breaks certain guarantees in Inductor. To fix this, we can track the strides of user-visible outputs before the post_grad passes, and fix them up with `inductor_prims.force_stride_order` after the post_grad pass. However, there's another issue - currently, user-visible outputs are tracked by their names. There's also no guarantee that post_grad pass won't change the output names (in fact, this is likely to happen). So we need to track the track user-visible outputs with a more stable manner. ## This PR Instead of tracking user-visible outputs by their name, track them by their indices (this is also consistent with `static_input_idxs`). cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
…isible outputs" ## Context Previously, the stride preservation of user-visible nodes worked as follows: - After joint-graph tracing, we recorded the **names** of user-visible nodes and passed them to GraphLowering. - In GraphLowering, we determined whether we needed to preserve the striding for a certain node by checking if the node's name was in `user_visible_outputs`. - We obtained the original strides by checking `node.meta["val"].stride()`. However, there's a problem with this approach: the nodes in output_node.args[0] and their strides could change between the completion of joint-graph tracing and the consumption of `user_visible_outputs` (e.g., during post-grad passes), making it unreliable. ## This PR - After joint graph tracing: - Record the original strides for all nodes in `output_nodes.args[0]` as `output_node.meta["original_output_strides"]` (recording for all nodes in case we need the info for other purposes such as debugging). - Record the indices of user-visible outputs as `output_node.meta["user_visible_output_idxs"]`. - Remove the original plumbing of `user_visible_outputs`. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov rec [ghstack-poisoned]
@Chillee added |
…isible outputs" ## Context Previously, the stride preservation of user-visible nodes worked as follows: - After joint-graph tracing, we recorded the **names** of user-visible nodes and passed them to GraphLowering. - In GraphLowering, we determined whether we needed to preserve the striding for a certain node by checking if the node's name was in `user_visible_outputs`. - We obtained the original strides by checking `node.meta["val"].stride()`. However, there's a problem with this approach: the nodes in output_node.args[0] and their strides could change between the completion of joint-graph tracing and the consumption of `user_visible_outputs` (e.g., during post-grad passes), making it unreliable. ## This PR - After joint graph tracing: - Record the original strides for all nodes in `output_nodes.args[0]` as `output_node.meta["original_output_strides"]` (recording for all nodes in case we need the info for other purposes such as debugging). - Record the indices of user-visible outputs as `output_node.meta["user_visible_output_idxs"]`. - Remove the original plumbing of `user_visible_outputs`. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov rec [ghstack-poisoned]
…isible outputs" ## Context Previously, the stride preservation of user-visible nodes worked as follows: - After joint-graph tracing, we recorded the **names** of user-visible nodes and passed them to GraphLowering. - In GraphLowering, we determined whether we needed to preserve the striding for a certain node by checking if the node's name was in `user_visible_outputs`. - We obtained the original strides by checking `node.meta["val"].stride()`. However, there's a problem with this approach: the nodes in output_node.args[0] and their strides could change between the completion of joint-graph tracing and the consumption of `user_visible_outputs` (e.g., during post-grad passes), making it unreliable. ## This PR - After joint graph tracing: - Record the original strides for all nodes in `output_nodes.args[0]` as `output_node.meta["original_output_strides"]` (recording for all nodes in case we need the info for other purposes such as debugging). - Record the indices of user-visible outputs as `output_node.meta["user_visible_output_idxs"]`. - Remove the original plumbing of `user_visible_outputs`. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov rec [ghstack-poisoned]
…isible outputs" ## Context Previously, the stride preservation of user-visible nodes worked as follows: - After joint-graph tracing, we recorded the **names** of user-visible nodes and passed them to GraphLowering. - In GraphLowering, we determined whether we needed to preserve the striding for a certain node by checking if the node's name was in `user_visible_outputs`. - We obtained the original strides by checking `node.meta["val"].stride()`. However, there's a problem with this approach: the nodes in output_node.args[0] and their strides could change between the completion of joint-graph tracing and the consumption of `user_visible_outputs` (e.g., during post-grad passes), making it unreliable. ## This PR - After joint graph tracing: - Record the original strides for all nodes in `output_nodes.args[0]` as `output_node.meta["original_output_strides"]` (recording for all nodes in case we need the info for other purposes such as debugging). - Record the indices of user-visible outputs as `output_node.meta["user_visible_output_idxs"]`. - Remove the original plumbing of `user_visible_outputs`. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov rec [ghstack-poisoned]
…isible outputs" ## Context Previously, the stride preservation of user-visible nodes worked as follows: - After joint-graph tracing, we recorded the **names** of user-visible nodes and passed them to GraphLowering. - In GraphLowering, we determined whether we needed to preserve the striding for a certain node by checking if the node's name was in `user_visible_outputs`. - We obtained the original strides by checking `node.meta["val"].stride()`. However, there's a problem with this approach: the nodes in output_node.args[0] and their strides could change between the completion of joint-graph tracing and the consumption of `user_visible_outputs` (e.g., during post-grad passes), making it unreliable. ## This PR - After joint graph tracing: - Record the original strides for all nodes in `output_nodes.args[0]` as `output_node.meta["original_output_strides"]` (recording for all nodes in case we need the info for other purposes such as debugging). - Record the indices of user-visible outputs as `output_node.meta["user_visible_output_idxs"]`. - Remove the original plumbing of `user_visible_outputs`. cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov rec [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
@ezyang have you confirmed if this PR caused the regression? If so, I guess the efficiency of |
|
The change in compile time looks pretty stable here, I have not reproed locally but it's probably not hard. |
…utput handling (#139420) This PR fixes a compilation time regression manifested in timm_models/hrnet_w18 caused by #136732. The regression is reproducible locally. The compilation time is a bit noisy, but it's still possible to tell the difference. ``` Before the offending PR compilation_latency mean=176.022 seconds compilation_latency mean=176.564 seconds On the offending PR compilation_latency mean=180.096 seconds compilation_latency mean=179.101 seconds On the fix compilation_latency mean=173.153 seconds compilation_latency mean=174.182 seconds ``` (I think the fix being faster than the baseline is due to noise) The cause of the regression is an inefficiency in `is_user_visible_output()`. Specifically, it used `output_node.args[0].index(node)` to obtain the output idx for each node (and we called this for each node twice). The offending PR had the assumption that `len(output_node.args[0])` is rather small. However, it has been proven false by the benchmark (it was 1900+ for timm_models/hrnet_w18). The fix is to precompute `user_visible_output_strides` once by iterating only over the nodes in `output_node.args[0]`. Pull Request resolved: #139420 Approved by: https://github.com/ezyang
…utput handling (pytorch#139420) This PR fixes a compilation time regression manifested in timm_models/hrnet_w18 caused by pytorch#136732. The regression is reproducible locally. The compilation time is a bit noisy, but it's still possible to tell the difference. ``` Before the offending PR compilation_latency mean=176.022 seconds compilation_latency mean=176.564 seconds On the offending PR compilation_latency mean=180.096 seconds compilation_latency mean=179.101 seconds On the fix compilation_latency mean=173.153 seconds compilation_latency mean=174.182 seconds ``` (I think the fix being faster than the baseline is due to noise) The cause of the regression is an inefficiency in `is_user_visible_output()`. Specifically, it used `output_node.args[0].index(node)` to obtain the output idx for each node (and we called this for each node twice). The offending PR had the assumption that `len(output_node.args[0])` is rather small. However, it has been proven false by the benchmark (it was 1900+ for timm_models/hrnet_w18). The fix is to precompute `user_visible_output_strides` once by iterating only over the nodes in `output_node.args[0]`. Pull Request resolved: pytorch#139420 Approved by: https://github.com/ezyang

Stack from ghstack (oldest at bottom):
Context
Previously, the stride preservation of user-visible nodes worked as follows:
user_visible_outputs.node.meta["val"].stride().However, there's a problem with this approach: the nodes in output_node.args[0] and their strides could change between the completion of joint-graph tracing and the consumption of
user_visible_outputs(e.g., during post-grad passes), making it unreliable.This PR
output_nodes.args[0]asoutput_node.meta["original_output_strides"](recording for all nodes in case we need the info for other purposes such as debugging).output_node.meta["user_visible_output_idxs"].user_visible_outputs.cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @rec