KEMBAR78
Optimize increment summations [Latest Nov 15] by laithsakka · Pull Request #140822 · pytorch/pytorch · GitHub
Skip to content

Conversation

@laithsakka
Copy link
Contributor

@laithsakka laithsakka commented Nov 15, 2024

Summary:
wins
on torchrec benchmark, for 2K nodes it save 40seconds
with the recent sympy changes (https://www.internalfb.com/diff/D65883538) we save around 13 second ( with the max opt on).

buck2 run fbcode//mode/opt fbcode//torchrec/distributed/tests:pt2_compile_benchmark -- --num-features=200

This diff optimizes construction expressions of the form
a+b+c... (all unique symbols).
which are very common in torchrec models.

How
Expressions of the form a+b+c are not optimized by add, the only needed optimization is sorting them.
If we have a+b+c and we are adding (d) to it, we can do a binary search to know
the position of (d) and avoid optimizing the new expression by passing the new order.

Extensions:

  1. support constant terms.
  2. support 10a+10b+.. (this will give even more wins will extend the support in second PR)

Differential Revision: D66008482

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @ezyang @SherlockNoMad @EikanWang @wenzhe-nrv @voznesenskym @penguinwu @Guobing-Chen @zhuhaozhe @blzheng @jiayisunx @chenyang78 @kadeng @chauhang @amjames

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 15, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140822

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure

As of commit a2b082e with merge base b740a1b (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor module: cpu CPU specific problem (e.g., perf, algorithm) release notes: fx release notes category labels Nov 15, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66008482

@laithsakka
Copy link
Contributor Author

offline discussion todos:

  1. document why we are using getattr(self, "_optimized_summation", False), on the symNode.
  2. add micro benchmark.

laithsakka added a commit to laithsakka/pytorch that referenced this pull request Nov 15, 2024
Summary:

**wins**
on torchrec benchmark, for 2K nodes it save 40seconds
with the recent sympy changes (https://www.internalfb.com/diff/D65883538) we save around 13 second ( with the max opt on). 
```
buck2 run fbcode//mode/opt fbcode//torchrec/distributed/tests:pt2_compile_benchmark -- --num-features=200
```
This diff optimizes construction expressions of the form 
a+b+c...  (all unique symbols). 
which are very common in torchrec models. 

**How**
Expressions of the form a+b+c are not optimized by add, the only needed optimization is sorting them.
If we have  a+b+c and we are adding (d) to it, we can do a binary search to know 
the position of (d) and avoid optimizing the new expression by passing the new order. 


**Extensions**:
1. support constant terms.
2. support 10a+10b+.. (this will give even more wins will extend the support in second PR)

Test Plan:
add tests
add benchmark 
run tests

Differential Revision: D66008482
laithsakka added a commit to laithsakka/pytorch that referenced this pull request Nov 15, 2024
Summary:

**wins**
on torchrec benchmark, for 2K nodes it save 40seconds
with the recent sympy changes (https://www.internalfb.com/diff/D65883538) we save around 13 second ( with the max opt on). 
```
buck2 run fbcode//mode/opt fbcode//torchrec/distributed/tests:pt2_compile_benchmark -- --num-features=200
```
This diff optimizes construction expressions of the form 
a+b+c...  (all unique symbols). 
which are very common in torchrec models. 

**How**
Expressions of the form a+b+c are not optimized by add, the only needed optimization is sorting them.
If we have  a+b+c and we are adding (d) to it, we can do a binary search to know 
the position of (d) and avoid optimizing the new expression by passing the new order. 


**Extensions**:
1. support constant terms.
2. support 10a+10b+.. (this will give even more wins will extend the support in second PR)

Test Plan:
add tests
add benchmark 
run tests

Differential Revision: D66008482
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66008482

laithsakka added a commit to laithsakka/pytorch that referenced this pull request Nov 15, 2024
Summary:

**wins**
on torchrec benchmark, for 2K nodes it save 40seconds
with the recent sympy changes (https://www.internalfb.com/diff/D65883538) we save around 13 second ( with the max opt on). 
```
buck2 run fbcode//mode/opt fbcode//torchrec/distributed/tests:pt2_compile_benchmark -- --num-features=200
```
This diff optimizes construction expressions of the form 
a+b+c...  (all unique symbols). 
which are very common in torchrec models. 

**How**
Expressions of the form a+b+c are not optimized by add, the only needed optimization is sorting them.
If we have  a+b+c and we are adding (d) to it, we can do a binary search to know 
the position of (d) and avoid optimizing the new expression by passing the new order. 


**Extensions**:
1. support constant terms.
2. support 10a+10b+.. (this will give even more wins will extend the support in second PR)

Test Plan:
add tests
add benchmark 
run tests

Differential Revision: D66008482
laithsakka added a commit to laithsakka/pytorch that referenced this pull request Nov 15, 2024
Summary:

**wins**
on torchrec benchmark, for 2K nodes it save 40seconds
with the recent sympy changes (https://www.internalfb.com/diff/D65883538) we save around 13 second ( with the max opt on). 
```
buck2 run fbcode//mode/opt fbcode//torchrec/distributed/tests:pt2_compile_benchmark -- --num-features=200
```
This diff optimizes construction expressions of the form 
a+b+c...  (all unique symbols). 
which are very common in torchrec models. 

**How**
Expressions of the form a+b+c are not optimized by add, the only needed optimization is sorting them.
If we have  a+b+c and we are adding (d) to it, we can do a binary search to know 
the position of (d) and avoid optimizing the new expression by passing the new order. 


**Extensions**:
1. support constant terms.
2. support 10a+10b+.. (this will give even more wins will extend the support in second PR)

Test Plan:
add tests
add benchmark 
run tests

Differential Revision: D66008482
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66008482

self.expr,
other.expr,
getattr(self, "_optimized_summation", False),
getattr(other, "_optimized_summation", False),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed this in person, where you defended the ad hoc getattr/setattr because it was only set and accessed from two places. I think my preferred way of documenting this sort of situation is to have a # Note [blah blah blah] in one location describing the situation (probably the comment below), and then referencing that note consistently at both use sites.

The important thing to document, which is not directly documented at either of these sites, is what exactly the invariant specified by optimized summation is.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's also important to annotate the field on SymNode, if only to make sure no one accidentally clobbers it if they are adding their own one off field. I am also still not all that happy with bodging it in SymNode but I will refrain from blocking on it unless I can think of a good alternative (besides rewriting Add from scratch).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think I would even be happy with "this is a subclass of Add and is identical to Add in all respects except it respects the optimized summation invariant". This would probably have some annoying side effects in other parts of the code but from a layering perspective it's much cleaner.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i tried the subclass it did not work, because we could get an add to co-live with custom add and they dont compare to be equal

rhs._args[0]
):
# (a0+a1) + (a2+a3) => (a0+a1+a2+a3)
return make_optimized(lhs._args + rhs._args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it cheap to test the other way too? (You have a cliff here if I accidentally swap the orders of the arguments to successive add here)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return RShift(a, b)


def _binary_search_insert_arg(ordered_args, new_arg):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert len(ordered_args) != 0

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks functionally correct, see also my comments on the other PR.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 17, 2024
laithsakka added a commit to laithsakka/pytorch that referenced this pull request Nov 19, 2024
Summary:

**wins**
on torchrec benchmark, for 2K nodes it save 40seconds
with the recent sympy changes (https://www.internalfb.com/diff/D65883538) we save around 13 second ( with the max opt on). 
```
buck2 run fbcode//mode/opt fbcode//torchrec/distributed/tests:pt2_compile_benchmark -- --num-features=200
```
This diff optimizes construction expressions of the form 
a+b+c...  (all unique symbols). 
which are very common in torchrec models. 

**How**
Expressions of the form a+b+c are not optimized by add, the only needed optimization is sorting them.
If we have  a+b+c and we are adding (d) to it, we can do a binary search to know 
the position of (d) and avoid optimizing the new expression by passing the new order. 


**Extensions**:
1. support constant terms.
2. support 10a+10b+.. (this will give even more wins will extend the support in second PR)

Test Plan:
add tests
add benchmark 
run tests

Reviewed By: ezyang

Differential Revision: D66008482
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66008482

laithsakka added a commit to laithsakka/pytorch that referenced this pull request Nov 19, 2024
Summary:

**wins**
on torchrec benchmark, for 2K nodes it save 40seconds
with the recent sympy changes (https://www.internalfb.com/diff/D65883538) we save around 13 second ( with the max opt on). 
```
buck2 run fbcode//mode/opt fbcode//torchrec/distributed/tests:pt2_compile_benchmark -- --num-features=200
```
This diff optimizes construction expressions of the form 
a+b+c...  (all unique symbols). 
which are very common in torchrec models. 

**How**
Expressions of the form a+b+c are not optimized by add, the only needed optimization is sorting them.
If we have  a+b+c and we are adding (d) to it, we can do a binary search to know 
the position of (d) and avoid optimizing the new expression by passing the new order. 


**Extensions**:
1. support constant terms.
2. support 10a+10b+.. (this will give even more wins will extend the support in second PR)

Test Plan:
add tests
add benchmark 
run tests

Reviewed By: ezyang

Differential Revision: D66008482
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66008482

laithsakka added a commit to laithsakka/pytorch that referenced this pull request Nov 19, 2024
Summary:

**wins**
on torchrec benchmark, for 2K nodes it save 40seconds
with the recent sympy changes (https://www.internalfb.com/diff/D65883538) we save around 13 second ( with the max opt on). 
```
buck2 run fbcode//mode/opt fbcode//torchrec/distributed/tests:pt2_compile_benchmark -- --num-features=200
```
This diff optimizes construction expressions of the form 
a+b+c...  (all unique symbols). 
which are very common in torchrec models. 

**How**
Expressions of the form a+b+c are not optimized by add, the only needed optimization is sorting them.
If we have  a+b+c and we are adding (d) to it, we can do a binary search to know 
the position of (d) and avoid optimizing the new expression by passing the new order. 


**Extensions**:
1. support constant terms.
2. support 10a+10b+.. (this will give even more wins will extend the support in second PR)

Test Plan:
add tests
add benchmark 
run tests

Reviewed By: ezyang

Differential Revision: D66008482
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66008482

@laithsakka
Copy link
Contributor Author

Address all comments

Summary:

**wins**
on torchrec benchmark, for 2K nodes it save 40seconds
with the recent sympy changes (https://www.internalfb.com/diff/D65883538) we save around 13 second ( with the max opt on). 
```
buck2 run fbcode//mode/opt fbcode//torchrec/distributed/tests:pt2_compile_benchmark -- --num-features=200
```
This diff optimizes construction expressions of the form 
a+b+c...  (all unique symbols). 
which are very common in torchrec models. 

**How**
Expressions of the form a+b+c are not optimized by add, the only needed optimization is sorting them.
If we have  a+b+c and we are adding (d) to it, we can do a binary search to know 
the position of (d) and avoid optimizing the new expression by passing the new order. 


**Extensions**:
1. support constant terms.
2. support 10a+10b+.. (this will give even more wins will extend the support in second PR)

Test Plan:
add tests
add benchmark 
run tests

Reviewed By: ezyang

Differential Revision: D66008482
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66008482

@laithsakka
Copy link
Contributor Author

@pytorchbot merge -f

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 20, 2024

❌ 🤖 pytorchbot command failed:

@pytorchbot merge: error: argument -f/--force: expected one argument

usage: @pytorchbot merge [-f MESSAGE | -i] [-ic] [-r [{viable/strict,main}]]

Try @pytorchbot --help for more info.

@laithsakka
Copy link
Contributor Author

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: Lint / lintrunner-noclang / linux-job

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Ryo-not-rio pushed a commit to Ryo-not-rio/pytorch that referenced this pull request Dec 2, 2024
Summary:
**wins**
on torchrec benchmark, for 2K nodes it save 40seconds
with the recent sympy changes (https://www.internalfb.com/diff/D65883538) we save around 13 second ( with the max opt on).
```
buck2 run fbcode//mode/opt fbcode//torchrec/distributed/tests:pt2_compile_benchmark -- --num-features=200
```
This diff optimizes construction expressions of the form
a+b+c...  (all unique symbols).
which are very common in torchrec models.

**How**
Expressions of the form a+b+c are not optimized by add, the only needed optimization is sorting them.
If we have  a+b+c and we are adding (d) to it, we can do a binary search to know
the position of (d) and avoid optimizing the new expression by passing the new order.

**Extensions**:
1. support constant terms.
2. support 10a+10b+.. (this will give even more wins will extend the support in second PR)

Differential Revision: D66008482

Pull Request resolved: pytorch#140822
Approved by: https://github.com/ezyang
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
Summary:
**wins**
on torchrec benchmark, for 2K nodes it save 40seconds
with the recent sympy changes (https://www.internalfb.com/diff/D65883538) we save around 13 second ( with the max opt on).
```
buck2 run fbcode//mode/opt fbcode//torchrec/distributed/tests:pt2_compile_benchmark -- --num-features=200
```
This diff optimizes construction expressions of the form
a+b+c...  (all unique symbols).
which are very common in torchrec models.

**How**
Expressions of the form a+b+c are not optimized by add, the only needed optimization is sorting them.
If we have  a+b+c and we are adding (d) to it, we can do a binary search to know
the position of (d) and avoid optimizing the new expression by passing the new order.

**Extensions**:
1. support constant terms.
2. support 10a+10b+.. (this will give even more wins will extend the support in second PR)

Differential Revision: D66008482

Pull Request resolved: pytorch#140822
Approved by: https://github.com/ezyang
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request fb-exported fx Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: dynamo release notes: fx release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants