KEMBAR78
Reduce or eliminate thrust usage in pytorch · Issue #57505 · pytorch/pytorch · GitHub
Skip to content

Reduce or eliminate thrust usage in pytorch #57505

@ngimel

Description

@ngimel

We have an open issue #52663 (and links therein) complaining about thrust/cub usage in libraries and applications linking to libtorch. In addition, with 1.8 release we had/have an issue within pytorch itself (because of cuda split build), where randperm didn't work (1.8), or randperm followed by sort didn't work (1.8.1).
At the root of those issues is cub and/or thrust's use of local static variables in templated functions, that doesn't work across multiple libraries. The reason why those issues appeared for cuda 11 only builds, and weren't a problem earlier is that with cuda 11 thrust mostly goes through cub implementations, and reuses cub's static local caches.
To stop the bleeding, we've put all the pytorch functions using cub/thrust into libtorch_cuda_cu (previously randperm was in libtorch_cuda_cpp), and wrapped cub in its own namespace (#55292). Unfortunately, thrust cannot be wrapped in its namespace, so that's not a complete solution. With this change, scripts using pytorch only shouldn't error out (fixing previous randperm + sort issues), and also extensions using cub/thrust can be linked to libtorch, as long as code calling into libtorch doesn't call functions that use thrust in their internal implementations. To make this solution robust, we should avoid using thrust altogether (or at least thrust functions that do create or query static local cache entries).
Additional argument for removing thrust calls is that typically thrust calls synchronize even when it's not needed.
To that effect, we are migrating thrust usages to cub:

we should scope which other functions use static local caches, and migrate them too. Also, we should not add new thrust usages.

Files where thrust sort functions remain (not exhaustive):

cc @ngimel @malfet @zasdfgbnm

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cudaRelated to torch.cuda, and CUDA support in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions