-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Description
Closeness checks in our the test suite, i.e. TestCase.assertEqual, are performed component wise meaning closeness is computed for the real and imaginary components separately.
pytorch/torch/testing/_core.py
Lines 137 to 164 in 39ce29e
| # Compares complex tensors' real and imaginary parts separately. | |
| # (see NOTE Test Framework Tensor "Equality") | |
| if a.is_complex(): | |
| if equal_nan == "relaxed": | |
| a = a.clone() | |
| b = b.clone() | |
| a.real[a.imag.isnan()] = math.nan | |
| a.imag[a.real.isnan()] = math.nan | |
| b.real[b.imag.isnan()] = math.nan | |
| b.imag[b.real.isnan()] = math.nan | |
| real_result, debug_msg = _compare_tensors_internal(a.real, b.real, | |
| rtol=rtol, atol=atol, | |
| equal_nan=equal_nan) | |
| if not real_result: | |
| debug_msg = "Real parts failed to compare as equal! " + cast(str, debug_msg) | |
| return (real_result, debug_msg) | |
| imag_result, debug_msg = _compare_tensors_internal(a.imag, b.imag, | |
| rtol=rtol, atol=atol, | |
| equal_nan=equal_nan) | |
| if not imag_result: | |
| debug_msg = "Imaginary parts failed to compare as equal! " + cast(str, debug_msg) | |
| return (imag_result, debug_msg) | |
| return (True, None) |
The recently introduced torch.testing.assert_close inherited the behavior:
pytorch/torch/testing/_asserts.py
Lines 72 to 88 in 39ce29e
| if relaxed_complex_nan: | |
| actual, expected = [ | |
| t.clone().masked_fill( | |
| t.real.isnan() | t.imag.isnan(), complex(float("NaN"), float("NaN")) # type: ignore[call-overload] | |
| ) | |
| for t in (actual, expected) | |
| ] | |
| error_meta = check_tensors(actual.real, expected.real, equal_nan=equal_nan, **kwargs) | |
| if error_meta: | |
| return error_meta | |
| error_meta = check_tensors(actual.imag, expected.imag, equal_nan=equal_nan, **kwargs) | |
| if error_meta: | |
| return error_meta | |
| return None |
The componentwise checking was introduced in #34258 but no public explanation was provided on the reason behind it.
Pros:
- Since
complex64is comprised of twofloat32, we can reuse the testing tolerances. Of course the same is true forcomplex128andcomplex32.
Cons:
-
torch.testing.assert_closeandtorch.isclosediverge for complex tensors:>>> rtol, atol = 0.1, 0.3 >>> expected = torch.tensor(1 + 1j) >>> actual1 = torch.tensor(1 + 1.43j) >>> torch.testing.assert_close(actual1, expected, rtol=rtol, atol=atol) AssertionError: Scalars are not close! >>> torch.isclose(actual1, expected, rtol=rtol, atol=atol) tensor(True) >>> actual2 = torch.tensor(0.6 + 1.3j) >>> torch.testing.assert_close(actual2, expected, rtol=rtol, atol=atol) >>> torch.isclose(actual2, expected, rtol=rtol, atol=atol) tensor(False)
torch.isclosefollows the implementation ofnumpyand partiallycmathby comparing the absolute values of the difference of the two values.torch.testing.assert_closeon the other hand compares the absolute values of the difference of the individual components. To be exact,torch.iscloseas well asnumpyevaluate the closeness inequalityabs(a - b) <= atol + rtol * abs(b)the same way for complex as for real numbers.torch.testing.assert_closeevaluates it twice for the real and imaginary component and checks the logical and of the results. -
The implementation is more messy and introduces more edgecases:
- Is
complex(inf, 0)close tocomplex(inf, 1)? - Is
complex(0, NaN)close tocomplex(NaN, 0)?
The second case led to
equal_nan="relaxed"inTestCase.assertEqualwhich accepts these comparisons while they will be rejected forequal_nan=True.torch.isclosecurrently bails out for complex values andequal_nan=Truepytorch/aten/src/ATen/native/TensorCompare.cpp
Lines 101 to 102 in a26a9f8
TORCH_CHECK(!(self.is_complex() && equal_nan), "isclose with equal_nan=True is not supported for complex inputs."); - Is
Regions of closeness
By checking the components individually we are forming a rectangle region of closeness around the expected value. Despite the divergence on the tolerance definition, torch.isclose, numpy, and cmath form a circle around the expected values. Thus is impossible to convert the old tolerances to achieve the same behavior if we go for this change.
RFC
Since in most cases the specific values of the tolerances hardly matter, IMHO the cons severly outweigh the pros. Thus, the current plan is to retire the componentwise checking and merge the complex behavior of our internal and public testing functions with torch.isclose and in turn with numpy. If you disagree with that plan, please let us know in this issue.
cc @ezyang @anjali411 @dylanbespalko @mruberry @lezcano @nikitaved @VitalyFedyunin @walterddr