KEMBAR78
Improve embedding_bag add kernel by jamesr66a · Pull Request #19329 · pytorch/pytorch · GitHub
Skip to content

Conversation

@jamesr66a
Copy link
Collaborator

@jamesr66a jamesr66a commented Apr 17, 2019

This PR makes it so that we call into the high-performance EmbeddingLookup function from C2 within embedding_bag for add mode. This is highly-optimized and does dynamic dispatch based on cpuid. It also makes it so that we elide creating the offset2bag tensor in the fast path case. That operation was contributing a significant portion of runtime to the operator.

Benchmark script (with a hack to make jit not DCE/CSE): https://gist.github.com/jamesr66a/73baed47400dcf2221bad996c4b57782

=== Baseline (master) ===

time_per_iter 8.726580142974853e-05
GB/s 4.588097438402839

== Test (this PR) ===

time_per_iter 2.5180697441101074e-05
GB/s 15.900433295643158

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jamesr66a has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@jamesr66a jamesr66a changed the title Don't use AXPY for embedding_bag add [WIP] Improve embedding_bag add kernel Apr 17, 2019
@jamesr66a jamesr66a force-pushed the no_axpy branch 4 times, most recently from 51acc47 to f9558a1 Compare April 17, 2019 22:25
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jamesr66a has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@jamesr66a jamesr66a force-pushed the no_axpy branch 2 times, most recently from 02bf05d to 0659677 Compare April 18, 2019 03:50
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jamesr66a has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, you don't use lengths here at all, do you?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a WIP. I'm trying to get the FBCode build disaster fixed before i start fixing the kernel implementation

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a TODO to just changed the underlying kernel to take offsets, not lengths?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EmbeddingLookup supports weighted version too - see weights argument, so let's call it here

@jamesr66a jamesr66a changed the title [WIP] Improve embedding_bag add kernel Improve embedding_bag add kernel Apr 18, 2019
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it assumes the offsets is non-empty. Probably usually the case, but better to work on empty batch too (seems like we'd need to fix functional.py too)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can probably save even more compute by skipping make_offset2bag vector in case we specialize here. Otherwise derivatives of offsets twice.

at least - add a TODO here so that the next person knows it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uh, I think i'm gonna do that. make_offset2bag takes a ton of runtime in my tests and if we can get rid of it we can probably squeeze more perf out of this. Thanks for the tip

Summary:
This was actually getting pretty poor throughput with respect to memory bandwidth. I used this test to measure the memory bandwidth specifically for the AXPY call: https://gist.github.com/jamesr66a/b27ff9ecbe036eed5ec310c0a3cc53c5

And I got ~8 GB/s before this change, but ~14 GB/s after this change.

This seems to speed up the operator overall by around 1.3x (benchmark: https://gist.github.com/jamesr66a/c533817c334d0be432720ef5e54a4166):

== Before ==

time_per_iter 0.0001298875093460083
GB/s 3.082544287868467

== After ==

time_per_iter 0.00010104801654815674
GB/s 3.9623142905451076

The large difference between the local BW increase and the full-op BW increase likely indicates significant time is being spent elsewhere in the op, so I will investigate that.

EDIT: I updated this PR to include a call into caffe2/perfkernels. This is the progression:

before

time_per_iter 8.983819484710693e-05
GB/s 4.456723564864611

After no axpy
time_per_iter 7.19951868057251e-05
GB/s 5.56126065872172

AFter perfkernels
time_per_iter 5.6699180603027346e-05
GB/s 7.061548257694262

After perfkernels no grad
time_per_iter 4.388842582702637e-05
GB/s 9.122769670026413
Pull Request resolved: pytorch#19329

Differential Revision: D14969630

fbshipit-source-id: 359b48eac218463d4ff13bdf22d31c70bf35281d
self._test_EmbeddingBag(False, 'sum', True)
self._test_EmbeddingBag(False, 'mean', True)
for dtype in [torch.double, torch.float]:
# TODO: figure out why backward on float breaks
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's "broken" on backward; it's just that the precision on the test needs to be adjusted because the gradients are so large. If this hypothesis is true, then test_embedding_bag should fail for mode=sum but not fail for mode={max,mean}.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 IME it fails for all off sum, max, and mean

Copy link
Contributor

@zou3519 zou3519 Apr 19, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the magnitude of the gradient for "max" and "mean"? If it is on the order of magnitude of 1 then there are definitely precision issues

const Tensor& offsets) {
int64_t ddim = src.size(1);
auto* scale_data = scale.data<float>();
auto select_indices_data = select_indices.data<int64_t>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually assert that these data types match somewhere?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cpuhrsch i think this does it: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/EmbeddingBag.cpp#L214

Also there's a PR out to fix that location: #19432


#include <TH/THBlasUtils.h>

#include <caffe2/perfkernels/embedding_lookup.h>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does the dispatch work? Keep in mind that this file isn't compile with avx.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a similar mechanism to aten/DISPATCH and does the right thing with different compiler options and runtime cpuid dispatch between them: https://github.com/pytorch/pytorch/blob/master/caffe2/perfkernels/common.h

we should bring them together, but it's for another day

@soumith
Copy link
Member

soumith commented Apr 19, 2019

this also calls caffe2::EmbeddingLookup for the GPU. Is the perf of the caffe2 kernel better or equal on GPU? (I remember we optimized the GPU kernel quite a bit on the ATen side)

@cpuhrsch
Copy link
Contributor

@jamesr66a - is it worthwhile to try a few more input shapes and number of threads for the benchmark?

@jamesr66a
Copy link
Collaborator Author

@soumith does this actually get called on GPU? I thought the only callsites for index_select_add and index_select_scale_add were in _embedding_bag_cpu

@soumith
Copy link
Member

soumith commented Apr 19, 2019

@jamesr66a great, it was my bad not to check that.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Apr 20, 2019
Summary:
This was actually getting pretty poor throughput with respect to memory bandwidth. I used this test to measure the memory bandwidth specifically for the AXPY call: https://gist.github.com/jamesr66a/b27ff9ecbe036eed5ec310c0a3cc53c5

And I got ~8 GB/s before this change, but ~14 GB/s after this change.

This seems to speed up the operator overall by around 1.3x (benchmark: https://gist.github.com/jamesr66a/c533817c334d0be432720ef5e54a4166):

== Before ==

time_per_iter 0.0001298875093460083
GB/s 3.082544287868467

== After ==

time_per_iter 0.00010104801654815674
GB/s 3.9623142905451076

The large difference between the local BW increase and the full-op BW increase likely indicates significant time is being spent elsewhere in the op, so I will investigate that.

EDIT: I updated this PR to include a call into caffe2/perfkernels. This is the progression:

before

time_per_iter 8.983819484710693e-05
GB/s 4.456723564864611

After no axpy
time_per_iter 7.19951868057251e-05
GB/s 5.56126065872172

AFter perfkernels
time_per_iter 5.6699180603027346e-05
GB/s 7.061548257694262

After perfkernels no grad
time_per_iter 4.388842582702637e-05
GB/s 9.122769670026413
Pull Request resolved: pytorch/pytorch#19329

Reviewed By: dzhulgakov

Differential Revision: D14969630

Pulled By: jamesr66a

fbshipit-source-id: 42d1015772c87bedd119e33c0aa2c8105160a738
@facebook-github-bot
Copy link
Contributor

@jamesr66a merged this pull request in d17c22d.

zhangguanheng66 pushed a commit to zhangguanheng66/pytorch that referenced this pull request May 6, 2019
Summary:
This was actually getting pretty poor throughput with respect to memory bandwidth. I used this test to measure the memory bandwidth specifically for the AXPY call: https://gist.github.com/jamesr66a/b27ff9ecbe036eed5ec310c0a3cc53c5

And I got ~8 GB/s before this change, but ~14 GB/s after this change.

This seems to speed up the operator overall by around 1.3x (benchmark: https://gist.github.com/jamesr66a/c533817c334d0be432720ef5e54a4166):

== Before ==

time_per_iter 0.0001298875093460083
GB/s 3.082544287868467

== After ==

time_per_iter 0.00010104801654815674
GB/s 3.9623142905451076

The large difference between the local BW increase and the full-op BW increase likely indicates significant time is being spent elsewhere in the op, so I will investigate that.

EDIT: I updated this PR to include a call into caffe2/perfkernels. This is the progression:

before

time_per_iter 8.983819484710693e-05
GB/s 4.456723564864611

After no axpy
time_per_iter 7.19951868057251e-05
GB/s 5.56126065872172

AFter perfkernels
time_per_iter 5.6699180603027346e-05
GB/s 7.061548257694262

After perfkernels no grad
time_per_iter 4.388842582702637e-05
GB/s 9.122769670026413
Pull Request resolved: pytorch#19329

Reviewed By: dzhulgakov

Differential Revision: D14969630

Pulled By: jamesr66a

fbshipit-source-id: 42d1015772c87bedd119e33c0aa2c8105160a738
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants