-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Fix dot reference checks #138596
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix dot reference checks #138596
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138596
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit f38dd0f with merge base 10a34dc ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
For some more context, see slack discussion: https://pytorch.slack.com/archives/C3PDTEV8E/p1728898894696069 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
1ab84ec to
b848e92
Compare
|
@bdhirsh I've updated the PR to only apply type promotion after the checks. This should hopefully fix the failing tests |
torch/_refs/__init__.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this going to skip the type promotion in the self.is_complex() path?
It might be cleaner just to take the exsting decomps, and wrap them in a new function that does the dtype checks first:
@elementwise_type_promotion_wrapper(
type_promoting_args=("self", "other"),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def dot_helper(self, other):
... # existing fn
@register_decomposition(aten.dot)
@out_wrapper()
def dot(self, other):
torch._check(
self.dtype == other.dtype,
lambda: "dot : expected both vectors to have same dtype, but found "
f"{self.dtype} and {other.dtype}",
)
return dot_helper(self, outer)
wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did it this way because the complex branch is effectively the same implementation as the C++ implementation (e.g. for dot CPU:
pytorch/aten/src/ATen/native/Blas.cpp
Line 159 in 392221b
| Tensor dot(const Tensor &self, const Tensor &other){ |
The actual decomposition into an elementwise product followed by a reduction is only for the case when the inputs are real, so type promotion is applied.
That being said, if the computation and result types for dot, vdot and sum + reduction should be the same then it would be cleaner to do what you suggest - I just didn't want to make the assumption that they are.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like dot_naive uses the same type promotion rules anyway:
pytorch/aten/src/ATen/native/BlasKernel.cpp
Line 1030 in 392221b
| scalar_t dot_naive( |
I'll change the code to your suggestion
…lementations since it may be used for meta dispatch
7ef5d6e to
f38dd0f
Compare
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
Thanks! |
dot reference implementation should be consistent with the cpu / cuda implementations since it may be used for meta dispatch
i.e.
```python
import torch
x = torch.tensor([1,2,3], dtype=torch.float32)
y = torch.tensor([4,5,6], dtype=torch.float16)
x.dot(y)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: dot : expected both vectors to have same dtype, but found Float and Half
```
However the below does not raise an exception
```python
x.to("meta").dot(y.to("meta"))
```
Fixes #ISSUE_NUMBER
Pull Request resolved: pytorch#138596
Approved by: https://github.com/bdhirsh
dot reference implementation should be consistent with the cpu / cuda implementations since it may be used for meta dispatch
i.e.
```python
import torch
x = torch.tensor([1,2,3], dtype=torch.float32)
y = torch.tensor([4,5,6], dtype=torch.float16)
x.dot(y)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: dot : expected both vectors to have same dtype, but found Float and Half
```
However the below does not raise an exception
```python
x.to("meta").dot(y.to("meta"))
```
Fixes #ISSUE_NUMBER
Pull Request resolved: pytorch#138596
Approved by: https://github.com/bdhirsh
dot reference implementation should be consistent with the cpu / cuda implementations since it may be used for meta dispatch
i.e.
However the below does not raise an exception
Fixes #ISSUE_NUMBER