KEMBAR78
wrap cudaStreamSynchronize calls by ngimel · Pull Request #61889 · pytorch/pytorch · GitHub
Skip to content

Conversation

@ngimel
Copy link
Collaborator

@ngimel ngimel commented Jul 20, 2021

This is a first step towards creating context manager that errors out on synchronizing calls.

@facebook-github-bot facebook-github-bot added oncall: jit Add this issue/PR to JIT oncall triage queue cla signed labels Jul 20, 2021
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jul 20, 2021

💊 CI failures summary and remediations

As of commit 5febc92 (more details on the Dr. CI page and at hud.pytorch.org/pr/61889):



🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_xla_linux_bionic_py3_6_clang9_test (1/1)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Jul 21 20:41:15 AssertionError: False is not tr...ot sizes torch.Size([5, 5, 5]) and torch.Size([]).
Jul 21 20:41:15   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py", line 780, in test_wrapper
Jul 21 20:41:15     return test(*args, **kwargs)
Jul 21 20:41:15   File "/var/lib/jenkins/workspace/xla/test/test_ops.py", line 411, in test_reference_eager
Jul 21 20:41:15     self.compare_with_eager_reference(op, sample_input)
Jul 21 20:41:15   File "/var/lib/jenkins/workspace/xla/test/test_ops.py", line 402, in compare_with_eager_reference
Jul 21 20:41:15     self.assertEqual(actual, expected, exact_dtype=True, exact_device=False)
Jul 21 20:41:15   File "/var/lib/jenkins/workspace/xla/test/pytorch_test_base.py", line 608, in assertEqual
Jul 21 20:41:15     return DeviceTypeTestBase.assertEqual(self, x, y, *args, **kwargs)
Jul 21 20:41:15   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_utils.py", line 1524, in assertEqual
Jul 21 20:41:15     super().assertTrue(result, msg=self._get_assert_msg(msg, debug_msg=debug_msg))
Jul 21 20:41:15 AssertionError: False is not true : Tensors failed to compare as equal!Attempted to compare equality of tensors with different sizes. Got sizes torch.Size([5, 5, 5]) and torch.Size([]).
Jul 21 20:41:15 
Jul 21 20:41:15 ----------------------------------------------------------------------
Jul 21 20:41:15 Ran 347 tests in 326.417s
Jul 21 20:41:15 
Jul 21 20:41:15 FAILED (failures=4)
Jul 21 20:41:15 
Jul 21 20:41:15 Generating XML reports...
Jul 21 20:41:16 + cleanup
Jul 21 20:41:16 + retcode=1
Jul 21 20:41:16 + set +x

XLA failure

Job pytorch_xla_linux_bionic_py3_6_clang9_test is failing. Please create an issue with title prefixed by [PT_BREAK] in pytorch/xla and link to to this PR. If you have questions, please reach out to @ailzhang / @dlibenzi / @JackCaoG.


🚧 1 fixed upstream failure:

These were probably caused by upstream breakages that were already fixed.

Please rebase on the viable/strict branch (expand for instructions)

If your commit is older than viable/strict, run these commands:

git fetch https://github.com/pytorch/pytorch viable/strict
git rebase FETCH_HEAD

Preview docs built from this PR

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@ngimel ngimel requested a review from ezyang July 20, 2021 16:25
Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handling the ifdef in one place is a nice bonus. Will you add a lint to flag bare occurrences of these calls?

@ngimel
Copy link
Collaborator Author

ngimel commented Jul 20, 2021

Yes, will add a lint!

@facebook-github-bot
Copy link
Contributor

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

1 similar comment
@facebook-github-bot
Copy link
Contributor

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ngimel merged this pull request in 6284d2a.

@byronyi
Copy link

byronyi commented Jul 27, 2021

@ngimel Any chance cudaStreamSynchronize could be completely removed from at::nonzero? We found it quite difficult to support for accelerators, including NVIDIA GPU.

Also cc @ezyang @ailzhang @asuhan @JackCaoG as we first identify this issue when supporting detection/segmentation models in PyTorch XLA but then find out it is mainly from the limitation (tensor shape must be concrete value) in PyTorch core.

@ezyang
Copy link
Contributor

ezyang commented Jul 27, 2021

Given the existing semantics of the operation, removing the synchronization is not possible.

What we should do, however, is support JAX's extension to the nonzero API https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.nonzero.html where an explicit size can be specified, giving an upper bound to the number of nonzero entries that will be returned (and zero padded if there aren't enough). Can you file an issue for this?

@ngimel
Copy link
Collaborator Author

ngimel commented Jul 27, 2021

We would need this extension not only for nonzero, but also for indexing ops with mask, the most common situation when people encounter this particular sync is out = x[mask]

@byronyi
Copy link

byronyi commented Jul 28, 2021

Given the existing semantics of the operation, removing the synchronization is not possible.

What we should do, however, is support JAX's extension to the nonzero API https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.nonzero.html where an explicit size can be specified, giving an upper bound to the number of nonzero entries that will be returned (and zero padded if there aren't enough). Can you file an issue for this?

Raised in #62320

@ngimel ngimel deleted the ngimel/wrap_cuda_calls branch December 26, 2021 06:44
jjsjann123 pushed a commit to jjsjann123/nvfuser that referenced this pull request Oct 29, 2022
Summary:
This is a first step towards creating context manager that errors out on synchronizing calls.

Pull Request resolved: pytorch/pytorch#61889

Reviewed By: albanD

Differential Revision: D29805280

Pulled By: ngimel

fbshipit-source-id: b66400fbe0941b7daa51e6b30abe27b9cccd4e8a
jjsjann123 pushed a commit to jjsjann123/nvfuser that referenced this pull request Nov 10, 2022
Summary:
This is a first step towards creating context manager that errors out on synchronizing calls.

Pull Request resolved: pytorch/pytorch#61889

Reviewed By: albanD

Differential Revision: D29805280

Pulled By: ngimel

fbshipit-source-id: b66400fbe0941b7daa51e6b30abe27b9cccd4e8a
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants