-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Generalize catArray for contiguous inputs and dim != 0 #17032
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Do our tests explicitly exercise this branch? If not, please add. |
|
I believe this test exercises the branch: https://github.com/pytorch/pytorch/blob/master/test/test_torch.py#L4203 EDIT: looks like the inputs may be noncontiguous, let me dig deeper EDIT2: I put a print in the branch and ran that test and it printed out, so looks like it's tested. |
|
Looks like it. Maybe double check and make sure the edge case you mentioned is covered as well. In general we prefer to avoid modifying TH and instead porting over the function to aten. Maybe I could ask you to spend a bit of time on seeing how feasible that is? |
|
I have a slight preference not to port this at the same time as we make changes. It's harder to review and less obvious where the problem is if there's a bug report. |
| int64_t outer = 1, inner = 1; | ||
|
|
||
| // Outer is the product of dimensions from the left up to (and not | ||
| // including the concatenation dimension). This becomes the number of times |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you want the ')' after including.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jamesr66a is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary:
I noticed that we were sinking a lot of time into `cat` operations in machine translation on CPU, and drilled down to us doing the cat element-by-element, even though all the inputs were contiguous. The reason was we were doing the cat along a dimension that was not 0, and that caused us to not use the fast `memcpy` branch. This PR generalizes that branch.
Quick benchmark script:
```
import torch, time
tensors = [torch.rand(6, 2, 1024) for i in range(5)]
NITER = 1000
s = time.time()
for i in range(NITER):
torch.cat(tensors, dim=1)
print('time per iter ', (time.time() - s) / NITER)
```
Before:
```
time per iter 8.089399337768554e-05
```
After:
```
time per iter 2.183413505554199e-05
```
Pull Request resolved: pytorch/pytorch#17032
Differential Revision: D14090038
Pulled By: jamesr66a
fbshipit-source-id: 2c733a84915896008ac95f2233f44894bd2573de
Summary:
I noticed that we were sinking a lot of time into `cat` operations in machine translation on CPU, and drilled down to us doing the cat element-by-element, even though all the inputs were contiguous. The reason was we were doing the cat along a dimension that was not 0, and that caused us to not use the fast `memcpy` branch. This PR generalizes that branch.
Quick benchmark script:
```
import torch, time
tensors = [torch.rand(6, 2, 1024) for i in range(5)]
NITER = 1000
s = time.time()
for i in range(NITER):
torch.cat(tensors, dim=1)
print('time per iter ', (time.time() - s) / NITER)
```
Before:
```
time per iter 8.089399337768554e-05
```
After:
```
time per iter 2.183413505554199e-05
```
Pull Request resolved: pytorch#17032
Differential Revision: D14090038
Pulled By: jamesr66a
fbshipit-source-id: 2c733a84915896008ac95f2233f44894bd2573de
I noticed that we were sinking a lot of time into
catoperations in machine translation on CPU, and drilled down to us doing the cat element-by-element, even though all the inputs were contiguous. The reason was we were doing the cat along a dimension that was not 0, and that caused us to not use the fastmemcpybranch. This PR generalizes that branch.Quick benchmark script:
Before:
After: