KEMBAR78
Support CUDA autocast for `torch.linalg.vecdot()` · Issue #108127 · pytorch/pytorch · GitHub
Skip to content

Support CUDA autocast for torch.linalg.vecdot() #108127

@gau-nernst

Description

@gau-nernst

🚀 The feature, motivation and pitch

Currently torch.linalg.vecdot() does not work with CUDA autocast. Documentation also does not mention that vecdot is supported: https://pytorch.org/docs/stable/amp.html#cuda-ops-that-can-autocast-to-float16

import torch

x = torch.randn(2, 100, 256, device="cuda", dtype=torch.float16)
y = torch.randn(2, 100, 256, device="cuda")

with torch.autocast("cuda", torch.float16):
    out = torch.linalg.vecdot(x, y)
    # out = (x.unsqueeze(-2) @ y.unsqueeze(-1)).squeeze(-1)  # this works
    # out = (x * y).sum(-1)  # this works, but promotes x to float32 and calculates in float32
    # out = torch.einsum("...i,...i->...", x, y)  # this works too

print(out.dtype)

Alternatives

Workarounds (included in code snippet above):

  • Adjust dimensions and use matmul: looks pretty ugly, hard to understand, and may be prone to errors
  • Point-wise multiply and sum: easy to understand, but materialize an intermediate tensor. Also, calculation is performed in float32
  • torch.einsum(): works well. I haven't benchmarked the speed to see if it's slower or faster. Although it's also easy to understand, using vecdot() is still slightly clearer.

Additional context

No response

cc @ptrblck @jianyuh @nikitaved @pearu @mruberry @walterddr @IvanYashchuk @xwang233 @lezcano @mcarilli @leslie-fang-intel @jgong5

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: amp (automated mixed precision)autocastmodule: cudaRelated to torch.cuda, and CUDA support in generalmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmultriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions