KEMBAR78
fix mean_out: op does not update parameter out for BF16/FP16 dtype on CPU by DavidGu-Datong · Pull Request #135174 · pytorch/pytorch · GitHub
Skip to content

Conversation

@DavidGu-Datong
Copy link
Contributor

@DavidGu-Datong DavidGu-Datong commented Sep 5, 2024

Fixes #134848

For BF16/FP16, when a tensor is specified in out parameter of mean, the mean kernel should use its storage for output, but that doesn't happen, since an at::to in the current code causes storage to be allocated again, but the out parameter tensor's storage doesn't get updated, resulting in it not holding the mean output.

cc @albanD

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 5, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/135174

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1a4ee95 with merge base cf31724 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Sep 5, 2024

CLA Signed

The committers listed above are authorized under a signed CLA.

  • ✅ login: DavidGu-Datong / name: DavidGu (1a4ee95)

@DavidGu-Datong
Copy link
Contributor Author

@pytorchbot label "topic: not user facing

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 5, 2024

❌ 🤖 pytorchbot command failed:

Got EOF while in a quoted string```
Try `@pytorchbot --help` for more info.

@DavidGu-Datong
Copy link
Contributor Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Sep 5, 2024
@DavidGu-Datong DavidGu-Datong reopened this Sep 5, 2024
@soulitzer soulitzer self-requested a review September 6, 2024 02:07
@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 6, 2024
Copy link
Contributor

@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, could you add a test (see test/test_reductions.py)

@DavidGu-Datong
Copy link
Contributor Author

Thanks for review, I have added a test to check whether the out of mean_out op is the alias of the return at (test/test_reductions.py) file.

@DavidGu-Datong
Copy link
Contributor Author

@mruberry Please help me to review the code, thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to no longer set result_mut and no longer do auto& result_mut = const_cast<Tensor&>(result); above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think yes, let me try to test it local.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the test, but maybe also a small check for correctness by comparing with the out-of-place mean op would be good?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it looks good? I use allclose to check it with the target.

Copy link
Contributor

@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@soulitzer soulitzer added topic: bug fixes topic category module: python frontend For issues relating to PyTorch's Python frontend and removed topic: not user facing topic category labels Sep 12, 2024
@soulitzer
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 12, 2024
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

at::sum_out(result_temp, self, opt_dim, keepdim, sum_out_dtype).div_(dim_prod);
// After sum & div, cast result_temp back to BF16 or FP16, if required.
if (is_half_type) {
result.copy_(result_temp.to(dtype));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to call result_temp.to(dtype) ? result.copy_(result_temp) should works and save the overhead of result_temp.to(dtype).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, I just known copy_ do this convert. I have changed it.

Copy link
Collaborator

@sanchitintel sanchitintel Sep 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CaoE - for accuracy reasons, wouldn't it be better if we have intermediate FP32 sum output that's input for division? Thanks

Just noticed that sum_out_dtype already ensures it.

Copy link
Collaborator

@CaoE CaoE Sep 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd better add other data types in the testing to ensure that they also return the correct result.

@onlyCPU
@dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
def test_mean_out_float16_is_alias_of_return(self, dtype, device):
a = torch.tensor([[[1.0, 1.0, 1.0, 1.0]], [[2.0, 2.0, 2.0, 2.0]], [[3.0, 3.0, 3.0, 3.0]]],
                         dtype=dtype, device=device)
...

Can we also avoid creating a new tensor when is_half_type is false to avoid the overhead of creating a new tensor ?
For example. We can just use:

    if (is_half_type) {
      auto _result_mut = result.to(sum_out_dtype);
      at::sum_out(_result_mut, self, opt_dim, keepdim, sum_out_dtype).div_(dim_prod);
      result.copy_(_result_mut);
    } else {
      at::sum_out(result, self, opt_dim, keepdim, sum_out_dtype).div_(dim_prod);
    }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your review. I add it.

Copy link
Collaborator

@sanchitintel sanchitintel Sep 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

Is there any way we can avoid the copy? If not, can we add a comment on why it's necessary, so that someone reading the code in the future may be able to understand it without going through the history of the file? Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It cannot be avoided because the promotion needs more storage. Thus, it cannot reuse the storage of input "out" parameter and need to update the result to "out".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added some comments about it. It is my first time do contribution to GitHub. Thanks for all you guy's excellent suggestions!

@sanchitintel sanchitintel changed the title fix mean_out op does not update value of given parameter out. #134848 fix mean_out: op does not update parameter out for BF16/FP16 dtype Sep 14, 2024
@DavidGu-Datong DavidGu-Datong force-pushed the dev/datonggu/1/base branch 3 times, most recently from d40cd70 to 336eccc Compare September 20, 2024 12:32
@sanchitintel
Copy link
Collaborator

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-focal-cuda12.4-py3.10-gcc9-sm86 / test (default, 3, 5, linux.g5.4xlarge.nvidia.gpu)

Details for Dev Infra team Raised by workflow job

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 20, 2024

❌ 🤖 pytorchbot command failed:

@pytorchbot: error: argument command: invalid choice: 'rebas' (choose from 'merge', 'revert', 'rebase', 'label', 'drci', 'cherry-pick', 'close')

usage: @pytorchbot [-h] {merge,revert,rebase,label,drci,cherry-pick,close} ...

Try @pytorchbot --help for more info.

@sanchitintel
Copy link
Collaborator

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased dev/datonggu/1/base onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout dev/datonggu/1/base && git pull --rebase)

@sanchitintel
Copy link
Collaborator

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@sanchitintel sanchitintel changed the title fix mean_out: op does not update parameter out for BF16/FP16 dtype fix mean_out: op does not update parameter out for BF16/FP16 dtype on CPU Sep 20, 2024
@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@DavidGu-Datong
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: python frontend For issues relating to PyTorch's Python frontend open source release notes: python_frontend python frontend release notes category topic: bug fixes topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[ATEN][OP]mean_out op does not update value of given parameter out.

6 participants