KEMBAR78
Add beginnings of torch::stable::accelerator by mikaylagawarecki · Pull Request #159679 · pytorch/pytorch · GitHub
Skip to content

Conversation

@mikaylagawarecki
Copy link
Contributor

@mikaylagawarecki mikaylagawarecki commented Aug 1, 2025

Adds

  • torch::stable::accelerator::DeviceGuard: std::unique_ptr to DeviceGuardOpauqe mostly copied from the below (but made generic)

    class AOTICudaGuard {
    public:
    AOTICudaGuard(int32_t device_index) : guard_(nullptr, delete_cuda_guard) {
    CUDAGuardHandle ptr = nullptr;
    AOTI_TORCH_ERROR_CODE_CHECK(
    aoti_torch_create_cuda_guard(device_index, &ptr));
    guard_.reset(ptr);
    }
    void set_index(int32_t device_index) {
    AOTI_TORCH_ERROR_CODE_CHECK(
    aoti_torch_cuda_guard_set_index(guard_.get(), device_index));
    }
    private:
    std::unique_ptr<CUDAGuardOpaque, DeleterFnPtr> guard_;
    };

    • constructor DeviceGuard(DeviceIndex) (this matches aoti but defers from the actual c10 DeviceGuard constructor that takes in device)
    • set_index(DeviceIndex)
  • torch::stable::accelerator::Stream: std::shared_ptr to StreamOpaque

    • constructor Stream(StreamHandle stream) (similar to torch::stable::Tensor)
    • id() -> StreamId
  • getCurrentStream(DeviceIndex device_index) -> stable::accelerator::Stream

Stack from ghstack (oldest at bottom):

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 1, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159679

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 5e8e158 with merge base 34ec5ed (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

mikaylagawarecki added a commit that referenced this pull request Aug 1, 2025
ghstack-source-id: 1530b24
Pull Request resolved: #159679
@github-actions
Copy link
Contributor

github-actions bot commented Aug 1, 2025

Attention! PyTorch one of the C-stable API file was changed

You MUST NOT change existing function declarations in this, as this header defines a stable C ABI. If you need to change the signature for a function, introduce a new v2 version of the function and modify code generation to target the new version of the function.


Caused by:

mikaylagawarecki added a commit that referenced this pull request Aug 1, 2025
ghstack-source-id: 2f2fa7b
Pull Request resolved: #159679
@mikaylagawarecki mikaylagawarecki changed the title Add beginning of torch::stable::accelerator Add beginnings of torch::stable::accelerator Aug 1, 2025
mikaylagawarecki added a commit that referenced this pull request Aug 4, 2025
ghstack-source-id: fae00fc
Pull Request resolved: #159679
mikaylagawarecki added a commit that referenced this pull request Aug 4, 2025
ghstack-source-id: 083b8bc
Pull Request resolved: #159679

using DeviceIndex = int8_t;
using StreamId = int64_t;
class DeviceGuard {
Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki Aug 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something that I'm not sure about -- do the copy / move semantics need to match the real device guard although this just a wrapper that is holding a std::unique_ptr to DeviceGuardOpaque?

~DeviceGuard() = default;
/// Copy is disallowed
DeviceGuard(const DeviceGuard&) = delete;
DeviceGuard& operator=(const DeviceGuard&) = delete;
/// Move is disallowed, as DeviceGuard does not have an uninitialized state,
/// which is required for moves on types with nontrivial destructors.
DeviceGuard(DeviceGuard&& other) = delete;
DeviceGuard& operator=(DeviceGuard&& other) = delete;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't need to copy the existing DeviceGuard. For example, I agree with deleting the default constructor for this API. We can disallow copy and move if it's easiest to maintain anyway.

mikaylagawarecki added a commit that referenced this pull request Aug 4, 2025
ghstack-source-id: afc07a8
Pull Request resolved: #159679
mikaylagawarecki added a commit that referenced this pull request Aug 4, 2025
ghstack-source-id: 23b8be1
Pull Request resolved: #159679
Adds 
- `torch::stable::accelerator::DeviceGuard`: `std::unique_ptr` to `DeviceGuardOpauqe` copied from 

   https://github.com/pytorch/pytorch/blob/50eac811a68e63e96ad56c11c983bfe298a0bb8a/torch/csrc/inductor/aoti_runtime/utils_cuda.h#L30-L46
    - constructor `DeviceGuard(DeviceIndex)` (this matches aoti but defers from the actual c10 DeviceGuard constructor that takes in device
    - `set_index(DeviceIndex)`
- `torch::stable::accelerator::Stream`: `std::shared_ptr` to `StreamOpaque`
     - constructor `Stream(StreamHandle stream)` (similar to torch::stable::Tensor)
     - `id() -> StreamId`
      
- `getCurrentStream(DeviceIndex device_index) -> stable::accelerator::Stream`





[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request Aug 4, 2025
ghstack-source-id: 655868b
Pull Request resolved: #159679
Adds 
- `torch::stable::accelerator::DeviceGuard`: `std::unique_ptr` to `DeviceGuardOpauqe` copied from 

   https://github.com/pytorch/pytorch/blob/50eac811a68e63e96ad56c11c983bfe298a0bb8a/torch/csrc/inductor/aoti_runtime/utils_cuda.h#L30-L46
    - constructor `DeviceGuard(DeviceIndex)` (this matches aoti but defers from the actual c10 DeviceGuard constructor that takes in device
    - `set_index(DeviceIndex)`
- `torch::stable::accelerator::Stream`: `std::shared_ptr` to `StreamOpaque`
     - constructor `Stream(StreamHandle stream)` (similar to torch::stable::Tensor)
     - `id() -> StreamId`
      
- `getCurrentStream(DeviceIndex device_index) -> stable::accelerator::Stream`





[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request Aug 4, 2025
ghstack-source-id: 94eaf15
Pull Request resolved: #159679
@albanD
Copy link
Collaborator

albanD commented Aug 4, 2025

FYI @guangyey @EikanWang in case you have some feedback here.

Adds 
- `torch::stable::accelerator::DeviceGuard`: `std::unique_ptr` to `DeviceGuardOpauqe` copied from 

   https://github.com/pytorch/pytorch/blob/50eac811a68e63e96ad56c11c983bfe298a0bb8a/torch/csrc/inductor/aoti_runtime/utils_cuda.h#L30-L46
    - constructor `DeviceGuard(DeviceIndex)` (this matches aoti but defers from the actual c10 DeviceGuard constructor that takes in device
    - `set_index(DeviceIndex)`
- `torch::stable::accelerator::Stream`: `std::shared_ptr` to `StreamOpaque`
     - constructor `Stream(StreamHandle stream)` (similar to torch::stable::Tensor)
     - `id() -> StreamId`
      
- `getCurrentStream(DeviceIndex device_index) -> stable::accelerator::Stream`





[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request Aug 4, 2025
ghstack-source-id: 2ede272
Pull Request resolved: #159679
Copy link
Collaborator

@guangyey guangyey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

@mikaylagawarecki mikaylagawarecki marked this pull request as ready for review August 5, 2025 17:31
if torch.cuda.is_available():
extra_compile_args["cxx"].append("-DUSE_CUDA")
extension = CUDAExtension

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm maybe this would be a good call to move the CUDA stuff into its own C++ file and build them separately..

Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I don't see the issue with the current approach, so gonna keep it as is unless there's something specific you're concerned about! :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No issues yet, but if we have more "cuda only" code in the future it may be a good way to split the ever growing kernel.cpp

});
}

AOTITorchError aoti_torch_create_device_guard(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +1643 to +1644
c10::Stream* stream_ptr = new c10::Stream(stream);
*ret_stream = reinterpret_cast<StreamHandle>(stream_ptr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating a new stream on the heap is gonna leak memory, we should be able to assign the stream in line 1642 into the pointer (tho idk the cast semantics).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhh after reading the rest of the code, I see what the user flow is intended to be. That said, I'm not sure if we want the semantics of get_current_stream to create a new stream on the heap (this is likely not expected from the user POV), and if we do end up sticking with this, we need to loudly document this as.

I think we do not want to mess with the memory of the stream at all...cc @albanD for thoughts

Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm looks like @albanD had no thoughts :b

I'm not sure what to do here, this one looks different from

AOTITorchError aoti_torch_get_current_cuda_stream(
int32_t device_index,
void** ret_stream) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
*(cudaStream_t*)(ret_stream) = at::cuda::getCurrentCUDAStream(device_index);
});
}

As that is just directly returning the id. However, the actual getCurrentStream returns a c10::Stream, so I'm trying to match the semantic in the stable ABI (so in future users can add other stream methods on a stream object)

c10::Stream getCurrentStream(c10::DeviceIndex device_index) {
const auto device_type = getAccelerator(true).value();
c10::impl::VirtualGuardImpl impl(device_type);
return impl.getStream({device_type, device_index});
}

I think at::acceerator::getCurrentStream is returning c10::Stream by value which is why we will need to create a new object on the heap if we want to return it to the caller and transfer ownership to the stable::Stream

I can add a comment for sure, but would be curious how else we can tweak this in a way that's more user/memory friendly, would it make more sense to just return the .id() directly?


using DeviceIndex = int8_t;
using StreamId = int64_t;
class DeviceGuard {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't need to copy the existing DeviceGuard. For example, I agree with deleting the default constructor for this API. We can disallow copy and move if it's easiest to maintain anyway.


// Construct a stable::Stream from a StreamHandle
// Steals ownership from the StreamHandle
explicit Stream(StreamHandle stream)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see, so the expected use case is:

  • user calls getCurrentStream which creates a Stream
  • then they can call id on it

@janeyx99
Copy link
Contributor

@pytorchbot merge -i

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 12, 2025
@pytorchmergebot
Copy link
Collaborator

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@@ -0,0 +1,71 @@
#pragma once

#include <torch/csrc/inductor/aoti_runtime/utils.h>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to self to remove when rebasing and use TORCH_ERROR_CODE_CHECK instead

Adds 
- `torch::stable::accelerator::DeviceGuard`: `std::unique_ptr` to `DeviceGuardOpauqe` mostly copied from the below (but made generic)

   https://github.com/pytorch/pytorch/blob/50eac811a68e63e96ad56c11c983bfe298a0bb8a/torch/csrc/inductor/aoti_runtime/utils_cuda.h#L30-L46
    - constructor `DeviceGuard(DeviceIndex)` (**this matches aoti but defers from the actual c10 DeviceGuard constructor that takes in device**)
    - `set_index(DeviceIndex)`
- `torch::stable::accelerator::Stream`: `std::shared_ptr` to `StreamOpaque`
     - constructor `Stream(StreamHandle stream)` (similar to torch::stable::Tensor)
     - `id() -> StreamId`
      
- `getCurrentStream(DeviceIndex device_index) -> stable::accelerator::Stream`





[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #160453

Adds 
- `torch::stable::accelerator::DeviceGuard`: `std::unique_ptr` to `DeviceGuardOpauqe` mostly copied from the below (but made generic)

   https://github.com/pytorch/pytorch/blob/50eac811a68e63e96ad56c11c983bfe298a0bb8a/torch/csrc/inductor/aoti_runtime/utils_cuda.h#L30-L46
    - constructor `DeviceGuard(DeviceIndex)` (**this matches aoti but defers from the actual c10 DeviceGuard constructor that takes in device**)
    - `set_index(DeviceIndex)`
- `torch::stable::accelerator::Stream`: `std::shared_ptr` to `StreamOpaque`
     - constructor `Stream(StreamHandle stream)` (similar to torch::stable::Tensor)
     - `id() -> StreamId`
      
- `getCurrentStream(DeviceIndex device_index) -> stable::accelerator::Stream`





[ghstack-poisoned]
Adds 
- `torch::stable::accelerator::DeviceGuard`: `std::unique_ptr` to `DeviceGuardOpauqe` mostly copied from the below (but made generic)

   https://github.com/pytorch/pytorch/blob/50eac811a68e63e96ad56c11c983bfe298a0bb8a/torch/csrc/inductor/aoti_runtime/utils_cuda.h#L30-L46
    - constructor `DeviceGuard(DeviceIndex)` (**this matches aoti but defers from the actual c10 DeviceGuard constructor that takes in device**)
    - `set_index(DeviceIndex)`
- `torch::stable::accelerator::Stream`: `std::shared_ptr` to `StreamOpaque`
     - constructor `Stream(StreamHandle stream)` (similar to torch::stable::Tensor)
     - `id() -> StreamId`
      
- `getCurrentStream(DeviceIndex device_index) -> stable::accelerator::Stream`





[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #160453

3 similar comments
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #160453

@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #160453

@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #160453

pytorchmergebot pushed a commit that referenced this pull request Aug 13, 2025
chuanhaozhuge pushed a commit that referenced this pull request Aug 14, 2025
Adds
- `torch::stable::accelerator::DeviceGuard`: `std::unique_ptr` to `DeviceGuardOpauqe` mostly copied from the below (but made generic)

   https://github.com/pytorch/pytorch/blob/50eac811a68e63e96ad56c11c983bfe298a0bb8a/torch/csrc/inductor/aoti_runtime/utils_cuda.h#L30-L46
    - constructor `DeviceGuard(DeviceIndex)` (**this matches aoti but defers from the actual c10 DeviceGuard constructor that takes in device**)
    - `set_index(DeviceIndex)`
- `torch::stable::accelerator::Stream`: `std::shared_ptr` to `StreamOpaque`
     - constructor `Stream(StreamHandle stream)` (similar to torch::stable::Tensor)
     - `id() -> StreamId`

- `getCurrentStream(DeviceIndex device_index) -> stable::accelerator::Stream`

Pull Request resolved: #159679
Approved by: https://github.com/guangyey, https://github.com/janeyx99
chuanhaozhuge pushed a commit that referenced this pull request Aug 14, 2025
chuanhaozhuge pushed a commit that referenced this pull request Aug 18, 2025
Adds
- `torch::stable::accelerator::DeviceGuard`: `std::unique_ptr` to `DeviceGuardOpauqe` mostly copied from the below (but made generic)

   https://github.com/pytorch/pytorch/blob/50eac811a68e63e96ad56c11c983bfe298a0bb8a/torch/csrc/inductor/aoti_runtime/utils_cuda.h#L30-L46
    - constructor `DeviceGuard(DeviceIndex)` (**this matches aoti but defers from the actual c10 DeviceGuard constructor that takes in device**)
    - `set_index(DeviceIndex)`
- `torch::stable::accelerator::Stream`: `std::shared_ptr` to `StreamOpaque`
     - constructor `Stream(StreamHandle stream)` (similar to torch::stable::Tensor)
     - `id() -> StreamId`

- `getCurrentStream(DeviceIndex device_index) -> stable::accelerator::Stream`

Pull Request resolved: #159679
Approved by: https://github.com/guangyey, https://github.com/janeyx99
chuanhaozhuge pushed a commit that referenced this pull request Aug 18, 2025
can-gaa-hou pushed a commit to can-gaa-hou/pytorch that referenced this pull request Aug 22, 2025
Adds
- `torch::stable::accelerator::DeviceGuard`: `std::unique_ptr` to `DeviceGuardOpauqe` mostly copied from the below (but made generic)

   https://github.com/pytorch/pytorch/blob/50eac811a68e63e96ad56c11c983bfe298a0bb8a/torch/csrc/inductor/aoti_runtime/utils_cuda.h#L30-L46
    - constructor `DeviceGuard(DeviceIndex)` (**this matches aoti but defers from the actual c10 DeviceGuard constructor that takes in device**)
    - `set_index(DeviceIndex)`
- `torch::stable::accelerator::Stream`: `std::shared_ptr` to `StreamOpaque`
     - constructor `Stream(StreamHandle stream)` (similar to torch::stable::Tensor)
     - `id() -> StreamId`

- `getCurrentStream(DeviceIndex device_index) -> stable::accelerator::Stream`

Pull Request resolved: pytorch#159679
Approved by: https://github.com/guangyey, https://github.com/janeyx99
can-gaa-hou pushed a commit to can-gaa-hou/pytorch that referenced this pull request Aug 22, 2025
@github-actions github-actions bot deleted the gh/mikaylagawarecki/332/head branch September 13, 2025 02:06
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Adds
- `torch::stable::accelerator::DeviceGuard`: `std::unique_ptr` to `DeviceGuardOpauqe` mostly copied from the below (but made generic)

   https://github.com/pytorch/pytorch/blob/50eac811a68e63e96ad56c11c983bfe298a0bb8a/torch/csrc/inductor/aoti_runtime/utils_cuda.h#L30-L46
    - constructor `DeviceGuard(DeviceIndex)` (**this matches aoti but defers from the actual c10 DeviceGuard constructor that takes in device**)
    - `set_index(DeviceIndex)`
- `torch::stable::accelerator::Stream`: `std::shared_ptr` to `StreamOpaque`
     - constructor `Stream(StreamHandle stream)` (similar to torch::stable::Tensor)
     - `id() -> StreamId`

- `getCurrentStream(DeviceIndex device_index) -> stable::accelerator::Stream`

Pull Request resolved: pytorch#159679
Approved by: https://github.com/guangyey, https://github.com/janeyx99
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants