KEMBAR78
Generalize catArray for contiguous inputs and dim != 0 by jamesr66a · Pull Request #17032 · pytorch/pytorch · GitHub
Skip to content

Conversation

@jamesr66a
Copy link
Collaborator

I noticed that we were sinking a lot of time into cat operations in machine translation on CPU, and drilled down to us doing the cat element-by-element, even though all the inputs were contiguous. The reason was we were doing the cat along a dimension that was not 0, and that caused us to not use the fast memcpy branch. This PR generalizes that branch.

Quick benchmark script:

import torch, time

tensors = [torch.rand(6, 2, 1024) for i in range(5)]

NITER = 1000
s = time.time()
for i in range(NITER):
    torch.cat(tensors, dim=1)
print('time per iter ', (time.time() - s) / NITER)

Before:

time per iter  8.089399337768554e-05

After:

time per iter  2.183413505554199e-05

@cpuhrsch
Copy link
Contributor

Do our tests explicitly exercise this branch? If not, please add.

@jamesr66a
Copy link
Collaborator Author

jamesr66a commented Feb 13, 2019

I believe this test exercises the branch:

https://github.com/pytorch/pytorch/blob/master/test/test_torch.py#L4203

EDIT: looks like the inputs may be noncontiguous, let me dig deeper

EDIT2: I put a print in the branch and ran that test and it printed out, so looks like it's tested.

@cpuhrsch
Copy link
Contributor

Looks like it. Maybe double check and make sure the edge case you mentioned is covered as well.

In general we prefer to avoid modifying TH and instead porting over the function to aten. Maybe I could ask you to spend a bit of time on seeing how feasible that is?

@zdevito zdevito removed their request for review February 13, 2019 02:54
@gchanan
Copy link
Contributor

gchanan commented Feb 13, 2019

I have a slight preference not to port this at the same time as we make changes. It's harder to review and less obvious where the problem is if there's a bug report.

int64_t outer = 1, inner = 1;

// Outer is the product of dimensions from the left up to (and not
// including the concatenation dimension). This becomes the number of times
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you want the ')' after including.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jamesr66a is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Feb 15, 2019
Summary:
I noticed that we were sinking a lot of time into `cat` operations in machine translation on CPU, and drilled down to us doing the cat element-by-element, even though all the inputs were contiguous. The reason was we were doing the cat along a dimension that was not 0, and that caused us to not use the fast `memcpy` branch. This PR generalizes that branch.

Quick benchmark script:
```
import torch, time

tensors = [torch.rand(6, 2, 1024) for i in range(5)]

NITER = 1000
s = time.time()
for i in range(NITER):
    torch.cat(tensors, dim=1)
print('time per iter ', (time.time() - s) / NITER)
```

Before:
```
time per iter  8.089399337768554e-05
```

After:
```
time per iter  2.183413505554199e-05
```
Pull Request resolved: pytorch/pytorch#17032

Differential Revision: D14090038

Pulled By: jamesr66a

fbshipit-source-id: 2c733a84915896008ac95f2233f44894bd2573de
zou3519 added a commit to zou3519/pytorch that referenced this pull request Mar 12, 2019
Summary:
I noticed that we were sinking a lot of time into `cat` operations in machine translation on CPU, and drilled down to us doing the cat element-by-element, even though all the inputs were contiguous. The reason was we were doing the cat along a dimension that was not 0, and that caused us to not use the fast `memcpy` branch. This PR generalizes that branch.

Quick benchmark script:
```
import torch, time

tensors = [torch.rand(6, 2, 1024) for i in range(5)]

NITER = 1000
s = time.time()
for i in range(NITER):
    torch.cat(tensors, dim=1)
print('time per iter ', (time.time() - s) / NITER)
```

Before:
```
time per iter  8.089399337768554e-05
```

After:
```
time per iter  2.183413505554199e-05
```
Pull Request resolved: pytorch#17032

Differential Revision: D14090038

Pulled By: jamesr66a

fbshipit-source-id: 2c733a84915896008ac95f2233f44894bd2573de
@ezyang ezyang added the merged label Jun 25, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants