-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
We have an open issue #52663 (and links therein) complaining about thrust/cub usage in libraries and applications linking to libtorch. In addition, with 1.8 release we had/have an issue within pytorch itself (because of cuda split build), where randperm didn't work (1.8), or randperm followed by sort didn't work (1.8.1).
At the root of those issues is cub and/or thrust's use of local static variables in templated functions, that doesn't work across multiple libraries. The reason why those issues appeared for cuda 11 only builds, and weren't a problem earlier is that with cuda 11 thrust mostly goes through cub implementations, and reuses cub's static local caches.
To stop the bleeding, we've put all the pytorch functions using cub/thrust into libtorch_cuda_cu (previously randperm was in libtorch_cuda_cpp), and wrapped cub in its own namespace (#55292). Unfortunately, thrust cannot be wrapped in its namespace, so that's not a complete solution. With this change, scripts using pytorch only shouldn't error out (fixing previous randperm + sort issues), and also extensions using cub/thrust can be linked to libtorch, as long as code calling into libtorch doesn't call functions that use thrust in their internal implementations. To make this solution robust, we should avoid using thrust altogether (or at least thrust functions that do create or query static local cache entries).
Additional argument for removing thrust calls is that typically thrust calls synchronize even when it's not needed.
To that effect, we are migrating thrust usages to cub:
- randperm Replace thrust with cub in randperm #53841
- sort
sort: Partially migrate from THC to ATen, replace the thrust path with cub #54626 Implement torch.sort with cub::DeviceSegmentedRadixSort #56821 - index_put Migrate thrust->cub for index put #55693, only for cuda 11.3+, previous versions of cub are too slow
- unique (no dim) [pytorch] fewer cuda sync in unique by using cub instead of thrust #57323
- masked_scatter Migrate masked_scatter to use cub instead of thrust #56750
we should scope which other functions use static local caches, and migrate them too. Also, we should not add new thrust usages.
Files where thrust sort functions remain (not exhaustive):
- TensorModeKernel.cu
- Embedding.cu (Migrate Embedding thrust sort to cub sort #62495, Embedding thrust->cub: unique #63042, Use cub 1.15's latest scan-by-key algorithm to replace thrust for Embedding.cu and EmbeddingBag.cu #66580)
- EmbeddingBackwardKernel.cu (EmbeddingBackward exclusive_scan thrust->cub #66566)
- EmbeddingBag.cu (EmbeddingBag sort thrust->cub #64498, Use cub 1.15's latest scan-by-key algorithm to replace thrust for Embedding.cu and EmbeddingBag.cu #66580)
- LegacyThrustHelper.cu (to host thrust functions until we can switch to cub 11.3, cub11.3 fixes perf bug that prevents cub use for relatively small sorts. This also affects Embedding.cu and EmbeddingBag.cu