-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Improve embedding_bag add kernel #19329
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jamesr66a has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
51acc47 to
f9558a1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jamesr66a has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
02bf05d to
0659677
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jamesr66a has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait, you don't use lengths here at all, do you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a WIP. I'm trying to get the FBCode build disaster fixed before i start fixing the kernel implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe add a TODO to just changed the underlying kernel to take offsets, not lengths?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
EmbeddingLookup supports weighted version too - see weights argument, so let's call it here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it assumes the offsets is non-empty. Probably usually the case, but better to work on empty batch too (seems like we'd need to fix functional.py too)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can probably save even more compute by skipping make_offset2bag vector in case we specialize here. Otherwise derivatives of offsets twice.
at least - add a TODO here so that the next person knows it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uh, I think i'm gonna do that. make_offset2bag takes a ton of runtime in my tests and if we can get rid of it we can probably squeeze more perf out of this. Thanks for the tip
Summary: This was actually getting pretty poor throughput with respect to memory bandwidth. I used this test to measure the memory bandwidth specifically for the AXPY call: https://gist.github.com/jamesr66a/b27ff9ecbe036eed5ec310c0a3cc53c5 And I got ~8 GB/s before this change, but ~14 GB/s after this change. This seems to speed up the operator overall by around 1.3x (benchmark: https://gist.github.com/jamesr66a/c533817c334d0be432720ef5e54a4166): == Before == time_per_iter 0.0001298875093460083 GB/s 3.082544287868467 == After == time_per_iter 0.00010104801654815674 GB/s 3.9623142905451076 The large difference between the local BW increase and the full-op BW increase likely indicates significant time is being spent elsewhere in the op, so I will investigate that. EDIT: I updated this PR to include a call into caffe2/perfkernels. This is the progression: before time_per_iter 8.983819484710693e-05 GB/s 4.456723564864611 After no axpy time_per_iter 7.19951868057251e-05 GB/s 5.56126065872172 AFter perfkernels time_per_iter 5.6699180603027346e-05 GB/s 7.061548257694262 After perfkernels no grad time_per_iter 4.388842582702637e-05 GB/s 9.122769670026413 Pull Request resolved: pytorch#19329 Differential Revision: D14969630 fbshipit-source-id: 359b48eac218463d4ff13bdf22d31c70bf35281d
| self._test_EmbeddingBag(False, 'sum', True) | ||
| self._test_EmbeddingBag(False, 'mean', True) | ||
| for dtype in [torch.double, torch.float]: | ||
| # TODO: figure out why backward on float breaks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it's "broken" on backward; it's just that the precision on the test needs to be adjusted because the gradients are so large. If this hypothesis is true, then test_embedding_bag should fail for mode=sum but not fail for mode={max,mean}.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zou3519 IME it fails for all off sum, max, and mean
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the magnitude of the gradient for "max" and "mean"? If it is on the order of magnitude of 1 then there are definitely precision issues
| const Tensor& offsets) { | ||
| int64_t ddim = src.size(1); | ||
| auto* scale_data = scale.data<float>(); | ||
| auto select_indices_data = select_indices.data<int64_t>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we actually assert that these data types match somewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cpuhrsch i think this does it: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/EmbeddingBag.cpp#L214
Also there's a PR out to fix that location: #19432
|
|
||
| #include <TH/THBlasUtils.h> | ||
|
|
||
| #include <caffe2/perfkernels/embedding_lookup.h> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does the dispatch work? Keep in mind that this file isn't compile with avx.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a similar mechanism to aten/DISPATCH and does the right thing with different compiler options and runtime cpuid dispatch between them: https://github.com/pytorch/pytorch/blob/master/caffe2/perfkernels/common.h
we should bring them together, but it's for another day
|
this also calls caffe2::EmbeddingLookup for the GPU. Is the perf of the caffe2 kernel better or equal on GPU? (I remember we optimized the GPU kernel quite a bit on the ATen side) |
|
@jamesr66a - is it worthwhile to try a few more input shapes and number of threads for the benchmark? |
|
@soumith does this actually get called on GPU? I thought the only callsites for |
|
@jamesr66a great, it was my bad not to check that. |
Summary: This was actually getting pretty poor throughput with respect to memory bandwidth. I used this test to measure the memory bandwidth specifically for the AXPY call: https://gist.github.com/jamesr66a/b27ff9ecbe036eed5ec310c0a3cc53c5 And I got ~8 GB/s before this change, but ~14 GB/s after this change. This seems to speed up the operator overall by around 1.3x (benchmark: https://gist.github.com/jamesr66a/c533817c334d0be432720ef5e54a4166): == Before == time_per_iter 0.0001298875093460083 GB/s 3.082544287868467 == After == time_per_iter 0.00010104801654815674 GB/s 3.9623142905451076 The large difference between the local BW increase and the full-op BW increase likely indicates significant time is being spent elsewhere in the op, so I will investigate that. EDIT: I updated this PR to include a call into caffe2/perfkernels. This is the progression: before time_per_iter 8.983819484710693e-05 GB/s 4.456723564864611 After no axpy time_per_iter 7.19951868057251e-05 GB/s 5.56126065872172 AFter perfkernels time_per_iter 5.6699180603027346e-05 GB/s 7.061548257694262 After perfkernels no grad time_per_iter 4.388842582702637e-05 GB/s 9.122769670026413 Pull Request resolved: pytorch/pytorch#19329 Reviewed By: dzhulgakov Differential Revision: D14969630 Pulled By: jamesr66a fbshipit-source-id: 42d1015772c87bedd119e33c0aa2c8105160a738
|
@jamesr66a merged this pull request in d17c22d. |
Summary: This was actually getting pretty poor throughput with respect to memory bandwidth. I used this test to measure the memory bandwidth specifically for the AXPY call: https://gist.github.com/jamesr66a/b27ff9ecbe036eed5ec310c0a3cc53c5 And I got ~8 GB/s before this change, but ~14 GB/s after this change. This seems to speed up the operator overall by around 1.3x (benchmark: https://gist.github.com/jamesr66a/c533817c334d0be432720ef5e54a4166): == Before == time_per_iter 0.0001298875093460083 GB/s 3.082544287868467 == After == time_per_iter 0.00010104801654815674 GB/s 3.9623142905451076 The large difference between the local BW increase and the full-op BW increase likely indicates significant time is being spent elsewhere in the op, so I will investigate that. EDIT: I updated this PR to include a call into caffe2/perfkernels. This is the progression: before time_per_iter 8.983819484710693e-05 GB/s 4.456723564864611 After no axpy time_per_iter 7.19951868057251e-05 GB/s 5.56126065872172 AFter perfkernels time_per_iter 5.6699180603027346e-05 GB/s 7.061548257694262 After perfkernels no grad time_per_iter 4.388842582702637e-05 GB/s 9.122769670026413 Pull Request resolved: pytorch#19329 Reviewed By: dzhulgakov Differential Revision: D14969630 Pulled By: jamesr66a fbshipit-source-id: 42d1015772c87bedd119e33c0aa2c8105160a738
This PR makes it so that we call into the high-performance
EmbeddingLookupfunction from C2 within embedding_bag foraddmode. This is highly-optimized and does dynamic dispatch based on cpuid. It also makes it so that we elide creating the offset2bag tensor in the fast path case. That operation was contributing a significant portion of runtime to the operator.Benchmark script (with a hack to make jit not DCE/CSE): https://gist.github.com/jamesr66a/73baed47400dcf2221bad996c4b57782
=== Baseline (master) ===
== Test (this PR) ===