KEMBAR78
Add dtype checks in meta dispatch for various ordering ops by matthewhagraphcore · Pull Request #159556 · pytorch/pytorch · GitHub
Skip to content

Conversation

@matthewhagraphcore
Copy link
Collaborator

@matthewhagraphcore matthewhagraphcore commented Jul 31, 2025

This adds data type checks for the unsupported bool and complex types for argmax/min topk, sort, minimum, maximum. As listed here:

https://github.com/pytorch/pytorch/blob/0a99b026d6bd0f67dc2c0a20fe3228ddc4144854/torch/testing/_internal/common_methods_invocations.py#L21076

Currently the ops will fail on CPU or CUDA calculation, rather than at meta dispatch stage as with for example max:

check_unsupported_complex("max()", self);
. This will catch it early.

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 31, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159556

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 8e848b7 with merge base ecde76c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@janeyx99
Copy link
Contributor

janeyx99 commented Aug 4, 2025

Test case?

@janeyx99 janeyx99 self-requested a review August 4, 2025 23:12
@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 4, 2025
@hameerabbasi
Copy link
Collaborator

Thanks for the PR -- I noticed that topk, sort, minimum, maximum all have this behavior. Perhaps it makes sense to extend this PR to those functions as well?

@hameerabbasi
Copy link
Collaborator

hameerabbasi commented Aug 5, 2025

Here are some minimal tests to get one started:

from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_device_type import ops

ordered_op_names = {'sort', 'topk', 'lt', 'argmin', 'le', 'ge', 'amax', 'maximum', 'minimum', 'clamp', 'amin', 'gt', 'ceil', 'argmax', 'floor'}
ordered_op_db = tuple(filter(lambda op: op.name in ordered_op_names, op_db))

...
    @ops(ordered_op_db, dtypes=[torch.complex32, torch.complex64, torch.complex128])
    def test_ordered_raises(self, device, dtype, op: OpInfo):
        sample_inputs = op.sample_inputs(device, dtype)

        for sample_input in sample_inputs:
            self.assertRaises(
                NotImplementedError,
                op,
                sample_input.input,
                *sample_input.args,
                **sample_input.kwargs,
            )

@matthewhagraphcore
Copy link
Collaborator Author

I've added the other functions. Sort had a TORCH_CHECK_VALUE instead of a TORCH_CHECK , so to make consistent with the others I've updated that, otherwise it raised a ValueError rather than a RuntimeError,

Not sure what that build error is about?

@matthewhagraphcore matthewhagraphcore changed the title Add dtype checks in meta dispatch for argmin/argmax Add dtype checks in meta dispatch for various ordering ops Aug 8, 2025
@janeyx99
Copy link
Contributor

janeyx99 commented Aug 8, 2025

CI failures are related

@hameerabbasi
Copy link
Collaborator

It seems some of the operations are implemented for bool but not complex -- we should not error on bool for those.

@matthewhagraphcore
Copy link
Collaborator Author

I've updated the operators and added a test so this should cover the correct ops now 🤞

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank u!

@janeyx99
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 13, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64 / test (default, 1, 3, macos-m1-stable)

Details for Dev Infra team Raised by workflow job

@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Aug 14, 2025
@matthewhagraphcore
Copy link
Collaborator Author

One test was failing due to the change I made to make the sort errors more consistent. I've updated this and it passes.

@janeyx99
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 14, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

can-gaa-hou pushed a commit to can-gaa-hou/pytorch that referenced this pull request Aug 22, 2025
…159556)

This adds data type checks for the unsupported bool and complex types for argmax/min topk, sort, minimum, maximum. As listed here:

https://github.com/pytorch/pytorch/blob/0a99b026d6bd0f67dc2c0a20fe3228ddc4144854/torch/testing/_internal/common_methods_invocations.py#L21076

Currently the ops will fail on CPU or CUDA calculation, rather than at meta dispatch stage as with for example max: https://github.com/pytorch/pytorch/blob/0a99b026d6bd0f67dc2c0a20fe3228ddc4144854/aten/src/ATen/native/TensorCompare.cpp#L285 . This will catch it early.

Pull Request resolved: pytorch#159556
Approved by: https://github.com/janeyx99
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…159556)

This adds data type checks for the unsupported bool and complex types for argmax/min topk, sort, minimum, maximum. As listed here:

https://github.com/pytorch/pytorch/blob/0a99b026d6bd0f67dc2c0a20fe3228ddc4144854/torch/testing/_internal/common_methods_invocations.py#L21076

Currently the ops will fail on CPU or CUDA calculation, rather than at meta dispatch stage as with for example max: https://github.com/pytorch/pytorch/blob/0a99b026d6bd0f67dc2c0a20fe3228ddc4144854/aten/src/ATen/native/TensorCompare.cpp#L285 . This will catch it early.

Pull Request resolved: pytorch#159556
Approved by: https://github.com/janeyx99
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: Meta API release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants