KEMBAR78
Improve CUDA softmax performance by apaszke · Pull Request #4973 · pytorch/pytorch · GitHub
Skip to content

Conversation

@apaszke
Copy link
Contributor

@apaszke apaszke commented Jan 31, 2018

Simple fix with a very large perf benefit for smaller sizes. Below are some plots (dim_size = size of the softmaxed dimension, outer_size = batch size, z-axis = ratio of old time to new time). In general, as long as dim_size < 1024 you get at least a 2x speedup with this code, 4x if you fit in 256, and even 12x for sizes around 100 and smaller.

screen shot 2018-01-31 at 23 27 02

screen shot 2018-01-31 at 23 27 18

I tried playing with some other potential improvements like replacing the blockReduce function with a shuffle-based one, but it gave mixed results (-20% time in some cases, +20% time in other cases).

Thanks to @nikitakit for reporting #4893 (which is fixed in this PR).

@apaszke apaszke closed this Feb 1, 2018
@apaszke apaszke reopened this Feb 1, 2018
@apaszke apaszke closed this Feb 1, 2018
@apaszke apaszke reopened this Feb 1, 2018
@apaszke apaszke requested a review from colesbury February 1, 2018 16:51
@apaszke
Copy link
Contributor Author

apaszke commented Feb 1, 2018

All the build failures are spurious and unrelated to this PR

Copy link
Member

@colesbury colesbury left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice speed-up!


inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) {
uint64_t block_size = 1;
uint64_t max_block_size = std::min(dim_size / ILP, static_cast<uint64_t>(1024));

This comment was marked as off-topic.

This comment was marked as off-topic.

@apaszke apaszke merged commit 8e22f84 into master Feb 2, 2018
@apaszke apaszke deleted the softmax_speedup branch February 2, 2018 12:24
@soumith soumith added the 0.3.1 label Feb 5, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants