KEMBAR78
Fix dot reference checks by kundaMwiza · Pull Request #138596 · pytorch/pytorch · GitHub
Skip to content

Conversation

@kundaMwiza
Copy link
Collaborator

@kundaMwiza kundaMwiza commented Oct 22, 2024

dot reference implementation should be consistent with the cpu / cuda implementations since it may be used for meta dispatch

i.e.

import torch 
x = torch.tensor([1,2,3], dtype=torch.float32)
y = torch.tensor([4,5,6], dtype=torch.float16)
x.dot(y)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: dot : expected both vectors to have same dtype, but found Float and Half

However the below does not raise an exception

x.to("meta").dot(y.to("meta"))

Fixes #ISSUE_NUMBER

@kundaMwiza kundaMwiza requested a review from mruberry as a code owner October 22, 2024 15:51
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 22, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138596

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f38dd0f with merge base 10a34dc (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@kundaMwiza
Copy link
Collaborator Author

For some more context, see slack discussion: https://pytorch.slack.com/archives/C3PDTEV8E/p1728898894696069

@bdhirsh bdhirsh added the release notes: composability release notes category label Oct 23, 2024
Copy link
Contributor

@bdhirsh bdhirsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@kundaMwiza kundaMwiza force-pushed the mwizak/fix-dot-meta-py-decomp branch 2 times, most recently from 1ab84ec to b848e92 Compare October 25, 2024 11:20
@kundaMwiza
Copy link
Collaborator Author

kundaMwiza commented Oct 25, 2024

@bdhirsh I've updated the PR to only apply type promotion after the checks. This should hopefully fix the failing tests

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this going to skip the type promotion in the self.is_complex() path?

It might be cleaner just to take the exsting decomps, and wrap them in a new function that does the dtype checks first:


@elementwise_type_promotion_wrapper(
    type_promoting_args=("self", "other"),
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def dot_helper(self, other):
    ... # existing fn


@register_decomposition(aten.dot)
@out_wrapper()
def dot(self, other):
    
    torch._check(
        self.dtype == other.dtype,
        lambda: "dot : expected both vectors to have same dtype, but found "
        f"{self.dtype} and {other.dtype}",
    )
    return dot_helper(self, outer)

wdyt?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did it this way because the complex branch is effectively the same implementation as the C++ implementation (e.g. for dot CPU:

Tensor dot(const Tensor &self, const Tensor &other){
). So the computation and result dtype in this branch should be correct, as the torch.dot / vdot would do this if necessary.

The actual decomposition into an elementwise product followed by a reduction is only for the case when the inputs are real, so type promotion is applied.

That being said, if the computation and result types for dot, vdot and sum + reduction should be the same then it would be cleaner to do what you suggest - I just didn't want to make the assumption that they are.

Copy link
Collaborator Author

@kundaMwiza kundaMwiza Oct 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like dot_naive uses the same type promotion rules anyway:

scalar_t dot_naive(
.

I'll change the code to your suggestion

@kundaMwiza kundaMwiza force-pushed the mwizak/fix-dot-meta-py-decomp branch from 7ef5d6e to f38dd0f Compare October 25, 2024 19:43
@kundaMwiza
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 28, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@bdhirsh
Copy link
Contributor

bdhirsh commented Oct 28, 2024

Thanks!

rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Oct 29, 2024
dot reference implementation should be consistent with the cpu / cuda implementations since it may be used for meta dispatch

i.e.
```python
import torch
x = torch.tensor([1,2,3], dtype=torch.float32)
y = torch.tensor([4,5,6], dtype=torch.float16)
x.dot(y)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: dot : expected both vectors to have same dtype, but found Float and Half
```

However the below does not raise an exception
```python
x.to("meta").dot(y.to("meta"))
```
Fixes #ISSUE_NUMBER

Pull Request resolved: pytorch#138596
Approved by: https://github.com/bdhirsh
rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Nov 5, 2024
dot reference implementation should be consistent with the cpu / cuda implementations since it may be used for meta dispatch

i.e.
```python
import torch
x = torch.tensor([1,2,3], dtype=torch.float32)
y = torch.tensor([4,5,6], dtype=torch.float16)
x.dot(y)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: dot : expected both vectors to have same dtype, but found Float and Half
```

However the below does not raise an exception
```python
x.to("meta").dot(y.to("meta"))
```
Fixes #ISSUE_NUMBER

Pull Request resolved: pytorch#138596
Approved by: https://github.com/bdhirsh
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: composability release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants