KEMBAR78
Fix DLPack stream logic. by ysiraichi · Pull Request #150217 · pytorch/pytorch · GitHub
Skip to content

Conversation

@ysiraichi
Copy link
Collaborator

@ysiraichi ysiraichi commented Mar 28, 2025

Stack from ghstack (oldest at bottom):

This PR fixes the logic for dealing with CUDA and ROCm streams whenever
we are trying to create a DLPack capsule from a tensor.

In summary, this PR:

  • Uses the legacy default stream if tensor.__dlpack__(stream=None) is
    called for a CUDA tensor.
  • Errors if tensor.__dlpack__(stream=2) is called for a CUDA tensor:
    PyTorch doesn't support the per-thread default stream.
  • Errors if tensor.__dlpack__(stream=stream), where stream is 1 or
    2, is called for a CUDA tensor using ROCm.

For more details, see the documentation.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 28, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/150217

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 5ffb2ee with merge base 7cc1a95 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Divigroup-RAP pushed a commit to Divigroup-RAP/PYTORCH that referenced this pull request Apr 22, 2025
This PR fixes the logic for dealing with CUDA and ROCm streams whenever
we are trying to create a DLPack capsule from a tensor.

In summary, this PR:

- Uses the legacy default stream if `tensor.__dlpack__(stream=None)` is
  called for a CUDA tensor.
- Errors if `tensor.__dlpack__(stream=2)` is called for a CUDA tensor:
  PyTorch doesn't support the per-thread default stream.
- Errors if `tensor.__dlpack__(stream=stream)`, where `stream` is 1 or
  2, is called for a CUDA tensor using ROCm.

For more details, see [the documentation][1].

[1]: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html

ghstack-source-id: cc0e31c
Pull Request resolved: pytorch/pytorch#150217
[ghstack-poisoned]
[ghstack-poisoned]
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds ok even though this doesn't fix the multi-device case.

torch/_tensor.py Outdated
elif stream is not None and stream != -1:
elif stream != -1:
if self.device.type == "cuda":
# NB: This logic handles the special case values for default
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No update to dlpack.py ? :D

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need. If stream is None, we still need to synchronize, assuming the legacy default stream.

torch/_tensor.py Outdated
if is_cuda and stream == 2:
raise BufferError("per-thread default stream is not supported.")

assert is_cuda or (is_rocm and stream not in (1, 2)), (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be a BufferError like above instead of AssertionError?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. The reason being that this assertion checks something the standard explicitly states as "unsupported" or "disallowed", i.e. something the consumer should know about. Moreover, the standard also says that:

Other errors are raised when export fails for other reasons (e.g., incorrect arguments passed or out of memory).

torch/_tensor.py Outdated
# Only synchronize on different streams
sync_stream = torch.cuda.current_stream()
if stream != sync_stream:
current_stream = torch.cuda.current_stream()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we care if self.device.index != torch.cuda.current_device() ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I think we should. I will add a check for that.

ysiraichi added 4 commits May 24, 2025 11:55
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
This PR fixes the logic for dealing with CUDA and ROCm streams whenever
we are trying to create a DLPack capsule from a tensor.

In summary, this PR:

- Uses the legacy default stream if `tensor.__dlpack__(stream=None)` is
  called for a CUDA tensor.
- Errors if `tensor.__dlpack__(stream=2)` is called for a CUDA tensor:
  PyTorch doesn't support the per-thread default stream.
- Errors if `tensor.__dlpack__(stream=stream)`, where `stream` is 1 or
  2, is called for a CUDA tensor using ROCm.

For more details, see [the documentation][1].

[1]: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #150691

This PR fixes the logic for dealing with CUDA and ROCm streams whenever
we are trying to create a DLPack capsule from a tensor.

In summary, this PR:

- Uses the legacy default stream if `tensor.__dlpack__(stream=None)` is
  called for a CUDA tensor.
- Errors if `tensor.__dlpack__(stream=2)` is called for a CUDA tensor:
  PyTorch doesn't support the per-thread default stream.
- Errors if `tensor.__dlpack__(stream=stream)`, where `stream` is 1 or
  2, is called for a CUDA tensor using ROCm.

For more details, see [the documentation][1].

[1]: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html

[ghstack-poisoned]
This PR fixes the logic for dealing with CUDA and ROCm streams whenever
we are trying to create a DLPack capsule from a tensor.

In summary, this PR:

- Uses the legacy default stream if `tensor.__dlpack__(stream=None)` is
  called for a CUDA tensor.
- Errors if `tensor.__dlpack__(stream=2)` is called for a CUDA tensor:
  PyTorch doesn't support the per-thread default stream.
- Errors if `tensor.__dlpack__(stream=stream)`, where `stream` is 1 or
  2, is called for a CUDA tensor using ROCm.

For more details, see [the documentation][1].

[1]: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #150691

pytorchmergebot pushed a commit that referenced this pull request Jul 20, 2025
This PR introduces the rest of the keyword-arguments added in DLPack
version 2023.12: `dl_device` and `copy`.

In summary, we handle these arguments in the C++ implementation of
`to_dlpack(...)` at _torch/csrc/Module.cpp_, by calling the
`maybeCopyTensor` function at _aten/src/ATen/DLConvertor.cpp_. It also
introduces the following changes:

- Add a new Python API `torchDeviceToDLDevice()`, which is simply a
  refactoring of the `getDLDevice()` function at
  _aten/src/ATen/DLConvertor.cpp_.
- Add both keyword-arguments to the `from_dlpack()` function at
  _torch/utils/dlpack.py_ and to the `Tensor.__dlpack__()` dunder
  method.
Pull Request resolved: #150218
Approved by: https://github.com/albanD
ghstack dependencies: #150216, #150217
pytorchmergebot pushed a commit that referenced this pull request Jul 20, 2025
This PR addresses the Array API documentation for [`__dlpack__`][1] and
[`from_dlpack`][2] by making some buffer-related errors `BufferError`
instead of `RuntimeError`, e.g. incompatible dtype, strides, or device.

[1]: https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html
[2]: https://data-apis.org/array-api/latest/API_specification/generated/array_api.from_dlpack.html#from-dlpack
Pull Request resolved: #150691
Approved by: https://github.com/Skylion007, https://github.com/albanD
ghstack dependencies: #150216, #150217, #150218
@github-actions github-actions bot deleted the gh/ysiraichi/85/head branch August 19, 2025 02:16
pytorchmergebot pushed a commit that referenced this pull request Sep 24, 2025
…capture (#163242)

Many extensions (including pybind helpers) call `Tensor.__dlpack__()` without a stream argument. Before #150217, `stream=None` behaved like “no cross-stream sync” and was safe inside CUDA Graph capture. After #150217, `stream=None` maps to the legacy default stream, adding a cross-stream wait that invalidates capture when running on a non-default stream.

See this example

```
import torch
s = torch.cuda.Stream()
x = torch.randn(8, device="cuda")
g = torch.cuda.CUDAGraph()

with torch.cuda.stream(s):
    with torch.cuda.graph(g):
        _ = x + 1
        cap = x.__dlpack__()
        _ = torch.utils.dlpack.from_dlpack(cap)
```

This PR partially reverts #150217 that stream=None defaults to no sync.

Pull Request resolved: #163242
Approved by: https://github.com/ngimel
jainapurva pushed a commit that referenced this pull request Sep 29, 2025
…capture (#163242)

Many extensions (including pybind helpers) call `Tensor.__dlpack__()` without a stream argument. Before #150217, `stream=None` behaved like “no cross-stream sync” and was safe inside CUDA Graph capture. After #150217, `stream=None` maps to the legacy default stream, adding a cross-stream wait that invalidates capture when running on a non-default stream.

See this example

```
import torch
s = torch.cuda.Stream()
x = torch.randn(8, device="cuda")
g = torch.cuda.CUDAGraph()

with torch.cuda.stream(s):
    with torch.cuda.graph(g):
        _ = x + 1
        cap = x.__dlpack__()
        _ = torch.utils.dlpack.from_dlpack(cap)
```

This PR partially reverts #150217 that stream=None defaults to no sync.

Pull Request resolved: #163242
Approved by: https://github.com/ngimel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants