-
Notifications
You must be signed in to change notification settings - Fork 25.7k
make torch.amp.autocast more generic #125103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125103
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit e11d24b with merge base 5007312 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng [ghstack-poisoned]
cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng [ghstack-poisoned]
cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng [ghstack-poisoned]
| ) | ||
| global_state["grad_enabled"] = (torch.set_grad_enabled, torch.is_grad_enabled()) | ||
|
|
||
| def autocast_specific_backend( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
code improvements.
|
@albanD This PR intends to make torch.amp.autocast to be more generic. Developers can use it to write device-agnostic code instead of using |
| if dtype is None: | ||
| dtype = torch.get_autocast_dtype(device_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should update the doc to mention the new default value for this arg?
| ) if torch.amp.is_autocast_available(device) else contextlib.nullcontext() | ||
| with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \ | ||
| recompute_context: | ||
| with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ho we gather and restore both the cpu context and another device's context here?
This makes this code a bit weird. But sounds fair. We definitely don't want to change the behavior here.
cc @soulitzer in case this is something you want to clean up for AC in general in a follow upnow that we have the nice API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't change the behavior here, just use torch.amp.autocast to be more generic code and leave the logic as it is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep perfect!
# Motivation As discussed in [#124479](#124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend. # Solution When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC. # Additional Context With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`. Add two new UTs to cover this change in eager and jit path respectively. cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit in doc, sounds good otherwise.
torch/amp/autocast_mode.py
Outdated
| Default: ``True`` | ||
| dtype(torch_dtype, optional): Whether to use torch.float16 or torch.bfloat16. | ||
| dtype(torch_dtype, optional): Data type for ops run in autocast. It uses the default value | ||
| (``torch.float16`` for CUDA and ``torch.bfloat16`` for CPU, by default), given by |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| (``torch.float16`` for CUDA and ``torch.bfloat16`` for CPU, by default), given by | |
| (``torch.float16`` for CUDA and ``torch.bfloat16`` for CPU), given by |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated.
# Motivation As discussed in [#124479](#124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend. # Solution When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC. # Additional Context With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`. Add two new UTs to cover this change in eager and jit path respectively. cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
# Motivation As discussed in [#124479](#124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend. # Solution When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC. # Additional Context With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`. Add two new UTs to cover this change in eager and jit path respectively. cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng [ghstack-poisoned]
# Motivation As discussed in [#124479](#124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend. # Solution When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC. # Additional Context With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`. Add two new UTs to cover this change in eager and jit path respectively. cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng [ghstack-poisoned]
Summary: # Motivation As discussed in [#124479](pytorch/pytorch#124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend. # Solution When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC. # Additional Context With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`. Add two new UTs to cover this change in eager and jit path respectively. X-link: pytorch/pytorch#125103 Approved by: https://github.com/albanD, https://github.com/jgong5, https://github.com/gujinghui Reviewed By: izaitsevfb Differential Revision: D57138276 fbshipit-source-id: 17f883924e43f68dd6836d99b06fe8a47cfccbf6
# Motivation We generalize a device-agnostic API `torch.amp.autocast` in [#125103](#125103). After that, - `torch.cpu.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cpu', args...)`, and - `torch.cuda.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cuda', args...)` no matter in eager mode or JIT mode. Base on this point, we would like to deprecate `torch.cpu.amp.autocast` and `torch.cuda.amp.autocast` to **strongly recommend** developer to use `torch.amp.autocast` that is a device-agnostic API. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
# Motivation We generalize a device-agnostic API `torch.amp.autocast` in [#125103](#125103). After that, - `torch.cpu.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cpu', args...)`, and - `torch.cuda.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cuda', args...)` no matter in eager mode or JIT mode. Base on this point, we would like to deprecate `torch.cpu.amp.autocast` and `torch.cuda.amp.autocast` to **strongly recommend** developer to use `torch.amp.autocast` that is a device-agnostic API. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
# Motivation We generalize a device-agnostic API `torch.amp.autocast` in [#125103](#125103). After that, - `torch.cpu.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cpu', args...)`, and - `torch.cuda.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cuda', args...)` no matter in eager mode or JIT mode. Base on this point, we would like to deprecate `torch.cpu.amp.autocast` and `torch.cuda.amp.autocast` to **strongly recommend** developer to use `torch.amp.autocast` that is a device-agnostic API. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
# Motivation We generalize a device-agnostic API `torch.amp.autocast` in [#125103](#125103). After that, - `torch.cpu.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cpu', args...)`, and - `torch.cuda.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cuda', args...)` no matter in eager mode or JIT mode. Base on this point, we would like to deprecate `torch.cpu.amp.autocast` and `torch.cuda.amp.autocast` to **strongly recommend** developer to use `torch.amp.autocast` that is a device-agnostic API. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 mcarilli ptrblck leslie-fang-intel [ghstack-poisoned]
# Motivation We generalize a device-agnostic API `torch.amp.autocast` in [#125103](#125103). After that, - `torch.cpu.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cpu', args...)`, and - `torch.cuda.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cuda', args...)` no matter in eager mode or JIT mode. Base on this point, we would like to deprecate `torch.cpu.amp.autocast` and `torch.cuda.amp.autocast` to **strongly recommend** developer to use `torch.amp.autocast` that is a device-agnostic API. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 mcarilli ptrblck leslie-fang-intel [ghstack-poisoned]
# Motivation We generalize a device-agnostic API `torch.amp.autocast` in [#125103](#125103). After that, - `torch.cpu.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cpu', args...)`, and - `torch.cuda.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cuda', args...)` no matter in eager mode or JIT mode. Base on this point, we would like to deprecate `torch.cpu.amp.autocast` and `torch.cuda.amp.autocast` to **strongly recommend** developer to use `torch.amp.autocast` that is a device-agnostic API. cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 mcarilli ptrblck leslie-fang-intel [ghstack-poisoned]
# Motivation We generalize a device-agnostic API `torch.amp.autocast` in [#125103](#125103). After that, - `torch.cpu.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cpu', args...)`, and - `torch.cuda.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cuda', args...)` no matter in eager mode or JIT mode. Base on this point, we would like to deprecate `torch.cpu.amp.autocast` and `torch.cuda.amp.autocast` to **strongly recommend** developer to use `torch.amp.autocast` that is a device-agnostic API. Pull Request resolved: #126062 Approved by: https://github.com/eqy, https://github.com/albanD
# Motivation We generalize a device-agnostic API `torch.amp.autocast` in [pytorch#125103](pytorch#125103). After that, - `torch.cpu.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cpu', args...)`, and - `torch.cuda.amp.autocast(args...)` is completely equivalent to `torch.amp.autocast('cuda', args...)` no matter in eager mode or JIT mode. Base on this point, we would like to deprecate `torch.cpu.amp.autocast` and `torch.cuda.amp.autocast` to **strongly recommend** developer to use `torch.amp.autocast` that is a device-agnostic API. Pull Request resolved: pytorch#126062 Approved by: https://github.com/eqy, https://github.com/albanD
Stack from ghstack (oldest at bottom):
Motivation
As discussed in #124479,
torch.amp.autocastcan NOT be completely equivalent totorch.cuda.amp.autocastandtorch.cpu.amp.autocastsincetorch.amp.autocasthas NOT the defaultdtypefor CPU (torch.bfloat16by default) and CUDA (torch.float16by default) respectively. We would liketorch.amp.autocastto be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocasttorch.xxx.amp.autocastfor each device backend.Solution
When
Noneis passed todtype, we should usetorch.get_autocast_dtypeto get the related dtype for each backend. Meanwhile,torch.get_autocast_dtypeis necessary to be supported in JIT path for BC.Additional Context
With this PR,
torch.amp.autocast(device_type='cuda')is equivalent totorch.cuda.amp.autocast.Add two new UTs to cover this change in eager and jit path respectively.
cc @mcarilli @ptrblck @leslie-fang-intel @jgong5 @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng