-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[inductor][scheduler] reorder scheduler nodes after fusion to reduce peak memory #134874
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/134874
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 9e3ece0 with merge base 803ce50 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@xuanzhang816 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
e421103
to
4ffc8a0
Compare
@xuanzhang816 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
4ffc8a0
to
fed3d59
Compare
@xuanzhang816 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
fed3d59
to
3b6055f
Compare
@xuanzhang816 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
d1f6710
to
8787c12
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! leaving some comments regarding organization
c9061f5
to
6860243
Compare
5094109
to
ed625aa
Compare
@pytorchbot merge -f "unrelated failures" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR has caused CI failures in linux-aarch64 workflow. I've raised an issue #136464 |
Fixes #136464 introduced in #134874 Pull Request resolved: #136474 Approved by: https://github.com/malfet
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool ! Awesome to memory improvements already internally. Would be great if you could do a benchmark run, let me know if you're not aware how to. There were no major issues, just smaller things.
I would be curious if you have done any testing of the various methods. How they compare in terms of compilation time or memory and if any of the methods was strictly worse than the others.
I didnt get to the last few methods yet but submitting this review anyway and will do others in follow up.
selected_node = min( | ||
nodes_to_schedule, | ||
key=lambda node: ( | ||
max(live_memory + node.mpi_node.size, max_memory), | ||
node.mpi_node.size - node.mpi_node.memory_to_free, | ||
node.mpi_node.index, | ||
), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we use a heap here ? instead of sorting on each iteration. for elements which are changed you can remove and reinsert
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually thought about using a heap here but
live_memory
is updating from iteration to iteration,node.mpi_node.memory_to_free
could be changing from iteration to iteration since how much memory can be freed after scheduling a node depends on what nodes have already scheduled.
So if we were to use a heap, elements in the heap would need to be updated from iteration to iteration and it is not efficient to do so.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the reason this isn't as bad is because it's O(n) * O(max_concurrently_schedulable_nodes).
But, just to discuss - the live memory is the same for every node. If live_memory + node.mpi_node.size < max_memory
, does it actually matter what we schedule ? I'm not sure we need the max here: max(live_memory + node.mpi_node.size, max_memory)
.
I think you could conceivably avoid updating the live memory and just update the memory_to_free ? anyways not needed but thinking aloud.
benchmark runHere is the spreadsheet with all benchmark models from torchbench, timm, and huggingface.
compilation timeTimed and logged the runtime for huggingface models and uploaded the results here. The additional compile time is less than 1 second for all models; and much less than 1 second for most models. |
@xuanzhang816 in case you're not aware, we also have a dashboard that you can use to get numbers. There's more details on https://pytorch.org/docs/stable/torch.compiler_performance_dashboard.html. local works as well but compile time on dashboard can be more stable and requires a bit less work locally. |
Fixes pytorch#136464 introduced in pytorch#134874 Pull Request resolved: pytorch#136474 Approved by: https://github.com/malfet
@eellison I am addressing the comments locally one by one and resolving them here to keep track of which one I have done. I will create a new PR after I go through everything. |
…#136596) [Inductor UT] Generalize device-bias code introduced from #134874 and fix unexpected success test cases. Fix #136595 Pull Request resolved: #136596 Approved by: https://github.com/EikanWang, https://github.com/jansel Co-authored-by: Yu, Guangye <guangye.yu@intel.com>
@eellison To do this, I will need to setup ghstack, right? |
selected_node = min( | ||
nodes_to_schedule, | ||
key=lambda node: ( | ||
max(live_memory + node.mpi_node.size, max_memory), | ||
node.mpi_node.size - node.mpi_node.memory_to_free, | ||
node.mpi_node.index, | ||
), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the reason this isn't as bad is because it's O(n) * O(max_concurrently_schedulable_nodes).
But, just to discuss - the live memory is the same for every node. If live_memory + node.mpi_node.size < max_memory
, does it actually matter what we schedule ? I'm not sure we need the max here: max(live_memory + node.mpi_node.size, max_memory)
.
I think you could conceivably avoid updating the live memory and just update the memory_to_free ? anyways not needed but thinking aloud.
max(live_memory + node.mpi_node.size, max_memory), | ||
node.mpi_node.size - node.mpi_node.memory_to_free, | ||
node.mpi_node.index, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Surprising to me that the greedy algorithm is not to order by net change in memory.. This will choose the largest node that is below max memory.. then net difference.. would assume it would just order by net difference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the reason this isn't as bad is because it's O(n) * O(max_concurrently_schedulable_nodes).
yes!!
If
live_memory + node.mpi_node.size < max_memory
, does it actually matter what we schedule?
In this case, whichever you schedule would not affect the peak memory immediately at the current step, but can have an effect on later steps.
I'm not sure we need the max here:
max(live_memory + node.mpi_node.size, max_memory)
.
The algorithm (given in the paper) selects the next node to schedule based on two numbers for every node:
(1) memory after output allocation, call it mem1
, which is live_memory + node.mpi_node.size
;
(2) memory after the node is executed and its last used buffers are freed, call it mem2
.
Among all nodes whose mem1
is below max_memory
, it selects the node with the lowest mem2
; and if such nodes do not exist, it selects the node with lowest mem1
.
The first element in the tuple -- max(live_memory + node.mpi_node.size, max_memory)
-- is evaluated to be max_memory
for all nodes whose mem1
is below max_memory
. For tiebreaking, we then use the second element in the tuple, which is effectively mem2
:
mem2
is live_memory + node.mpi_node.size - node.mpi_node.memory_to_free
. Since live_memory
is the same for all nodes, we simply use node.mpi_node.size - node.mpi_node.memory_to_free
for comparison.
Hope this makes sense.
I think you could conceivably avoid updating the live memory and just update the memory_to_free ? anyways not needed but thinking aloud.
Does the explanation above help with this?
) Addressing additional comments given in PR #134874 Pull Request resolved: #137205 Approved by: https://github.com/eellison
Motivations:
A topological order of the scheduler nodes that optimize the liveness of buffers can reduce the peak memory utilization. This has been observed and studied e.g., here and here.
Solutions:
Results:
On some models we can reduce the peak memory significantly:
In the presence of compiler based AC, peak memory can be further reduced:
Here is an internal use case.
Other infos:
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang