-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Closed
Labels
module: amp (automated mixed precision)autocastautocastmodule: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generalmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmulIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmultriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🚀 The feature, motivation and pitch
Currently torch.linalg.vecdot()
does not work with CUDA autocast. Documentation also does not mention that vecdot is supported: https://pytorch.org/docs/stable/amp.html#cuda-ops-that-can-autocast-to-float16
import torch
x = torch.randn(2, 100, 256, device="cuda", dtype=torch.float16)
y = torch.randn(2, 100, 256, device="cuda")
with torch.autocast("cuda", torch.float16):
out = torch.linalg.vecdot(x, y)
# out = (x.unsqueeze(-2) @ y.unsqueeze(-1)).squeeze(-1) # this works
# out = (x * y).sum(-1) # this works, but promotes x to float32 and calculates in float32
# out = torch.einsum("...i,...i->...", x, y) # this works too
print(out.dtype)
Alternatives
Workarounds (included in code snippet above):
- Adjust dimensions and use matmul: looks pretty ugly, hard to understand, and may be prone to errors
- Point-wise multiply and sum: easy to understand, but materialize an intermediate tensor. Also, calculation is performed in float32
torch.einsum()
: works well. I haven't benchmarked the speed to see if it's slower or faster. Although it's also easy to understand, usingvecdot()
is still slightly clearer.
Additional context
No response
cc @ptrblck @jianyuh @nikitaved @pearu @mruberry @walterddr @IvanYashchuk @xwang233 @lezcano @mcarilli @leslie-fang-intel @jgong5
Metadata
Metadata
Assignees
Labels
module: amp (automated mixed precision)autocastautocastmodule: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generalmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmulIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmultriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module