KEMBAR78
make torch.amp.autocast more generic by guangyey · Pull Request #125103 · pytorch/pytorch · GitHub
Skip to content

Conversation

@guangyey
Copy link
Collaborator

@guangyey guangyey commented Apr 27, 2024

Stack from ghstack (oldest at bottom):

Motivation

As discussed in #124479, torch.amp.autocast can NOT be completely equivalent to torch.cuda.amp.autocast and torch.cpu.amp.autocast since torch.amp.autocast has NOT the default dtype for CPU (torch.bfloat16 by default) and CUDA (torch.float16 by default) respectively. We would like torch.amp.autocast to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast torch.xxx.amp.autocast for each device backend.

Solution

When None is passed to dtype, we should use torch.get_autocast_dtype to get the related dtype for each backend. Meanwhile, torch.get_autocast_dtype is necessary to be supported in JIT path for BC.

Additional Context

With this PR, torch.amp.autocast(device_type='cuda') is equivalent to torch.cuda.amp.autocast.
Add two new UTs to cover this change in eager and jit path respectively.

cc @mcarilli @ptrblck @leslie-fang-intel @jgong5 @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 27, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125103

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e11d24b with merge base 5007312 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@guangyey guangyey changed the title make torch.amp.autocast more generic [WIP] make torch.amp.autocast more generic Apr 27, 2024
@guangyey guangyey marked this pull request as draft April 27, 2024 15:42
@guangyey guangyey added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 27, 2024
cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng

[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added the release notes: jit release notes category label Apr 28, 2024
guangyey added a commit that referenced this pull request Apr 28, 2024
ghstack-source-id: 4e58007
Pull Request resolved: #125103
cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng

[ghstack-poisoned]
cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng

[ghstack-poisoned]
@guangyey guangyey added the topic: improvements topic category label May 6, 2024
guangyey added a commit that referenced this pull request May 6, 2024
ghstack-source-id: 31ac5e5
Pull Request resolved: #125103
@guangyey guangyey changed the title [WIP] make torch.amp.autocast more generic make torch.amp.autocast more generic May 6, 2024
@guangyey guangyey marked this pull request as ready for review May 6, 2024 05:52
)
global_state["grad_enabled"] = (torch.set_grad_enabled, torch.is_grad_enabled())

def autocast_specific_backend(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

code improvements.

@guangyey
Copy link
Collaborator Author

guangyey commented May 7, 2024

@albanD This PR intends to make torch.amp.autocast to be more generic. Developers can use it to write device-agnostic code instead of using torch.cuda.amp.autocast or torch.cpu.amp.autocast. Is it reasonable?

Comment on lines +210 to +211
if dtype is None:
dtype = torch.get_autocast_dtype(device_type)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should update the doc to mention the new default value for this arg?

) if torch.amp.is_autocast_available(device) else contextlib.nullcontext()
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
recompute_context:
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ho we gather and restore both the cpu context and another device's context here?
This makes this code a bit weird. But sounds fair. We definitely don't want to change the behavior here.

cc @soulitzer in case this is something you want to clean up for AC in general in a follow upnow that we have the nice API

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't change the behavior here, just use torch.amp.autocast to be more generic code and leave the logic as it is.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep perfect!

# Motivation
As discussed in [#124479](#124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend.

# Solution
When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC.

# Additional Context
With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`.
Add two new UTs to cover this change in eager and jit path respectively.

cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng

[ghstack-poisoned]
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit in doc, sounds good otherwise.

Default: ``True``
dtype(torch_dtype, optional): Whether to use torch.float16 or torch.bfloat16.
dtype(torch_dtype, optional): Data type for ops run in autocast. It uses the default value
(``torch.float16`` for CUDA and ``torch.bfloat16`` for CPU, by default), given by
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(``torch.float16`` for CUDA and ``torch.bfloat16`` for CPU, by default), given by
(``torch.float16`` for CUDA and ``torch.bfloat16`` for CPU), given by

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated.

# Motivation
As discussed in [#124479](#124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend.

# Solution
When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC.

# Additional Context
With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`.
Add two new UTs to cover this change in eager and jit path respectively.

cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng

[ghstack-poisoned]
guangyey added a commit that referenced this pull request May 8, 2024
ghstack-source-id: c05a9e0
Pull Request resolved: #125103
@guangyey
Copy link
Collaborator Author

guangyey commented May 8, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

guangyey added 2 commits May 8, 2024 10:19
# Motivation
As discussed in [#124479](#124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend.

# Solution
When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC.

# Additional Context
With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`.
Add two new UTs to cover this change in eager and jit path respectively.

cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng

[ghstack-poisoned]
# Motivation
As discussed in [#124479](#124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend.

# Solution
When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC.

# Additional Context
With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`.
Add two new UTs to cover this change in eager and jit path respectively.

cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng

[ghstack-poisoned]
facebook-github-bot pushed a commit to pytorch/benchmark that referenced this pull request May 9, 2024
Summary:
# Motivation
As discussed in [#124479](pytorch/pytorch#124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend.

# Solution
When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC.

# Additional Context
With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`.
Add two new UTs to cover this change in eager and jit path respectively.

X-link: pytorch/pytorch#125103
Approved by: https://github.com/albanD, https://github.com/jgong5, https://github.com/gujinghui

Reviewed By: izaitsevfb

Differential Revision: D57138276

fbshipit-source-id: 17f883924e43f68dd6836d99b06fe8a47cfccbf6
guangyey added a commit that referenced this pull request May 13, 2024
# Motivation
We generalize a device-agnostic API `torch.amp.autocast` in [#125103](#125103).  After that,
- `torch.cpu.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cpu', args...)`, and
- `torch.cuda.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cuda', args...)`

no matter in eager mode or JIT mode.
Base on this point, we would like to deprecate `torch.cpu.amp.autocast` and `torch.cuda.amp.autocast` to **strongly recommend** developer to use `torch.amp.autocast` that is a device-agnostic API.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
guangyey added a commit that referenced this pull request May 14, 2024
# Motivation
We generalize a device-agnostic API `torch.amp.autocast` in [#125103](#125103).  After that,
- `torch.cpu.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cpu', args...)`, and
- `torch.cuda.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cuda', args...)`

no matter in eager mode or JIT mode.
Base on this point, we would like to deprecate `torch.cpu.amp.autocast` and `torch.cuda.amp.autocast` to **strongly recommend** developer to use `torch.amp.autocast` that is a device-agnostic API.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
guangyey added a commit that referenced this pull request May 15, 2024
# Motivation
We generalize a device-agnostic API `torch.amp.autocast` in [#125103](#125103).  After that,
- `torch.cpu.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cpu', args...)`, and
- `torch.cuda.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cuda', args...)`

no matter in eager mode or JIT mode.
Base on this point, we would like to deprecate `torch.cpu.amp.autocast` and `torch.cuda.amp.autocast` to **strongly recommend** developer to use `torch.amp.autocast` that is a device-agnostic API.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
guangyey added a commit that referenced this pull request May 15, 2024
# Motivation
We generalize a device-agnostic API `torch.amp.autocast` in [#125103](#125103).  After that,
- `torch.cpu.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cpu', args...)`, and
- `torch.cuda.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cuda', args...)`

no matter in eager mode or JIT mode.
Base on this point, we would like to deprecate `torch.cpu.amp.autocast` and `torch.cuda.amp.autocast` to **strongly recommend** developer to use `torch.amp.autocast` that is a device-agnostic API.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 mcarilli ptrblck leslie-fang-intel

[ghstack-poisoned]
guangyey added a commit that referenced this pull request May 15, 2024
# Motivation
We generalize a device-agnostic API `torch.amp.autocast` in [#125103](#125103).  After that,
- `torch.cpu.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cpu', args...)`, and
- `torch.cuda.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cuda', args...)`

no matter in eager mode or JIT mode.
Base on this point, we would like to deprecate `torch.cpu.amp.autocast` and `torch.cuda.amp.autocast` to **strongly recommend** developer to use `torch.amp.autocast` that is a device-agnostic API.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 mcarilli ptrblck leslie-fang-intel

[ghstack-poisoned]
guangyey added a commit that referenced this pull request May 15, 2024
# Motivation
We generalize a device-agnostic API `torch.amp.autocast` in [#125103](#125103).  After that,
- `torch.cpu.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cpu', args...)`, and
- `torch.cuda.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cuda', args...)`

no matter in eager mode or JIT mode.
Base on this point, we would like to deprecate `torch.cpu.amp.autocast` and `torch.cuda.amp.autocast` to **strongly recommend** developer to use `torch.amp.autocast` that is a device-agnostic API.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 mcarilli ptrblck leslie-fang-intel

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request May 16, 2024
# Motivation
We generalize a device-agnostic API `torch.amp.autocast` in [#125103](#125103).  After that,
- `torch.cpu.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cpu', args...)`, and
- `torch.cuda.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cuda', args...)`

no matter in eager mode or JIT mode.
Base on this point, we would like to deprecate `torch.cpu.amp.autocast` and `torch.cuda.amp.autocast` to **strongly recommend** developer to use `torch.amp.autocast` that is a device-agnostic API.

Pull Request resolved: #126062
Approved by: https://github.com/eqy, https://github.com/albanD
ZelboK pushed a commit to ZelboK/pytorch that referenced this pull request May 19, 2024
# Motivation
We generalize a device-agnostic API `torch.amp.autocast` in [pytorch#125103](pytorch#125103).  After that,
- `torch.cpu.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cpu', args...)`, and
- `torch.cuda.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cuda', args...)`

no matter in eager mode or JIT mode.
Base on this point, we would like to deprecate `torch.cpu.amp.autocast` and `torch.cuda.amp.autocast` to **strongly recommend** developer to use `torch.amp.autocast` that is a device-agnostic API.

Pull Request resolved: pytorch#126062
Approved by: https://github.com/eqy, https://github.com/albanD
@github-actions github-actions bot deleted the gh/guangyey/26/head branch June 8, 2024 01:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

6 participants