KEMBAR78
Support with statement on torch.Stream by guangyey · Pull Request #140138 · pytorch/pytorch · GitHub
Skip to content

Conversation

@guangyey
Copy link
Collaborator

@guangyey guangyey commented Nov 8, 2024

Stack from ghstack (oldest at bottom):

Motivation

We propose to support Python with statement on torch.Stream. This is a benefit for all accelerators when writing device-agnostic code. The device-specific stream will also be supported because they are generally derived from torch.Stream.

With this PR, we can do like this

s1= torch.Stream()
# Set s1 to the current stream
torch.accelerator.set_stream(s1)
with torch.Stream() as s2:
    # Inside with statement, we set s2 to the current stream
    assert torch.accelerator.current_stream() == s2
# Here the current stream should be s1
assert torch.accelerator.current_stream() == s1

cc @albanD @EikanWang

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 8, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140138

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure, 1 Unrelated Failure

As of commit fb3f110 with merge base 3beb700 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

guangyey added a commit that referenced this pull request Nov 8, 2024
ghstack-source-id: cd25eb6
Pull Request resolved: #140138
@guangyey guangyey changed the title Support with statement on torch.Stream [WIP] Support with statement on torch.Stream Nov 8, 2024
@guangyey guangyey marked this pull request as draft November 8, 2024 14:00
[ghstack-poisoned]
@guangyey guangyey added ciflow/xpu Run XPU CI tasks release notes: python_frontend python frontend release notes category labels Nov 28, 2024
guangyey added a commit that referenced this pull request Nov 28, 2024
ghstack-source-id: 4c39e55
Pull Request resolved: #140138
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@guangyey guangyey changed the title [WIP] Support with statement on torch.Stream Support with statement on torch.Stream Nov 29, 2024
@guangyey guangyey marked this pull request as ready for review November 29, 2024 07:55
[ghstack-poisoned]
[ghstack-poisoned]
@guangyey guangyey requested a review from albanD January 6, 2025 10:51
guangyey added a commit that referenced this pull request Jan 6, 2025
ghstack-source-id: c51a38d
Pull Request resolved: #140138

g#	modified:   test/test_accelerator.py
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would still leak if the second raises an error, here is an updated version.

Comment on lines 313 to 326
PyObject* ctx_stream = nullptr;
if (PyDict_GetItemStringRef(self->context, "_ctx_stream", &ctx_stream) < 0) {
throw python_error();
}
TORCH_CHECK(ctx_stream, "ctx_stream should be initialized.");
PyObject* ctx_device_index = nullptr;
if (PyDict_GetItemStringRef(
self->context, "_ctx_device_index", &ctx_device_index) < 0) {
throw python_error();
}
TORCH_CHECK(ctx_device_index, "ctx_device_index should be initialized.");
auto prev_stream = (THPStream*)(THPObjectPtr(ctx_stream).get());
auto prev_device_index =
THPUtils_unpackDeviceIndex(THPObjectPtr(ctx_device_index).get());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
PyObject* ctx_stream = nullptr;
if (PyDict_GetItemStringRef(self->context, "_ctx_stream", &ctx_stream) < 0) {
throw python_error();
}
TORCH_CHECK(ctx_stream, "ctx_stream should be initialized.");
PyObject* ctx_device_index = nullptr;
if (PyDict_GetItemStringRef(
self->context, "_ctx_device_index", &ctx_device_index) < 0) {
throw python_error();
}
TORCH_CHECK(ctx_device_index, "ctx_device_index should be initialized.");
auto prev_stream = (THPStream*)(THPObjectPtr(ctx_stream).get());
auto prev_device_index =
THPUtils_unpackDeviceIndex(THPObjectPtr(ctx_device_index).get());
THPObjectPtr ctx_stream;
if (PyDict_GetItemStringRef(self->context, "_ctx_stream", &ctx_stream.get()) < 0) {
throw python_error();
}
TORCH_INTERNAL_ERROR(ctx_stream.get(), "ctx_stream should be present on the context dict.");
THPObjectPtr ctx_device_index;
if (PyDict_GetItemStringRef(
self->context, "_ctx_device_index", &ctx_device_index.get()) < 0) {
throw python_error();
}
TORCH_CHECK(ctx_device_index.get(), "ctx_device_index should be present on the context dict.");
auto prev_stream = (THPStream*)(ctx_stream.get());
auto prev_device_index =
THPUtils_unpackDeviceIndex(ctx_device_index.get());

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Thanks~

Copy link
Collaborator Author

@guangyey guangyey Jan 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ctx_stream.get() will raise an error lvalue required as unary & operand because it returns a rvalue. So I changed a minor code to use a lvalue py_stream and passed its reference to PyDict_GetItemStringRef.

guangyey added a commit that referenced this pull request Jan 6, 2025
ghstack-source-id: e1e2c55
Pull Request resolved: #140138

g#	modified:   test/test_accelerator.py
@guangyey guangyey requested a review from albanD January 6, 2025 17:51
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
guangyey added a commit that referenced this pull request Jan 7, 2025
ghstack-source-id: cfbd12a
Pull Request resolved: #140138

g#	modified:   test/test_accelerator.py
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
guangyey added a commit that referenced this pull request Jan 8, 2025
ghstack-source-id: 50ef4dc
Pull Request resolved: #140138

g#	modified:   test/test_accelerator.py
[ghstack-poisoned]
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One minor nit on the clearing, sounds good otherwise!
Thanks for your patience working through cpython api fun!

Comment on lines 339 to 340
Py_DECREF(self->context);
self->context = nullptr;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Py_DECREF(self->context);
self->context = nullptr;
Py_CLEAR(self->context);

Very interesting read about that in https://github.com/python/cpython/blob/ea39c8b08d8f025273bfa5b7a95f7b5984dc1e86/Include/refcount.h#L416

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your suggestions. I am now more familiar with CPython:)

guangyey added a commit that referenced this pull request Jan 9, 2025
ghstack-source-id: ac1c365
Pull Request resolved: #140138

g#	modified:   test/test_accelerator.py
[ghstack-poisoned]
@guangyey
Copy link
Collaborator Author

"So happy to try to land this PR, the failure is irrelevant."
@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 2 checks: xpu / linux-jammy-xpu-2025.0-py3.9 / test (default, 3, 4, linux.idc.xpu), xpu / linux-jammy-xpu-2025.0-py3.9 / test (default, 4, 4, linux.idc.xpu)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Jan 17, 2025
# Motivation
In #137678, we help use the device-agnostic APIs to generalize distributed module. As this [comment](#137678 (comment)) said, we will use the with statement of `torch.Stream` once #140138 is landed.

Pull Request resolved: #144951
Approved by: https://github.com/kwen2501, https://github.com/albanD
@github-actions github-actions bot deleted the gh/guangyey/90/head branch February 12, 2025 02:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mps Run MPS tests (subset of trunk) ciflow/rocm Trigger "default" config CI on ROCm ciflow/trunk Trigger trunk jobs on your pull request ciflow/xpu Run XPU CI tasks Merged module: accelerator Issues related to the shared accelerator API open source release notes: python_frontend python frontend release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants