-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
The current implementation of torch.testing.assert_close
already does a good job at delivering information to the user in case of mismatching values. Still, I think they can be improved in multiple ways:
-
If a scalar (tensor) is compared, the error message currently looks like this:
AssertionError: Tensors are not close! Mismatched elements: 1 / 1 (100.0%) Greatest absolute difference: 1.0 at 0 (up to 1e-05 allowed) Greatest relative difference: 0.5 at 0 (up to 1.3e-06 allowed)
This can be improved in multiple ways:
- The header should read "Scalars are not close!" to indicate that only a single value was involved in the comparison
- The number of mismatched elements should be removed since it carries no information and always is
Mismatched elements: 1 / 1 (100.0%)
- The word "Greatest" should be removed from the reported absolute and relative differences, since there is only one difference each
- The index should be removed since it carries no information and scalar tensors cannot accessed by index
With these changes the error message could look like:
AssertionError: Scalars are not close! Absolute difference: 1.0 at 0 (up to 1e-05 allowed) Relative difference: 0.5 at 0 (up to 1.3e-06 allowed)
-
The only differences between the messages for closeness and equality are that
- the header reads "not close" instead of "not equal" and
- the allowed tolerances are appended to the differences.
We could combine the two by making the message dependent on the value of the tolerances:
- The header should read "equal" if
rtol == 0 and atol == 0
and - the allowed tolerance should only be added if
> 0
This would also allow us to make
assert_equal
a special case ofassert_close
(Resolved: add torch.testing.assert_close() #56544 (comment)) without further modification. -
By default the numbers in the error message are formatted with the "General format" (
:g
). This might lead to diverging formats of two related numbers if they have different magnitudes:abs_diff = 1e-4 atol = 1e-5 print(f"Absolute difference: {abs_diff} (up to {atol} allowed)")
Absolute difference: 0.0001 (up to 1e-05 allowed)
We should use a fixed number format, e.g,
:.2e
for floats and:d
for integers. -
Non-matching NaN's currently eliminate the difference information in the error message:
a = torch.tensor(1.0, float("NaN")) b = torch.tensor(2.0, float("NaN")) torch.testing.assert_close(a, b)
AssertionError: Tensors are not close! Mismatched elements: 2 / 2 (100.0%) Greatest absolute difference: nan at 1 (up to 1e-05 allowed) Greatest relative difference: nan at 1 (up to 1.3e-06 allowed)
We should exclude NaN's from the difference calculation and for example add a line
Mismatched NaN's: ...
in case there are any. -
assert_close
allows to pass a callable asmsg
to construct an expressive error message. If a user just wants to add some information, for example the tested operator, it is currently impossible to do without losing all the benefits of our message generation. We should make our version public so others can build on top of it:import torch.testing def my_msg(actual, expected, info): default_msg = torch.testing.default_mismatch_msg(actual, expected, info) return f"My important information: ... \n\n{default_msg}" torch.testing.assert_close(..., msg=my_msg)