KEMBAR78
Improve error messages of `torch.testing.assert_close` in case of mismatching values · Issue #58383 · pytorch/pytorch · GitHub
Skip to content

Improve error messages of torch.testing.assert_close in case of mismatching values #58383

@pmeier

Description

@pmeier

The current implementation of torch.testing.assert_close already does a good job at delivering information to the user in case of mismatching values. Still, I think they can be improved in multiple ways:

  1. If a scalar (tensor) is compared, the error message currently looks like this:

    AssertionError: Tensors are not close!
    
    Mismatched elements: 1 / 1 (100.0%)
    Greatest absolute difference: 1.0 at 0 (up to 1e-05 allowed)
    Greatest relative difference: 0.5 at 0 (up to 1.3e-06 allowed)
    

    This can be improved in multiple ways:

    • The header should read "Scalars are not close!" to indicate that only a single value was involved in the comparison
    • The number of mismatched elements should be removed since it carries no information and always is Mismatched elements: 1 / 1 (100.0%)
    • The word "Greatest" should be removed from the reported absolute and relative differences, since there is only one difference each
    • The index should be removed since it carries no information and scalar tensors cannot accessed by index

    With these changes the error message could look like:

    AssertionError: Scalars are not close!
    
    Absolute difference: 1.0 at 0 (up to 1e-05 allowed)
    Relative difference: 0.5 at 0 (up to 1.3e-06 allowed)
    
  2. The only differences between the messages for closeness and equality are that

    • the header reads "not close" instead of "not equal" and
    • the allowed tolerances are appended to the differences.

    We could combine the two by making the message dependent on the value of the tolerances:

    • The header should read "equal" if rtol == 0 and atol == 0 and
    • the allowed tolerance should only be added if > 0

    This would also allow us to make assert_equal a special case of assert_close (Resolved: add torch.testing.assert_close() #56544 (comment)) without further modification.

  3. By default the numbers in the error message are formatted with the "General format" (:g). This might lead to diverging formats of two related numbers if they have different magnitudes:

    abs_diff = 1e-4
    atol = 1e-5
    print(f"Absolute difference: {abs_diff} (up to {atol} allowed)")
    Absolute difference: 0.0001 (up to 1e-05 allowed)
    

    We should use a fixed number format, e.g, :.2e for floats and :d for integers.

  4. Non-matching NaN's currently eliminate the difference information in the error message:

    a = torch.tensor(1.0, float("NaN"))
    b = torch.tensor(2.0, float("NaN"))
    torch.testing.assert_close(a, b)
    
    AssertionError: Tensors are not close!
    Mismatched elements: 2 / 2 (100.0%)
    Greatest absolute difference: nan at 1 (up to 1e-05 allowed)
    Greatest relative difference: nan at 1 (up to 1.3e-06 allowed)
    

    We should exclude NaN's from the difference calculation and for example add a line Mismatched NaN's: ... in case there are any.

  5. assert_close allows to pass a callable as msg to construct an expressive error message. If a user just wants to add some information, for example the tested operator, it is currently impossible to do without losing all the benefits of our message generation. We should make our version public so others can build on top of it:

    import torch.testing
    
    def my_msg(actual, expected, info):
        default_msg = torch.testing.default_mismatch_msg(actual, expected, info)
        return f"My important information: ... \n\n{default_msg}"
    
    torch.testing.assert_close(..., msg=my_msg)

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: testingIssues related to the torch.testing module (not tests)triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions