KEMBAR78
[Inductor] improve the stride preservation logic of user-visible outputs by yifuwang · Pull Request #136732 · pytorch/pytorch · GitHub
Skip to content

Conversation

@yifuwang
Copy link
Collaborator

@yifuwang yifuwang commented Sep 26, 2024

Stack from ghstack (oldest at bottom):

Context

Previously, the stride preservation of user-visible nodes worked as follows:

  • After joint-graph tracing, we recorded the names of user-visible nodes and passed them to GraphLowering.
  • In GraphLowering, we determined whether we needed to preserve the striding for a certain node by checking if the node's name was in user_visible_outputs.
  • We obtained the original strides by checking node.meta["val"].stride().

However, there's a problem with this approach: the nodes in output_node.args[0] and their strides could change between the completion of joint-graph tracing and the consumption of user_visible_outputs (e.g., during post-grad passes), making it unreliable.

This PR

  • After joint graph tracing:
    • Record the original strides for all nodes in output_nodes.args[0] as output_node.meta["original_output_strides"] (recording for all nodes in case we need the info for other purposes such as debugging).
    • Record the indices of user-visible outputs as output_node.meta["user_visible_output_idxs"].
  • Remove the original plumbing of user_visible_outputs.

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @rec

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136732

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 091c939 with merge base 5ea6777 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…put_idxs"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
…put_idxs"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
@yifuwang
Copy link
Collaborator Author

@tianyu-l this should address the inconsistent striding issue in full-graph capture with async-TP.

Copy link
Collaborator

@Chillee Chillee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a more explicit test? Take a graph, run a custom FX graph pass on it that changes strides, ensure that the output still has the same stride.

…put_idxs"


## Context

Currently, there's no guarantee that post_grad passes won't change the strides of user-visible outputs. This breaks certain guarantees in Inductor.

To fix this, we can track the strides of user-visible outputs before the post_grad passes, and fix them up with `inductor_prims.force_stride_order` after the post_grad pass.

However, there's another issue - currently, user-visible outputs are tracked by their names. There's also no guarantee that post_grad pass won't change the output names (in fact, this is likely to happen). So we need to track the track user-visible outputs with a more stable manner.

## This PR

Instead of tracking user-visible outputs by their name, track them by their indices (this is also consistent with `static_input_idxs`).

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I had discussed this issue (post grad passes changing strides) with @zou3519 so great to see a fix for it.

We also want to do this for custom_operators. Similar to @Chillee's comment, maybe it sense to have put the metadata on the output node of the graph corresponding to its inputs. We could put metadata on custom ops corresponding to their inputs in the same way.

cc @zou3519 for thoughts

@zou3519 zou3519 self-requested a review September 29, 2024 20:36
@zou3519
Copy link
Contributor

zou3519 commented Sep 30, 2024

I'm not sure if the metadata should go in the graph's metadata or directly on the output node. Putting the metadata on the node lines up with what we want to do with custom ops, but then it runs the risk that someone may not preserve the metadata. Putting the output metadata in the graph means that it'll always be there.

Also, one interesting case is what happens if the strides contain symints. Maybe this is already handled, but it seems tricky to ensure that the symints get codegenned correctly.

@eellison
Copy link
Contributor

I'm not sure if the metadata should go in the graph's metadata or directly on the output node. Putting the metadata on the node lines up with what we want to do with custom ops, but then it runs the risk that someone may not preserve the metadata. Putting the output metadata in the graph means that it'll always be there.

On the "node" here means the output node. The output node is more or less guaranteed to remain unchanged I think.

@eellison
Copy link
Contributor

And @zou3519 symints should already be part of the lowering. Maybe there are fixes to be had but shouldn't be changing as part of this PR.

@zou3519
Copy link
Contributor

zou3519 commented Sep 30, 2024

Output node sounds good to me then

@yifuwang
Copy link
Collaborator Author

yifuwang commented Sep 30, 2024

Thanks for the input @eellison @zou3519!

Also, one interesting case is what happens if the strides contain symints.

It seems that in GraphLowering, we simply skip the fixing up for such tensors: https://github.com/pytorch/pytorch/blob/main/torch/_inductor/graph.py#L1426

Output node sounds good to me then

Just to confirm - we want the metadata to be on the output node (i.e. the node whose op is "output"), not the args of output node, right?

nvm I see we meant the output node.

…put_idxs"


## Context

Currently, there's no guarantee that post_grad passes won't change the strides of user-visible outputs. This breaks certain guarantees in Inductor.

To fix this, we can track the strides of user-visible outputs before the post_grad passes, and fix them up with `inductor_prims.force_stride_order` after the post_grad pass.

However, there's another issue - currently, user-visible outputs are tracked by their names. There's also no guarantee that post_grad pass won't change the output names (in fact, this is likely to happen). So we need to track the track user-visible outputs with a more stable manner.

## This PR

Instead of tracking user-visible outputs by their name, track them by their indices (this is also consistent with `static_input_idxs`).

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
@yifuwang yifuwang mentioned this pull request Oct 24, 2024
…put_idxs"


## Context

Currently, there's no guarantee that post_grad passes won't change the strides of user-visible outputs. This breaks certain guarantees in Inductor.

To fix this, we can track the strides of user-visible outputs before the post_grad passes, and fix them up with `inductor_prims.force_stride_order` after the post_grad pass.

However, there's another issue - currently, user-visible outputs are tracked by their names. There's also no guarantee that post_grad pass won't change the output names (in fact, this is likely to happen). So we need to track the track user-visible outputs with a more stable manner.

## This PR

Instead of tracking user-visible outputs by their name, track them by their indices (this is also consistent with `static_input_idxs`).

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Oct 24, 2024
…isible outputs"


## Context

Previously, the stride preservation of user-visible nodes worked as follows:

- After joint-graph tracing, we recorded the **names** of user-visible nodes and passed them to GraphLowering.
- In GraphLowering, we determined whether we needed to preserve the striding for a certain node by checking if the node's name was in `user_visible_outputs`.
- We obtained the original strides by checking `node.meta["val"].stride()`.

However, there's a problem with this approach: the nodes in output_node.args[0] and their strides could change between the completion of joint-graph tracing and the consumption of `user_visible_outputs` (e.g., during post-grad passes), making it unreliable.

## This PR

- After joint graph tracing:
  - Record the original strides for all nodes in `output_nodes.args[0]` as `output_node.meta["original_output_strides"]` (recording for all nodes in case we need the info for other purposes such as debugging).
  - Record the indices of user-visible outputs as `output_node.meta["user_visible_output_idxs"]`.
- Remove the original plumbing of `user_visible_outputs`.

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov rec

[ghstack-poisoned]
@yifuwang
Copy link
Collaborator Author

Can you add a more explicit test? Take a graph, run a custom FX graph pass on it that changes strides, ensure that the output still has the same stride.

@Chillee added test_stride_preservation_with_stride_modifying_fx_pass.

…isible outputs"


## Context

Previously, the stride preservation of user-visible nodes worked as follows:

- After joint-graph tracing, we recorded the **names** of user-visible nodes and passed them to GraphLowering.
- In GraphLowering, we determined whether we needed to preserve the striding for a certain node by checking if the node's name was in `user_visible_outputs`.
- We obtained the original strides by checking `node.meta["val"].stride()`.

However, there's a problem with this approach: the nodes in output_node.args[0] and their strides could change between the completion of joint-graph tracing and the consumption of `user_visible_outputs` (e.g., during post-grad passes), making it unreliable.

## This PR

- After joint graph tracing:
  - Record the original strides for all nodes in `output_nodes.args[0]` as `output_node.meta["original_output_strides"]` (recording for all nodes in case we need the info for other purposes such as debugging).
  - Record the indices of user-visible outputs as `output_node.meta["user_visible_output_idxs"]`.
- Remove the original plumbing of `user_visible_outputs`.

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov rec

[ghstack-poisoned]
yifuwang pushed a commit that referenced this pull request Oct 24, 2024
…isible outputs"


## Context

Previously, the stride preservation of user-visible nodes worked as follows:

- After joint-graph tracing, we recorded the **names** of user-visible nodes and passed them to GraphLowering.
- In GraphLowering, we determined whether we needed to preserve the striding for a certain node by checking if the node's name was in `user_visible_outputs`.
- We obtained the original strides by checking `node.meta["val"].stride()`.

However, there's a problem with this approach: the nodes in output_node.args[0] and their strides could change between the completion of joint-graph tracing and the consumption of `user_visible_outputs` (e.g., during post-grad passes), making it unreliable.

## This PR

- After joint graph tracing:
  - Record the original strides for all nodes in `output_nodes.args[0]` as `output_node.meta["original_output_strides"]` (recording for all nodes in case we need the info for other purposes such as debugging).
  - Record the indices of user-visible outputs as `output_node.meta["user_visible_output_idxs"]`.
- Remove the original plumbing of `user_visible_outputs`.

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov rec

[ghstack-poisoned]
yifuwang pushed a commit that referenced this pull request Oct 25, 2024
…isible outputs"


## Context

Previously, the stride preservation of user-visible nodes worked as follows:

- After joint-graph tracing, we recorded the **names** of user-visible nodes and passed them to GraphLowering.
- In GraphLowering, we determined whether we needed to preserve the striding for a certain node by checking if the node's name was in `user_visible_outputs`.
- We obtained the original strides by checking `node.meta["val"].stride()`.

However, there's a problem with this approach: the nodes in output_node.args[0] and their strides could change between the completion of joint-graph tracing and the consumption of `user_visible_outputs` (e.g., during post-grad passes), making it unreliable.

## This PR

- After joint graph tracing:
  - Record the original strides for all nodes in `output_nodes.args[0]` as `output_node.meta["original_output_strides"]` (recording for all nodes in case we need the info for other purposes such as debugging).
  - Record the indices of user-visible outputs as `output_node.meta["user_visible_output_idxs"]`.
- Remove the original plumbing of `user_visible_outputs`.

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov rec

[ghstack-poisoned]
…isible outputs"


## Context

Previously, the stride preservation of user-visible nodes worked as follows:

- After joint-graph tracing, we recorded the **names** of user-visible nodes and passed them to GraphLowering.
- In GraphLowering, we determined whether we needed to preserve the striding for a certain node by checking if the node's name was in `user_visible_outputs`.
- We obtained the original strides by checking `node.meta["val"].stride()`.

However, there's a problem with this approach: the nodes in output_node.args[0] and their strides could change between the completion of joint-graph tracing and the consumption of `user_visible_outputs` (e.g., during post-grad passes), making it unreliable.

## This PR

- After joint graph tracing:
  - Record the original strides for all nodes in `output_nodes.args[0]` as `output_node.meta["original_output_strides"]` (recording for all nodes in case we need the info for other purposes such as debugging).
  - Record the indices of user-visible outputs as `output_node.meta["user_visible_output_idxs"]`.
- Remove the original plumbing of `user_visible_outputs`.

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov rec

[ghstack-poisoned]
@yifuwang
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 26, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@yifuwang
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@ezyang
Copy link
Contributor

ezyang commented Oct 28, 2024

Slight compile time regression on hrnet_w18 training:

image

@yifuwang
Copy link
Collaborator Author

@ezyang have you confirmed if this PR caused the regression? If so, I guess the efficiency of is_user_visible_output() could be improved.

@ezyang
Copy link
Contributor

ezyang commented Oct 29, 2024

The change in compile time looks pretty stable here, I have not reproed locally but it's probably not hard.

pytorchmergebot pushed a commit that referenced this pull request Nov 1, 2024
…utput handling (#139420)

This PR fixes a compilation time regression manifested in timm_models/hrnet_w18 caused by #136732.

The regression is reproducible locally. The compilation time is a bit noisy, but it's still possible to tell the difference.

```
Before the offending PR

compilation_latency mean=176.022 seconds
compilation_latency mean=176.564 seconds

On the offending PR

compilation_latency mean=180.096 seconds
compilation_latency mean=179.101 seconds

On the fix

compilation_latency mean=173.153 seconds
compilation_latency mean=174.182 seconds
```

(I think the fix being faster than the baseline is due to noise)

The cause of the regression is an inefficiency in `is_user_visible_output()`. Specifically, it used `output_node.args[0].index(node)` to obtain the output idx for each node (and we called this for each node twice). The offending PR had the assumption that `len(output_node.args[0])` is rather small. However, it has been proven false by the benchmark (it was 1900+ for timm_models/hrnet_w18).

The fix is to precompute `user_visible_output_strides` once by iterating only over the nodes in `output_node.args[0]`.

Pull Request resolved: #139420
Approved by: https://github.com/ezyang
rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Nov 5, 2024
…utput handling (pytorch#139420)

This PR fixes a compilation time regression manifested in timm_models/hrnet_w18 caused by pytorch#136732.

The regression is reproducible locally. The compilation time is a bit noisy, but it's still possible to tell the difference.

```
Before the offending PR

compilation_latency mean=176.022 seconds
compilation_latency mean=176.564 seconds

On the offending PR

compilation_latency mean=180.096 seconds
compilation_latency mean=179.101 seconds

On the fix

compilation_latency mean=173.153 seconds
compilation_latency mean=174.182 seconds
```

(I think the fix being faster than the baseline is due to noise)

The cause of the regression is an inefficiency in `is_user_visible_output()`. Specifically, it used `output_node.args[0].index(node)` to obtain the output idx for each node (and we called this for each node twice). The offending PR had the assumption that `len(output_node.args[0])` is rather small. However, it has been proven false by the benchmark (it was 1900+ for timm_models/hrnet_w18).

The fix is to precompute `user_visible_output_strides` once by iterating only over the nodes in `output_node.args[0]`.

Pull Request resolved: pytorch#139420
Approved by: https://github.com/ezyang
@github-actions github-actions bot deleted the gh/yifuwang/127/head branch November 29, 2024 02:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants