KEMBAR78
[inductor][scheduler] reorder scheduler nodes after fusion to reduce peak memory by xuanzhang816 · Pull Request #134874 · pytorch/pytorch · GitHub
Skip to content

Conversation

@xuanzhang816
Copy link
Contributor

@xuanzhang816 xuanzhang816 commented Aug 30, 2024

Motivations:
A topological order of the scheduler nodes that optimize the liveness of buffers can reduce the peak memory utilization. This has been observed and studied e.g., here and here.

Solutions:

  1. implement a peak memory estimator via liveness analysis
  2. implement a few memory aware topological sorting algorithms and pick the one with the lowest peak memory

Results:
On some models we can reduce the peak memory significantly:

model batch size peak_memory baseline peak_memory new ratio
alexnet 128 1.17 0.99 1.19
vgg16 64 4.10 3.57 1.15
DebertaV2ForQuestionAnswering 1 11.60 10.56 1.10

In the presence of compiler based AC, peak memory can be further reduced:

model batch size peak_memory baseline peak_memory new ratio
AlbertForMaskedLM 4 6.87 6.43 1.07
AlbertForQuestionAnswering 4 8.69 7.76 1.12
MobileBertForQuestionAnswering 128 4.67 3.90 1.20

Here is an internal use case.

Other infos:

  • neutral model runtime, because the the reordering happens after fusion. So memory saving is for free.
  • minimal compile time overhead as the algorithm is linear in the number of edges of the inductor graph. For all hugglingface benchmark models, the additional compile time is less than 1 second.
  • no peak memory regression since we only adopt a new order if the peak memory is reduced based on the estimator. However, the model is unaware of operators' working memories, but for large models, the working memory should be negligible. We haven't observed any significant regressions on all of our tests.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 30, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/134874

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 9e3ece0 with merge base 803ce50 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

@xuanzhang816 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@xuanzhang816 xuanzhang816 force-pushed the orm_implementation branch 2 times, most recently from e421103 to 4ffc8a0 Compare September 4, 2024 18:26
@facebook-github-bot
Copy link
Contributor

@xuanzhang816 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@xuanzhang816 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@xuanzhang816 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@xuanzhang816 xuanzhang816 force-pushed the orm_implementation branch 2 times, most recently from d1f6710 to 8787c12 Compare September 12, 2024 16:22
@xuanzhang816 xuanzhang816 marked this pull request as ready for review September 12, 2024 16:23
@xuanzhang816 xuanzhang816 requested a review from yf225 September 12, 2024 16:23
@xuanzhang816 xuanzhang816 changed the title Orm implementation [inductor][scheduler] reorder scheduler nodes after fusion to reduce peak memory Sep 12, 2024
Copy link
Contributor

@yf225 yf225 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! leaving some comments regarding organization

@yf225 yf225 requested a review from eellison September 16, 2024 22:07
@xuanzhang816
Copy link
Contributor Author

@pytorchbot merge -f "unrelated failures"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@robert-hardwick
Copy link
Collaborator

This PR has caused CI failures in linux-aarch64 workflow. I've raised an issue #136464

pytorchmergebot pushed a commit that referenced this pull request Sep 24, 2024
Fixes #136464 introduced in #134874

Pull Request resolved: #136474
Approved by: https://github.com/malfet
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool ! Awesome to memory improvements already internally. Would be great if you could do a benchmark run, let me know if you're not aware how to. There were no major issues, just smaller things.

I would be curious if you have done any testing of the various methods. How they compare in terms of compilation time or memory and if any of the methods was strictly worse than the others.

I didnt get to the last few methods yet but submitting this review anyway and will do others in follow up.

Comment on lines +346 to +353
selected_node = min(
nodes_to_schedule,
key=lambda node: (
max(live_memory + node.mpi_node.size, max_memory),
node.mpi_node.size - node.mpi_node.memory_to_free,
node.mpi_node.index,
),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use a heap here ? instead of sorting on each iteration. for elements which are changed you can remove and reinsert

Copy link
Contributor Author

@xuanzhang816 xuanzhang816 Sep 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually thought about using a heap here but

  • live_memory is updating from iteration to iteration,
  • node.mpi_node.memory_to_free could be changing from iteration to iteration since how much memory can be freed after scheduling a node depends on what nodes have already scheduled.

So if we were to use a heap, elements in the heap would need to be updated from iteration to iteration and it is not efficient to do so.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the reason this isn't as bad is because it's O(n) * O(max_concurrently_schedulable_nodes).

But, just to discuss - the live memory is the same for every node. If live_memory + node.mpi_node.size < max_memory, does it actually matter what we schedule ? I'm not sure we need the max here: max(live_memory + node.mpi_node.size, max_memory).

I think you could conceivably avoid updating the live memory and just update the memory_to_free ? anyways not needed but thinking aloud.

etaf added a commit that referenced this pull request Sep 25, 2024
fix an unexpected success test case.

ghstack-source-id: 523ef6a
Pull Request resolved: #136596
@xuanzhang816
Copy link
Contributor Author

@eellison

benchmark run

Here is the spreadsheet with all benchmark models from torchbench, timm, and huggingface.

  • I run each model twice or three times, and use the memory number from the last run to compare.
  • test two settings -- (1) default (2) compiler based auto activation checkpointing with memory budget set to 0.4.

compilation time

Timed and logged the runtime for huggingface models and uploaded the results here. The additional compile time is less than 1 second for all models; and much less than 1 second for most models.

@eellison
Copy link
Contributor

@xuanzhang816 in case you're not aware, we also have a dashboard that you can use to get numbers. bunny pt2dash . To submit a run of your own go to: https://github.com/pytorch/pytorch/actions/workflows/inductor-perf-test-nightly.yml . In some cases you'll want to rebase to a commit that has been recently benchmarked.

There's more details on https://pytorch.org/docs/stable/torch.compiler_performance_dashboard.html. local works as well but compile time on dashboard can be more stable and requires a bit less work locally.

BoyuanFeng pushed a commit to BoyuanFeng/pytorch that referenced this pull request Sep 25, 2024
@xuanzhang816
Copy link
Contributor Author

@eellison I am addressing the comments locally one by one and resolving them here to keep track of which one I have done. I will create a new PR after I go through everything.

pytorchmergebot pushed a commit that referenced this pull request Sep 26, 2024
…#136596)

[Inductor UT] Generalize device-bias code introduced from #134874 and fix unexpected success test cases.
Fix #136595

Pull Request resolved: #136596
Approved by: https://github.com/EikanWang, https://github.com/jansel

Co-authored-by: Yu, Guangye <guangye.yu@intel.com>
@xuanzhang816
Copy link
Contributor Author

@xuanzhang816 in case you're not aware, we also have a dashboard that you can use to get numbers. bunny pt2dash . To submit a run of your own go to: https://github.com/pytorch/pytorch/actions/workflows/inductor-perf-test-nightly.yml . In some cases you'll want to rebase to a commit that has been recently benchmarked.

There's more details on https://pytorch.org/docs/stable/torch.compiler_performance_dashboard.html. local works as well but compile time on dashboard can be more stable and requires a bit less work locally.

@eellison To do this, I will need to setup ghstack, right?

@xuanzhang816
Copy link
Contributor Author

@eellison I am addressing the comments locally one by one and resolving them here to keep track of which one I have done. I will create a new PR after I go through everything.

I submitted the restructured code here -- #137205

Comment on lines +346 to +353
selected_node = min(
nodes_to_schedule,
key=lambda node: (
max(live_memory + node.mpi_node.size, max_memory),
node.mpi_node.size - node.mpi_node.memory_to_free,
node.mpi_node.index,
),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the reason this isn't as bad is because it's O(n) * O(max_concurrently_schedulable_nodes).

But, just to discuss - the live memory is the same for every node. If live_memory + node.mpi_node.size < max_memory, does it actually matter what we schedule ? I'm not sure we need the max here: max(live_memory + node.mpi_node.size, max_memory).

I think you could conceivably avoid updating the live memory and just update the memory_to_free ? anyways not needed but thinking aloud.

Comment on lines +349 to +351
max(live_memory + node.mpi_node.size, max_memory),
node.mpi_node.size - node.mpi_node.memory_to_free,
node.mpi_node.index,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Surprising to me that the greedy algorithm is not to order by net change in memory.. This will choose the largest node that is below max memory.. then net difference.. would assume it would just order by net difference.

Copy link
Contributor Author

@xuanzhang816 xuanzhang816 Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the reason this isn't as bad is because it's O(n) * O(max_concurrently_schedulable_nodes).

yes!!

If live_memory + node.mpi_node.size < max_memory, does it actually matter what we schedule?

In this case, whichever you schedule would not affect the peak memory immediately at the current step, but can have an effect on later steps.

I'm not sure we need the max here: max(live_memory + node.mpi_node.size, max_memory).

The algorithm (given in the paper) selects the next node to schedule based on two numbers for every node:

(1) memory after output allocation, call it mem1, which is live_memory + node.mpi_node.size;
(2) memory after the node is executed and its last used buffers are freed, call it mem2.

Among all nodes whose mem1 is below max_memory, it selects the node with the lowest mem2; and if such nodes do not exist, it selects the node with lowest mem1.

The first element in the tuple -- max(live_memory + node.mpi_node.size, max_memory) -- is evaluated to be max_memory for all nodes whose mem1 is below max_memory. For tiebreaking, we then use the second element in the tuple, which is effectively mem2:

mem2 is live_memory + node.mpi_node.size - node.mpi_node.memory_to_free. Since live_memory is the same for all nodes, we simply use node.mpi_node.size - node.mpi_node.memory_to_free for comparison.

Hope this makes sense.

I think you could conceivably avoid updating the live memory and just update the memory_to_free ? anyways not needed but thinking aloud.

Does the explanation above help with this?

pytorchmergebot pushed a commit that referenced this pull request Oct 25, 2024
)

Addressing additional comments given in PR #134874

Pull Request resolved: #137205
Approved by: https://github.com/eellison
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants