KEMBAR78
Improve error message for torch.binomial enforcing float inputs by michellemadubuike · Pull Request #157658 · pytorch/pytorch · GitHub
Skip to content

Conversation

@michellemadubuike
Copy link
Contributor

@michellemadubuike michellemadubuike commented Jul 5, 2025

Fixes #157195

Summary:

Fixed Issue 157195 by adding a new error message for torch.binomial in aten/src/ATen/native/Distributions.cpp

Explanation

According to the issue,

import torch
torch.binomial(torch.tensor([10]).long(), torch.tensor([0.5]))

RuntimeError: Found dtype Float but expected Long

It looks like we are getting a Tensor error rather than a binomial function error. Since the error is coming from pytorch/aten/src/ATen/TensorIterator.cpp, it seems like it is trying to align the tensor data to the same datatype for smooth tensor computations instead of giving a binomial function error.

I tried using both arguments as longs and both as ints and got the right binomial function error

torch.binomial(torch.tensor([10]).long(), torch.tensor([0.5]).long())
NotImplementedError: "binomial_cpu" not implemented for 'Long'
torch.binomial(torch.tensor([10.0]).int(), torch.tensor([0.5]).int())
NotImplementedError: "binomial_cpu" not implemented for 'Int'

But when I have both as different datatypes, the TensorIterator.cpp error comes back trying to align the datatypes.
RuntimeError: Found dtype Float but expected Long

I then tried finding where the NotImplementation Error was documented and found it in pytorch/aten/src/ATen/Dispatch.h in lines 193 - 211

#define AT_DISPATCH_SWITCH(TYPE, NAME, ...)                                 \
  [&] {                                                                     \
    const auto& the_type = TYPE;                                            \
    constexpr const char* at_dispatch_name = NAME;                          \
    /* don't use TYPE again in case it is an expensive or side-effect op */ \
    at::ScalarType _st = ::detail::scalar_type(the_type);                   \
    RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st);                    \
    switch (_st) {                                                          \
      __VA_ARGS__                                                           \
      default:                                                              \
        TORCH_CHECK_NOT_IMPLEMENTED(                                        \
            false,                                                          \
            '"',                                                            \
            at_dispatch_name,                                               \
            "\" not implemented for '",                                     \
            toString(_st),                                                  \
            "'");                                                           \
    }                                                                       \
  }()

In the AT_DISPATCH_SWITCH function, it picks a tensor and its datatype and checks if the Tensor datatype matches the supported datatypes. If not we get the Not Implemented error. Unfortunately, I think the AT_DISPATCH_SWITCH function, uses the common_dtype from TensorIterator in order to run. So TensorIterator.cpp needs to happen before the AT_DISPATCH_SWITCH function.

Summary: We are getting the wrong error message because TensorIterator.cpp gets called and errors out due to Tensor datatype mismatch before we can get the right error message in Dispatch.h for torch.binomial not supporting that datatype.

Options for the Fix

Option 1: Make the error message in TensorIterator.cpp more general so it applies to torch.binomial. An error message along the lines
RunTime Error : "Tensor Datatypes", op.target_dtype," and ", common_dtype_, "are different "

Option 2: Add an error message for the binomial function datatype mismatch before the the TensorIterator.cpp error message gets called.

Although Option 1 seemed easier I think Option 2 might be better as it is more specific to the binomial function while Option1 would affect all Tensors with datatype mismatch.

This PR applies the fix for Option 2

After Fix :

torch.binomial(torch.tensor([10]).long(), torch.tensor([0.5]))
RuntimeError: Binomial function arguments count and prob must have same datatype of type Float, got: count = Long, prob = Float
torch.binomial(torch.tensor([10]).long(), torch.tensor([0.5]).long())
NotImplementedError: "binomial_cpu" not implemented for 'Long'

@malfet

cc @malfet

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/157658

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6d48465 with merge base 524e827 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@michellemadubuike
Copy link
Contributor Author

@pytorchbot label "release notes: cpp"

@pytorch-bot pytorch-bot bot added the release notes: cpp release notes category label Jul 5, 2025
@michellemadubuike
Copy link
Contributor Author

@pytorchbot label "module: error checking"

@pytorch-bot pytorch-bot bot added the module: error checking Bugs related to incorrect/lacking error checking label Jul 7, 2025
@soulitzer soulitzer self-requested a review July 7, 2025 21:26
TORCH_CHECK(
count.scalar_type() == prob.scalar_type(),
"Binomial function arguments count and prob must have same datatype of type Float, got: count = ",
count.scalar_type(),", prob = ", prob.scalar_type());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: how about something like to be consistent with multinomial's error checking

TORCH_CHECK(
      at::isFloatingType(count.scalar_type()),
      "binomial only supports floating-point dtypes for count, got: ",
      self.scalar_type());
TORCH_CHECK(
      at::isFloatingType(prob.scalar_type()),
      "binomial only supports floating-point dtypes for prob, got: ",
      self.scalar_type());

Copy link
Contributor

@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix! Option 2 sounds good, added a small nit

@soulitzer
Copy link
Contributor

Could we also add a test?

@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 7, 2025
@michellemadubuike
Copy link
Contributor Author

michellemadubuike commented Jul 7, 2025

Hi @soulitzer , sure no problem. I'll also work on making the error checking more consistent with the multinomial's error checking and make another commit.

@michellemadubuike
Copy link
Contributor Author

Hi @soulitzer , I have added a test and made the error more consistent with multimodal's error checking.
The change to the error message seems more efficient as instead of checking for datatype mismatch, we are checking if each individual argument is a Float.
Therefore, we would be getting this error message consistently as opposed to the former where we would go back and forth between the error message and NotImplemented Error. The only instance I can think on where we would get the NotImplemented Error with the new error message would be for torch.float16(torch.half). Thanks!

malfet
malfet previously approved these changes Jul 8, 2025
Copy link
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably fine, but would be nice to add the same error handling for CUDA as well (not sure if test_distributions support dtypes)

at::isFloatingType(count.scalar_type()),
"binomial only supports floating-point dtypes for count, got: ",
count.scalar_type());
TORCH_CHECK(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If some values are unsupported, than ValueError is more appropriate than RuntimeError, don't you think?

Suggested change
TORCH_CHECK(
TORCH_CHECK_VALUE(

@malfet malfet dismissed their stale review July 8, 2025 02:06

Hmm, I hope CI will fail on this new check, because integer count seems to be fine

@michellemadubuike
Copy link
Contributor Author

Hi @malfet, I realized I used the wrong Binomial function in my previous test. I have fixed this and created a new test. I wasn't sure where to put the test for torch.binomial as I didn't find other tests that tested this function so I put it in the test directory.

@soulitzer
Copy link
Contributor

@michellemadubuike Maybe test/distributions/test_distributions.py would be a good place for this test

@michellemadubuike
Copy link
Contributor Author

@soulitzer Thanks, I moved the test to test/distribution/test_distributions.py

Copy link
Contributor

@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@michellemadubuike
Copy link
Contributor Author

@spzala

@soulitzer
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 8, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Jul 9, 2025
@michellemadubuike
Copy link
Contributor Author

@soulitzer, I realized torch.binomial is not supported for CUDA tensors. I add a test skip if CUDA is used in my last commit. This should hopefully resolve the cuda fail in the CI builds

assert (vals == 0.0).sum() > 4000
assert (vals == 1.0).sum() > 4000

@unittest.skipIf(TEST_CUDA, "CUDA found")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fine without this skip since none of the tensors in the test are created on cuda

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you revert this change @michellemadubuike
The failing test should be unrelated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@soulitzer Sure, no problem

@soulitzer soulitzer changed the title Fixes issue 157195 by adding error message Improve error message for torch.binomial enforcing float inputs Jul 9, 2025
@soulitzer
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 9, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: error checking Bugs related to incorrect/lacking error checking open source release notes: cpp release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Wrong error message for wrong dtypes in torch.binomial

5 participants