KEMBAR78
Improve error messages of `torch.testing.assert_close` in case of mismatching values by pmeier · Pull Request #60091 · pytorch/pytorch · GitHub
Skip to content

Conversation

@pmeier
Copy link
Collaborator

@pmeier pmeier commented Jun 16, 2021

Stack from ghstack:

Closes #58383. (1) and (2) are implemented. (3) was rejected. No consensus was reached on (4) and (5).

Improvements:

  • Instead of calling everything "Tensors" we now use "Scalars" and "Tensor-likes" depending on the shape. Plus, we now internally have the option to adapt this identifier for example to report "Imaginary components of complex tensor-likes", which is even more expressive.
  • The reported conditions "not close" and "not equal" are now determined based on rtol and atol.
  • The number of mismatched elements and the offending indices are only reported in case the inputs are not scalar
  • The allowed rtol and atol is only reported if > 0

Example 1

torch.testing.assert_close(1, 3, rtol=0, atol=1)

Before:

AssertionError: Tensors are not close!

Mismatched elements: 1 / 1 (100.0%)
Greatest absolute difference: 2 at 0 (up to 1 allowed)
Greatest relative difference: 0.6666666865348816 at 0 (up to 0 allowed)

After:

AssertionError: Scalars are not close!

Absolute difference: 2 (up to 1 allowed)
Relative difference: 0.6666666865348816

Example 2

torch.manual_seed(0)
t = torch.rand((2, 2), dtype=torch.complex64)
torch.testing.assert_close(t, t + complex(0, 1))

Before:

AssertionError: Tensors are not close!

Mismatched elements: 4 / 4 (100.0%)
Greatest absolute difference: 1.0000000596046448 at (0, 0) (up to 1e-05 allowed)
Greatest relative difference: 0.8833684352411922 at (0, 1) (up to 1.3e-06 allowed)

The failure occurred for the imaginary part.

After:

AssertionError: Imaginary components of tensor-likes are not close!

Mismatched elements: 4 / 4 (100.0%)
Greatest absolute difference: 1.0000000596046448 at index (0, 0) (up to 1e-05 allowed)
Greatest relative difference: 0.8833684352411922 at index (0, 1) (up to 1.3e-06 allowed)

Differential Revision: D29556357

…matching values

Closes #58383. (1) and (2) are implemented. (3) was rejected. No consensus was reached on (4) and (5).

Improvements:

- Instead of calling everything "Tensors" we now use "Scalars" and "Tensor-likes" depending on the shape. Plus, we now internally have the option to adapt this identifier for example to report "Imaginary components of complex tensor-likes", which is even more expressive.
- The reported conditions "not close" and "not equal" are now determined based on `rtol` and `atol`.
- The number of mismatched elements and the offending indices are only reported in case the inputs are not scalar
- The allowed `rtol` and `atol` is only reported if `> 0`

**Example 1**

```python
torch.testing.assert_close(1, 3, rtol=0, atol=1)
```

Before:

```
AssertionError: Tensors are not close!

Mismatched elements: 1 / 1 (100.0%)
Greatest absolute difference: 2 at 0 (up to 1 allowed)
Greatest relative difference: 0.6666666865348816 at 0 (up to 0 allowed)
```

After:

```
AssertionError: Scalars are not close!
Absolute difference: 2 (up to 1 allowed)
Relative difference: 0.6666666865348816
```

**Example 2**

```python
torch.manual_seed(0)
t = torch.rand((2, 2), dtype=torch.complex64)
torch.testing.assert_close(t, t + complex(0, 1))
```

Before:

```
AssertionError: Tensors are not close!

Mismatched elements: 4 / 4 (100.0%)
Greatest absolute difference: 1.0000000596046448 at (0, 0) (up to 1e-05 allowed)
Greatest relative difference: 0.8833684352411922 at (0, 1) (up to 1.3e-06 allowed)

The failure occurred for the imaginary part.
```

After:

```
AssertionError: Imaginary components of tensor-likes are not close!

Mismatched elements: 4 / 4 (100.0%)
Greatest absolute difference: 1.0000000596046448 at index (0, 0) (up to 1e-05 allowed)
Greatest relative difference: 0.8833684352411922 at index (0, 1) (up to 1.3e-06 allowed)
```

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jun 16, 2021

💊 CI failures summary and remediations

As of commit 978c9cb (more details on the Dr. CI page and at hud.pytorch.org/pr/60091):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


Preview docs built from this PR

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

pmeier added a commit that referenced this pull request Jun 16, 2021
…matching values

Closes #58383. (1) and (2) are implemented. (3) was rejected. No consensus was reached on (4) and (5).

Improvements:

- Instead of calling everything "Tensors" we now use "Scalars" and "Tensor-likes" depending on the shape. Plus, we now internally have the option to adapt this identifier for example to report "Imaginary components of complex tensor-likes", which is even more expressive.
- The reported conditions "not close" and "not equal" are now determined based on `rtol` and `atol`.
- The number of mismatched elements and the offending indices are only reported in case the inputs are not scalar
- The allowed `rtol` and `atol` is only reported if `> 0`

**Example 1**

```python
torch.testing.assert_close(1, 3, rtol=0, atol=1)
```

Before:

```
AssertionError: Tensors are not close!

Mismatched elements: 1 / 1 (100.0%)
Greatest absolute difference: 2 at 0 (up to 1 allowed)
Greatest relative difference: 0.6666666865348816 at 0 (up to 0 allowed)
```

After:

```
AssertionError: Scalars are not close!
Absolute difference: 2 (up to 1 allowed)
Relative difference: 0.6666666865348816
```

**Example 2**

```python
torch.manual_seed(0)
t = torch.rand((2, 2), dtype=torch.complex64)
torch.testing.assert_close(t, t + complex(0, 1))
```

Before:

```
AssertionError: Tensors are not close!

Mismatched elements: 4 / 4 (100.0%)
Greatest absolute difference: 1.0000000596046448 at (0, 0) (up to 1e-05 allowed)
Greatest relative difference: 0.8833684352411922 at (0, 1) (up to 1.3e-06 allowed)

The failure occurred for the imaginary part.
```

After:

```
AssertionError: Imaginary components of tensor-likes are not close!

Mismatched elements: 4 / 4 (100.0%)
Greatest absolute difference: 1.0000000596046448 at index (0, 0) (up to 1e-05 allowed)
Greatest relative difference: 0.8833684352411922 at index (0, 1) (up to 1.3e-06 allowed)
```

ghstack-source-id: d46bc03
Pull Request resolved: #60091
@pmeier pmeier requested a review from mruberry June 16, 2021 12:40
@pmeier pmeier added the module: testing Issues related to the torch.testing module (not tests) label Jun 16, 2021
…case of mismatching values"

Closes #58383. (1) and (2) are implemented. (3) was rejected. No consensus was reached on (4) and (5).

Improvements:

- Instead of calling everything "Tensors" we now use "Scalars" and "Tensor-likes" depending on the shape. Plus, we now internally have the option to adapt this identifier for example to report "Imaginary components of complex tensor-likes", which is even more expressive.
- The reported conditions "not close" and "not equal" are now determined based on `rtol` and `atol`.
- The number of mismatched elements and the offending indices are only reported in case the inputs are not scalar
- The allowed `rtol` and `atol` is only reported if `> 0`

**Example 1**

```python
torch.testing.assert_close(1, 3, rtol=0, atol=1)
```

Before:

```
AssertionError: Tensors are not close!

Mismatched elements: 1 / 1 (100.0%)
Greatest absolute difference: 2 at 0 (up to 1 allowed)
Greatest relative difference: 0.6666666865348816 at 0 (up to 0 allowed)
```

After:

```
AssertionError: Scalars are not close!
Absolute difference: 2 (up to 1 allowed)
Relative difference: 0.6666666865348816
```

**Example 2**

```python
torch.manual_seed(0)
t = torch.rand((2, 2), dtype=torch.complex64)
torch.testing.assert_close(t, t + complex(0, 1))
```

Before:

```
AssertionError: Tensors are not close!

Mismatched elements: 4 / 4 (100.0%)
Greatest absolute difference: 1.0000000596046448 at (0, 0) (up to 1e-05 allowed)
Greatest relative difference: 0.8833684352411922 at (0, 1) (up to 1.3e-06 allowed)

The failure occurred for the imaginary part.
```

After:

```
AssertionError: Imaginary components of tensor-likes are not close!

Mismatched elements: 4 / 4 (100.0%)
Greatest absolute difference: 1.0000000596046448 at index (0, 0) (up to 1e-05 allowed)
Greatest relative difference: 0.8833684352411922 at index (0, 1) (up to 1.3e-06 allowed)
```

[ghstack-poisoned]
…case of mismatching values"

Closes #58383. (1) and (2) are implemented. (3) was rejected. No consensus was reached on (4) and (5).

Improvements:

- Instead of calling everything "Tensors" we now use "Scalars" and "Tensor-likes" depending on the shape. Plus, we now internally have the option to adapt this identifier for example to report "Imaginary components of complex tensor-likes", which is even more expressive.
- The reported conditions "not close" and "not equal" are now determined based on `rtol` and `atol`.
- The number of mismatched elements and the offending indices are only reported in case the inputs are not scalar
- The allowed `rtol` and `atol` is only reported if `> 0`

**Example 1**

```python
torch.testing.assert_close(1, 3, rtol=0, atol=1)
```

Before:

```
AssertionError: Tensors are not close!

Mismatched elements: 1 / 1 (100.0%)
Greatest absolute difference: 2 at 0 (up to 1 allowed)
Greatest relative difference: 0.6666666865348816 at 0 (up to 0 allowed)
```

After:

```
AssertionError: Scalars are not close!
Absolute difference: 2 (up to 1 allowed)
Relative difference: 0.6666666865348816
```

**Example 2**

```python
torch.manual_seed(0)
t = torch.rand((2, 2), dtype=torch.complex64)
torch.testing.assert_close(t, t + complex(0, 1))
```

Before:

```
AssertionError: Tensors are not close!

Mismatched elements: 4 / 4 (100.0%)
Greatest absolute difference: 1.0000000596046448 at (0, 0) (up to 1e-05 allowed)
Greatest relative difference: 0.8833684352411922 at (0, 1) (up to 1.3e-06 allowed)

The failure occurred for the imaginary part.
```

After:

```
AssertionError: Imaginary components of tensor-likes are not close!

Mismatched elements: 4 / 4 (100.0%)
Greatest absolute difference: 1.0000000596046448 at index (0, 0) (up to 1e-05 allowed)
Greatest relative difference: 0.8833684352411922 at index (0, 1) (up to 1.3e-06 allowed)
```

[ghstack-poisoned]
@mruberry
Copy link
Collaborator

  • Instead of calling everything "Tensors" we now use "Scalars" and "Tensor-likes" depending on the shape.

Neat.

Plus, we now internally have the option to adapt this identifier for example to report "Imaginary components of complex tensor-likes", which is even more expressive.

In the future we might change our complex comparisons to also just be a closeness check (and not validate the real and imaginary components separately) -- how would that future update work with this statement?

  • The reported conditions "not close" and "not equal" are now determined based on rtol and atol.

Sure.

  • The number of mismatched elements and the offending indices are only reported in case the inputs are not scalar

This seems OK since the language is changing to indicate the comparison is between scalars.

  • The allowed rtol and atol is only reported if > 0

Won't this confused people?

AssertionError: Scalars are not close!
Absolute difference: 2 (up to 1 allowed)
Relative difference: 0.6666666865348816

When we compare two values are we imposing the stricter of atol and rtol? Isn't the close equation we're using weird and additive in both these things, so there's just an absolute difference to consider?

…case of mismatching values"

Closes #58383. (1) and (2) are implemented. (3) was rejected. No consensus was reached on (4) and (5).

Improvements:

- Instead of calling everything "Tensors" we now use "Scalars" and "Tensor-likes" depending on the shape. Plus, we now internally have the option to adapt this identifier for example to report "Imaginary components of complex tensor-likes", which is even more expressive.
- The reported conditions "not close" and "not equal" are now determined based on `rtol` and `atol`.
- The number of mismatched elements and the offending indices are only reported in case the inputs are not scalar
- The allowed `rtol` and `atol` is only reported if `> 0`

**Example 1**

```python
torch.testing.assert_close(1, 3, rtol=0, atol=1)
```

Before:

```
AssertionError: Tensors are not close!

Mismatched elements: 1 / 1 (100.0%)
Greatest absolute difference: 2 at 0 (up to 1 allowed)
Greatest relative difference: 0.6666666865348816 at 0 (up to 0 allowed)
```

After:

```
AssertionError: Scalars are not close!
Absolute difference: 2 (up to 1 allowed)
Relative difference: 0.6666666865348816
```

**Example 2**

```python
torch.manual_seed(0)
t = torch.rand((2, 2), dtype=torch.complex64)
torch.testing.assert_close(t, t + complex(0, 1))
```

Before:

```
AssertionError: Tensors are not close!

Mismatched elements: 4 / 4 (100.0%)
Greatest absolute difference: 1.0000000596046448 at (0, 0) (up to 1e-05 allowed)
Greatest relative difference: 0.8833684352411922 at (0, 1) (up to 1.3e-06 allowed)

The failure occurred for the imaginary part.
```

After:

```
AssertionError: Imaginary components of tensor-likes are not close!

Mismatched elements: 4 / 4 (100.0%)
Greatest absolute difference: 1.0000000596046448 at index (0, 0) (up to 1e-05 allowed)
Greatest relative difference: 0.8833684352411922 at index (0, 1) (up to 1.3e-06 allowed)
```

[ghstack-poisoned]
@pmeier
Copy link
Collaborator Author

pmeier commented Jun 21, 2021

When we compare two values are we imposing the stricter of atol and rtol? Isn't the close equation we're using weird and additive in both these things, so there's just an absolute difference to consider?

Good point. We now either leave both out or report both. In the former case this should not be confusing, since the header in such a case reads "... are not equal!"

pmeier added 3 commits June 21, 2021 12:58
…case of mismatching values"

Closes #58383. (1) and (2) are implemented. (3) was rejected. No consensus was reached on (4) and (5).

Improvements:

- Instead of calling everything "Tensors" we now use "Scalars" and "Tensor-likes" depending on the shape. Plus, we now internally have the option to adapt this identifier for example to report "Imaginary components of complex tensor-likes", which is even more expressive.
- The reported conditions "not close" and "not equal" are now determined based on `rtol` and `atol`.
- The number of mismatched elements and the offending indices are only reported in case the inputs are not scalar
- The allowed `rtol` and `atol` is only reported if `> 0`

**Example 1**

```python
torch.testing.assert_close(1, 3, rtol=0, atol=1)
```

Before:

```
AssertionError: Tensors are not close!

Mismatched elements: 1 / 1 (100.0%)
Greatest absolute difference: 2 at 0 (up to 1 allowed)
Greatest relative difference: 0.6666666865348816 at 0 (up to 0 allowed)
```

After:

```
AssertionError: Scalars are not close!
Absolute difference: 2 (up to 1 allowed)
Relative difference: 0.6666666865348816
```

**Example 2**

```python
torch.manual_seed(0)
t = torch.rand((2, 2), dtype=torch.complex64)
torch.testing.assert_close(t, t + complex(0, 1))
```

Before:

```
AssertionError: Tensors are not close!

Mismatched elements: 4 / 4 (100.0%)
Greatest absolute difference: 1.0000000596046448 at (0, 0) (up to 1e-05 allowed)
Greatest relative difference: 0.8833684352411922 at (0, 1) (up to 1.3e-06 allowed)

The failure occurred for the imaginary part.
```

After:

```
AssertionError: Imaginary components of tensor-likes are not close!

Mismatched elements: 4 / 4 (100.0%)
Greatest absolute difference: 1.0000000596046448 at index (0, 0) (up to 1e-05 allowed)
Greatest relative difference: 0.8833684352411922 at index (0, 1) (up to 1.3e-06 allowed)
```

[ghstack-poisoned]
…case of mismatching values"

Closes #58383. (1) and (2) are implemented. (3) was rejected. No consensus was reached on (4) and (5).

Improvements:

- Instead of calling everything "Tensors" we now use "Scalars" and "Tensor-likes" depending on the shape. Plus, we now internally have the option to adapt this identifier for example to report "Imaginary components of complex tensor-likes", which is even more expressive.
- The reported conditions "not close" and "not equal" are now determined based on `rtol` and `atol`.
- The number of mismatched elements and the offending indices are only reported in case the inputs are not scalar
- The allowed `rtol` and `atol` is only reported if `> 0`

**Example 1**

```python
torch.testing.assert_close(1, 3, rtol=0, atol=1)
```

Before:

```
AssertionError: Tensors are not close!

Mismatched elements: 1 / 1 (100.0%)
Greatest absolute difference: 2 at 0 (up to 1 allowed)
Greatest relative difference: 0.6666666865348816 at 0 (up to 0 allowed)
```

After:

```
AssertionError: Scalars are not close!
Absolute difference: 2 (up to 1 allowed)
Relative difference: 0.6666666865348816
```

**Example 2**

```python
torch.manual_seed(0)
t = torch.rand((2, 2), dtype=torch.complex64)
torch.testing.assert_close(t, t + complex(0, 1))
```

Before:

```
AssertionError: Tensors are not close!

Mismatched elements: 4 / 4 (100.0%)
Greatest absolute difference: 1.0000000596046448 at (0, 0) (up to 1e-05 allowed)
Greatest relative difference: 0.8833684352411922 at (0, 1) (up to 1.3e-06 allowed)

The failure occurred for the imaginary part.
```

After:

```
AssertionError: Imaginary components of tensor-likes are not close!

Mismatched elements: 4 / 4 (100.0%)
Greatest absolute difference: 1.0000000596046448 at index (0, 0) (up to 1e-05 allowed)
Greatest relative difference: 0.8833684352411922 at index (0, 1) (up to 1.3e-06 allowed)
```

[ghstack-poisoned]
…case of mismatching values"

Closes #58383. (1) and (2) are implemented. (3) was rejected. No consensus was reached on (4) and (5).

Improvements:

- Instead of calling everything "Tensors" we now use "Scalars" and "Tensor-likes" depending on the shape. Plus, we now internally have the option to adapt this identifier for example to report "Imaginary components of complex tensor-likes", which is even more expressive.
- The reported conditions "not close" and "not equal" are now determined based on `rtol` and `atol`.
- The number of mismatched elements and the offending indices are only reported in case the inputs are not scalar
- The allowed `rtol` and `atol` is only reported if `> 0`

**Example 1**

```python
torch.testing.assert_close(1, 3, rtol=0, atol=1)
```

Before:

```
AssertionError: Tensors are not close!

Mismatched elements: 1 / 1 (100.0%)
Greatest absolute difference: 2 at 0 (up to 1 allowed)
Greatest relative difference: 0.6666666865348816 at 0 (up to 0 allowed)
```

After:

```
AssertionError: Scalars are not close!
Absolute difference: 2 (up to 1 allowed)
Relative difference: 0.6666666865348816
```

**Example 2**

```python
torch.manual_seed(0)
t = torch.rand((2, 2), dtype=torch.complex64)
torch.testing.assert_close(t, t + complex(0, 1))
```

Before:

```
AssertionError: Tensors are not close!

Mismatched elements: 4 / 4 (100.0%)
Greatest absolute difference: 1.0000000596046448 at (0, 0) (up to 1e-05 allowed)
Greatest relative difference: 0.8833684352411922 at (0, 1) (up to 1.3e-06 allowed)

The failure occurred for the imaginary part.
```

After:

```
AssertionError: Imaginary components of tensor-likes are not close!

Mismatched elements: 4 / 4 (100.0%)
Greatest absolute difference: 1.0000000596046448 at index (0, 0) (up to 1e-05 allowed)
Greatest relative difference: 0.8833684352411922 at index (0, 1) (up to 1.3e-06 allowed)
```

[ghstack-poisoned]
Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool

…case of mismatching values"

Closes #58383. (1) and (2) are implemented. (3) was rejected. No consensus was reached on (4) and (5).

Improvements:

- Instead of calling everything "Tensors" we now use "Scalars" and "Tensor-likes" depending on the shape. Plus, we now internally have the option to adapt this identifier for example to report "Imaginary components of complex tensor-likes", which is even more expressive.
- The reported conditions "not close" and "not equal" are now determined based on `rtol` and `atol`.
- The number of mismatched elements and the offending indices are only reported in case the inputs are not scalar
- The allowed `rtol` and `atol` is only reported if `> 0`

**Example 1**

```python
torch.testing.assert_close(1, 3, rtol=0, atol=1)
```

Before:

```
AssertionError: Tensors are not close!

Mismatched elements: 1 / 1 (100.0%)
Greatest absolute difference: 2 at 0 (up to 1 allowed)
Greatest relative difference: 0.6666666865348816 at 0 (up to 0 allowed)
```

After:

```
AssertionError: Scalars are not close!
Absolute difference: 2 (up to 1 allowed)
Relative difference: 0.6666666865348816
```

**Example 2**

```python
torch.manual_seed(0)
t = torch.rand((2, 2), dtype=torch.complex64)
torch.testing.assert_close(t, t + complex(0, 1))
```

Before:

```
AssertionError: Tensors are not close!

Mismatched elements: 4 / 4 (100.0%)
Greatest absolute difference: 1.0000000596046448 at (0, 0) (up to 1e-05 allowed)
Greatest relative difference: 0.8833684352411922 at (0, 1) (up to 1.3e-06 allowed)

The failure occurred for the imaginary part.
```

After:

```
AssertionError: Imaginary components of tensor-likes are not close!

Mismatched elements: 4 / 4 (100.0%)
Greatest absolute difference: 1.0000000596046448 at index (0, 0) (up to 1e-05 allowed)
Greatest relative difference: 0.8833684352411922 at index (0, 1) (up to 1.3e-06 allowed)
```

[ghstack-poisoned]
…case of mismatching values"

Closes #58383. (1) and (2) are implemented. (3) was rejected. No consensus was reached on (4) and (5).

Improvements:

- Instead of calling everything "Tensors" we now use "Scalars" and "Tensor-likes" depending on the shape. Plus, we now internally have the option to adapt this identifier for example to report "Imaginary components of complex tensor-likes", which is even more expressive.
- The reported conditions "not close" and "not equal" are now determined based on `rtol` and `atol`.
- The number of mismatched elements and the offending indices are only reported in case the inputs are not scalar
- The allowed `rtol` and `atol` is only reported if `> 0`

**Example 1**

```python
torch.testing.assert_close(1, 3, rtol=0, atol=1)
```

Before:

```
AssertionError: Tensors are not close!

Mismatched elements: 1 / 1 (100.0%)
Greatest absolute difference: 2 at 0 (up to 1 allowed)
Greatest relative difference: 0.6666666865348816 at 0 (up to 0 allowed)
```

After:

```
AssertionError: Scalars are not close!
Absolute difference: 2 (up to 1 allowed)
Relative difference: 0.6666666865348816
```

**Example 2**

```python
torch.manual_seed(0)
t = torch.rand((2, 2), dtype=torch.complex64)
torch.testing.assert_close(t, t + complex(0, 1))
```

Before:

```
AssertionError: Tensors are not close!

Mismatched elements: 4 / 4 (100.0%)
Greatest absolute difference: 1.0000000596046448 at (0, 0) (up to 1e-05 allowed)
Greatest relative difference: 0.8833684352411922 at (0, 1) (up to 1.3e-06 allowed)

The failure occurred for the imaginary part.
```

After:

```
AssertionError: Imaginary components of tensor-likes are not close!

Mismatched elements: 4 / 4 (100.0%)
Greatest absolute difference: 1.0000000596046448 at index (0, 0) (up to 1e-05 allowed)
Greatest relative difference: 0.8833684352411922 at index (0, 1) (up to 1.3e-06 allowed)
```

[ghstack-poisoned]
…case of mismatching values"


Closes #58383. (1) and (2) are implemented. (3) was rejected. No consensus was reached on (4) and (5).

Improvements:

- Instead of calling everything "Tensors" we now use "Scalars" and "Tensor-likes" depending on the shape. Plus, we now internally have the option to adapt this identifier for example to report "Imaginary components of complex tensor-likes", which is even more expressive.
- The reported conditions "not close" and "not equal" are now determined based on `rtol` and `atol`.
- The number of mismatched elements and the offending indices are only reported in case the inputs are not scalar
- The allowed `rtol` and `atol` is only reported if `> 0`

**Example 1**

```python
torch.testing.assert_close(1, 3, rtol=0, atol=1)
```

Before:

```
AssertionError: Tensors are not close!

Mismatched elements: 1 / 1 (100.0%)
Greatest absolute difference: 2 at 0 (up to 1 allowed)
Greatest relative difference: 0.6666666865348816 at 0 (up to 0 allowed)
```

After:

```
AssertionError: Scalars are not close!

Absolute difference: 2 (up to 1 allowed)
Relative difference: 0.6666666865348816
```

**Example 2**

```python
torch.manual_seed(0)
t = torch.rand((2, 2), dtype=torch.complex64)
torch.testing.assert_close(t, t + complex(0, 1))
```

Before:

```
AssertionError: Tensors are not close!

Mismatched elements: 4 / 4 (100.0%)
Greatest absolute difference: 1.0000000596046448 at (0, 0) (up to 1e-05 allowed)
Greatest relative difference: 0.8833684352411922 at (0, 1) (up to 1.3e-06 allowed)

The failure occurred for the imaginary part.
```

After:

```
AssertionError: Imaginary components of tensor-likes are not close!

Mismatched elements: 4 / 4 (100.0%)
Greatest absolute difference: 1.0000000596046448 at index (0, 0) (up to 1e-05 allowed)
Greatest relative difference: 0.8833684352411922 at index (0, 1) (up to 1.3e-06 allowed)
```

[ghstack-poisoned]
@mruberry
Copy link
Collaborator

mruberry commented Jul 6, 2021

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@mruberry merged this pull request in 9979289.

@facebook-github-bot facebook-github-bot deleted the gh/pmeier/21/head branch July 11, 2021 14:16
facebook-github-bot pushed a commit that referenced this pull request Jul 30, 2021
…uts (#61583)

Summary:
This utilizes the feature introduced in #60091 to modify the header of the error message.

Before:

```python
AssertionError: Tensor-likes are not equal!

Mismatched elements: 1 / 2 (50.0%)
Greatest absolute difference: 1 at index 1
Greatest relative difference: 0.3333333432674408 at index 1

The failure occurred for the values.
```

After:

```python
AssertionError: Sparse COO values of tensor-likes are not equal!

Mismatched elements: 1 / 2 (50.0%)
Greatest absolute difference: 1 at index 1
Greatest relative difference: 0.3333333432674408 at index 1
```

Pull Request resolved: #61583

Reviewed By: malfet

Differential Revision: D30014797

Pulled By: cpuhrsch

fbshipit-source-id: 66e30645e94de5c8c96510822082ff9aabef5329
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: testing Issues related to the torch.testing module (not tests) open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants