KEMBAR78
[ATen][CUDA] Implement 128 bit vectorization v2 by Aidyn-A · Pull Request #145746 · pytorch/pytorch · GitHub
Skip to content

Conversation

@Aidyn-A
Copy link
Collaborator

@Aidyn-A Aidyn-A commented Jan 27, 2025

This is a re-base PR to my previous one #141959.

Description from the original PR:

This PR implements 128-bit vectorization. It improves the performance of contiguous elementwise ops by 4-10% on Hopper H100.

The benchmark code used
import time
import torch
from torch.profiler import profile, ProfilerActivity


def benchmark(function, dtype=torch.float32, check_numerics=True, print_profile=False):
    device = torch.device("cuda")

    shapes = []
    for p in range(24, 30):
        shape = 1<<p
        shapes.append(shape)

    for shape in shapes:
        for _ in range(6):
            x = torch.randn(shape, device=device, dtype=dtype)
            y = function(x)

        if print_profile:
            x = torch.randn(shape, device=device, dtype=dtype)
            with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
                y = function(x)
            print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

        x = torch.randn(shape, device=device, dtype=dtype)
        torch.cuda.synchronize()
        t1 = time.perf_counter()
        for _ in range(6):
            y = function(x)
        torch.cuda.synchronize()
        t2 = time.perf_counter()
        perf_time = (t2 - t1) / 6

        print(f"{function.__name__}, {dtype}, {shape}, {perf_time}")
        if check_numerics:
            x_cpu = x.cpu()
            y_cpu = function(x_cpu).cuda()
            try:
                torch.testing.assert_allclose(y_cpu, y)
            except AssertionError as error:
                print("An exception occurred:", error)


def main():
    ops = [
            torch.relu,
            torch.sigmoid,
            torch.tanh,
            torch.nn.functional.gelu,
            torch.sin,
            torch.exp,
    ]

    dtypes = [
            torch.float16,
            torch.bfloat16,
            torch.float32,
    ]

    for op in ops:
        for dtype in dtypes:
            benchmark(op, dtype=dtype)
            torch.cuda.empty_cache()

if __name__ == "__main__":
    main()
Results
op dtype size time after time before % improvement
relu torch.float16 33554432 4.84E-05 5.06E-05 4.66296539127052
relu torch.float16 67108864 9.22E-05 9.64E-05 4.56491432752297
relu torch.float16 134217728 0.000180343495837102 0.000187981834945579 4.23543919508829
relu torch.float16 268435456 0.000355071155354381 0.000370856161074092 4.44558942107169
relu torch.float16 536870912 0.000704489842367669 0.000736006341564159 4.47366268483987
relu torch.bfloat16 16777216 3.03E-05 3.04E-05 0.166504085842689
relu torch.bfloat16 33554432 4.89E-05 5.06E-05 3.45848238875716
relu torch.bfloat16 67108864 9.32E-05 9.65E-05 3.56122651631445
relu torch.bfloat16 134217728 0.000180805509444326 0.000187998676362137 3.97840029317567
relu torch.bfloat16 268435456 0.000356242332297067 0.000371279485989362 4.22104627356745
relu torch.bfloat16 536870912 0.000708114336399982 0.000736773828975856 4.04729732229083
relu torch.float32 16777216 5.61E-05 5.61E-05 0.0442587268354941
relu torch.float32 33554432 9.33E-05 9.30E-05 -0.259070913799022
relu torch.float32 67108864 0.000181321326332788 0.000181289506144822 -0.0175490597877115
relu torch.float32 134217728 0.000356896334172537 0.000356570177245885 -0.0913870206618981
relu torch.float32 268435456 0.000709421835684528 0.000707465515006334 -0.275762681635911
relu torch.float32 536870912 0.00141372415237129 0.00141036518228551 -0.237597276678471
sigmoid torch.float16 16777216 3.10E-05 3.16E-05 2.10012593866895
sigmoid torch.float16 33554432 4.91E-05 5.23E-05 6.37710600666122
sigmoid torch.float16 67108864 9.30E-05 0.000100057009452333 7.61866144555331
sigmoid torch.float16 134217728 0.000180928347011407 0.000194982004662355 7.76752669390248
sigmoid torch.float16 268435456 0.000355658994521946 0.00038468533117945 8.16128288742412
sigmoid torch.float16 536870912 0.000705982849467546 0.000764021339515845 8.22094900634937
sigmoid torch.bfloat16 16777216 3.08E-05 3.17E-05 2.90965915673149
sigmoid torch.bfloat16 33554432 4.87E-05 5.24E-05 7.63503884668234
sigmoid torch.bfloat16 67108864 9.33E-05 0.000100019678939134 7.21238137428013
sigmoid torch.bfloat16 134217728 0.000180786165098349 0.000194868014659733 7.78922964250206
sigmoid torch.bfloat16 268435456 0.000355564659306159 0.000384909333661199 8.25297835063321
sigmoid torch.bfloat16 536870912 0.000705831005082776 0.000764102345177283 8.2557070566308
sigmoid torch.float32 16777216 4.93E-05 5.65E-05 14.5314136197766
sigmoid torch.float32 33554432 9.32E-05 9.31E-05 -0.120169865610833
sigmoid torch.float32 67108864 0.000181328505277634 0.000180455681402236 -0.481349512069855
sigmoid torch.float32 134217728 0.000357362829769651 0.000356093340087682 -0.35523831137877
sigmoid torch.float32 268435456 0.000708921831877281 0.000707052337626616 -0.263709504574663
sigmoid torch.float32 536870912 0.00141358317341656 0.0014090768333214 -0.318788464654745
tanh torch.float16 16777216 3.03E-05 3.03E-05 -0.0912564658661808
tanh torch.float16 33554432 4.90E-05 5.07E-05 3.46644442974484
tanh torch.float16 67108864 9.30E-05 9.68E-05 3.99871369815531
tanh torch.float16 134217728 0.00018052199933057 0.000188717152923346 4.53969799978138
tanh torch.float16 268435456 0.000355684508879979 0.000373026006855071 4.8755280430115
tanh torch.float16 536870912 0.000706660988119741 0.000740105014604827 4.73268328765002
tanh torch.bfloat16 16777216 2.99E-05 3.03E-05 1.21049563135981
tanh torch.bfloat16 33554432 4.89E-05 5.06E-05 3.48836101041744
tanh torch.bfloat16 67108864 9.28E-05 9.69E-05 4.39944918036626
tanh torch.bfloat16 134217728 0.000180710999605556 0.000189167990659674 4.67984299382829
tanh torch.bfloat16 268435456 0.000356062994493792 0.000372666652159144 4.66312363882606
tanh torch.bfloat16 536870912 0.000707100164921333 0.000740134331863374 4.67178040408393
tanh torch.float32 16777216 5.61E-05 5.64E-05 0.439595755746353
tanh torch.float32 33554432 9.31E-05 9.31E-05 0.00287633090228212
tanh torch.float32 67108864 0.000181465332085888 0.000180895323865116 -0.31411411437098
tanh torch.float32 134217728 0.000356963835656643 0.000356073161431899 -0.249513854283251
tanh torch.float32 268435456 0.000709201170442005 0.00070707315656667 -0.300057862849997
tanh torch.float32 536870912 0.00141367283261692 0.00141030051357423 -0.238550176877922
gelu torch.float16 16777216 2.73E-05 3.17E-05 15.921079070745
gelu torch.float16 33554432 5.06E-05 5.55E-05 9.76345374333098
gelu torch.float16 67108864 9.65E-05 0.000106600326641152 10.4308039074712
gelu torch.float16 134217728 0.000187776672343413 0.000208565829476962 11.0712139447915
gelu torch.float16 268435456 0.000370216167842348 0.000412251994324227 11.3544005187205
gelu torch.float16 536870912 0.000737301345604161 0.000819394170927505 11.1342296895002
gelu torch.bfloat16 16777216 3.02E-05 3.08E-05 1.78405479367653
gelu torch.bfloat16 33554432 5.13E-05 5.69E-05 10.9929393318302
gelu torch.bfloat16 67108864 9.76E-05 0.00010968199543034 12.3420807512356
gelu torch.bfloat16 134217728 0.000189661824454864 0.000214487663470209 13.0895287371091
gelu torch.bfloat16 268435456 0.000374197009174774 0.000423670164309442 13.2211519391275
gelu torch.bfloat16 536870912 0.000743675006863972 0.000842577001700799 13.299088166737
gelu torch.float32 16777216 5.06E-05 5.04E-05 -0.413385894716413
gelu torch.float32 33554432 9.31E-05 9.32E-05 0.134157041722546
gelu torch.float32 67108864 0.000181480175039421 0.000180836669945469 -0.354586992112075
gelu torch.float32 134217728 0.000356874331676712 0.000356305002545317 -0.159532104402047
gelu torch.float32 268435456 0.000708909006789327 0.000706991491218408 -0.270488250615287
gelu torch.float32 536870912 0.00141321367118508 0.00140937082081412 -0.271922813181618
sin torch.float16 16777216 3.04E-05 3.11E-05 2.21834939018859
sin torch.float16 33554432 4.85E-05 5.23E-05 7.72165512511596
sin torch.float16 67108864 9.31E-05 9.98E-05 7.24947099480072
sin torch.float16 134217728 0.000180371008658161 0.000194791161144773 7.99471744039613
sin torch.float16 268435456 0.000355454161763191 0.000384903668115536 8.28503630574026
sin torch.float16 536870912 0.000705183832906187 0.000764360166310022 8.39161799270973
sin torch.bfloat16 16777216 3.11E-05 3.10E-05 -0.257677954940036
sin torch.bfloat16 33554432 4.89E-05 5.24E-05 7.34808420323539
sin torch.bfloat16 67108864 9.26E-05 0.000100248667877167 8.22347488801205
sin torch.bfloat16 134217728 0.000180674154156198 0.00019567032965521 8.30012215584937
sin torch.bfloat16 268435456 0.000355360486234228 0.000386023331278314 8.62865913118873
sin torch.bfloat16 536870912 0.00070483615854755 0.000766805159704139 8.79197248964745
sin torch.float32 16777216 5.67E-05 5.64E-05 -0.441348534920039
sin torch.float32 33554432 9.34E-05 9.30E-05 -0.496458540364117
sin torch.float32 67108864 0.000181706990891447 0.000180556671693921 -0.633062708199702
sin torch.float32 134217728 0.000356894995396336 0.000356046327700218 -0.237791985616354
sin torch.float32 268435456 0.000708777321657787 0.000707602652255446 -0.165731798471427
sin torch.float32 536870912 0.00141263716310884 0.00140912582476934 -0.248566187496451
exp torch.float16 16777216 3.00E-05 3.04E-05 1.40099098901014
exp torch.float16 33554432 4.86E-05 5.03E-05 3.44611943643906
exp torch.float16 67108864 9.37E-05 9.55E-05 1.96412400380129
exp torch.float16 134217728 0.000180913504057874 0.000187193179347863 3.47109262113439
exp torch.float16 268435456 0.00035607748820136 0.000369079003576189 3.65131630210701
exp torch.float16 536870912 0.000707551507124056 0.000732363162872692 3.50669251620789
exp torch.bfloat16 16777216 2.98E-05 3.04E-05 1.74345594341654
exp torch.bfloat16 33554432 4.88E-05 5.04E-05 3.40217856534821
exp torch.bfloat16 67108864 9.32E-05 9.62E-05 3.29219958210226
exp torch.bfloat16 134217728 0.000180999826019009 0.000187239318620414 3.44723679499521
exp torch.bfloat16 268435456 0.000355944503098726 0.000369370992605885 3.77207384585864
exp torch.bfloat16 536870912 0.000707135167128096 0.000733066000975668 3.66702648277075
exp torch.float32 16777216 4.89E-05 5.63E-05 15.1245314346532
exp torch.float32 33554432 9.34E-05 9.31E-05 -0.259945454477446
exp torch.float32 67108864 0.000181152504713585 0.000180474346658836 -0.374357536939058
exp torch.float32 134217728 0.000356771342922002 0.000355627329554409 -0.3206573034212
exp torch.float32 268435456 0.000708404501589636 0.00070713268360123 -0.179532736671163
exp torch.float32 536870912 0.00141283582585553 0.00140944866385932 -0.23974208002295

cc @msaroufim @ptrblck @eqy @manuelcandales @SherlockNoMad @angelayi

@Aidyn-A Aidyn-A added module: cuda Related to torch.cuda, and CUDA support in general topic: not user facing topic category module: core aten Related to change to the Core ATen opset labels Jan 27, 2025
@Aidyn-A Aidyn-A requested review from eqy and ngimel January 27, 2025 18:40
@Aidyn-A Aidyn-A self-assigned this Jan 27, 2025
@Aidyn-A Aidyn-A requested a review from syed-ahmed as a code owner January 27, 2025 18:40
@pytorch-bot
Copy link

pytorch-bot bot commented Jan 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/145746

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit 575b250 with merge base 0674ab7 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@Skylion007
Copy link
Collaborator

Did we confirm not regresion on A100 per comments on the prev PR?

@Aidyn-A
Copy link
Collaborator Author

Aidyn-A commented Jan 27, 2025

Did we confirm not regresion on A100 per comments on the prev PR?

Yes, the regression is due to a compiler (NVCC) bug. Moreover, I discovered the bug is present on H100 as well. I have omitted vec8/vec16 for 1-byte data on all archs to workaround the bug.

@Skylion007
Copy link
Collaborator

Skylion007 commented Jan 27, 2025

Did we confirm not regresion on A100 per comments on the prev PR?

Yes, the regression is due to a compiler (NVCC) bug. Moreover, I discovered the bug is present on H100 as well. I have omitted vec8/vec16 for 1-byte data on all archs to workaround the bug.

What versions of NVCC are affected? We are potentially planning dropping old NVCC support (11.8/12.4)

@Aidyn-A
Copy link
Collaborator Author

Aidyn-A commented Jan 27, 2025

What versions of NVCC are affected? We are potentially planning dropping old NVCC support (11.8/12.4)

I have not checked 11.8, but 12.4-12.6 are affected. 12.8 is doing fine if nvvm-latest flag was enforced. Though, I am not certain if it is generally safe to force nvvm-latest on sm_80-90.

if constexpr (io_sizes == 1) {
return 16;
} else {
return 8;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is elems_per_thread = 8 allaround better than 4 we mostly used previously?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I observed little to no difference. The biggest improvement come from vec8.

uint16_t vec_size = 16 / static_cast<uint16_t>(sizeof(cpp_type));
vec_size = std::min<uint16_t>(vec_size, max_vec_size);
if (sizeof(cpp_type) < 2) {
vec_size = std::min<uint16_t>(vec_size, 4);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you setting max vec size to 4 here for 1 byte datatypes? Is it to workaround that bug? Can you leave a comment then?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. This is a workaround that bug. I have left a comment that explains it.

@Aidyn-A Aidyn-A added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 29, 2025
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 29, 2025 07:41 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 29, 2025 07:41 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 29, 2025 07:41 Inactive
@Aidyn-A
Copy link
Collaborator Author

Aidyn-A commented Jan 29, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@Aidyn-A
Copy link
Collaborator Author

Aidyn-A commented Jan 31, 2025

The lint failures are unrelated

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 2 checks: Lint / lintrunner-noclang / linux-job, Lint / Test tools / linux-job

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Copy link
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it should be guarded with __CUDA_ARCH__ >= 90 to avoid compile time increases for older architectures

@ngimel
Copy link
Collaborator

ngimel commented Feb 25, 2025

I think it helps perf even on H100

@malfet
Copy link
Contributor

malfet commented Feb 25, 2025

I think it helps perf even on H100

I've made off by one error in my calculations anyway, CUDA_ARCH > 10 are Fermi and above :)

@Aidyn-A
Copy link
Collaborator Author

Aidyn-A commented Feb 28, 2025

IMO it should be guarded with __CUDA_ARCH__ >= 90 to avoid compile time increases for older architectures

I apologize for the delay, but it seems like I will need to do some refactoring to achieve it, because __CUDA_ARCH__ is available in kernels and device functions only.

atalman added a commit to atalman/pytorch that referenced this pull request Apr 4, 2025
pytorchmergebot pushed a commit that referenced this pull request Apr 8, 2025
By addressing a feedback requested at #145746
Pull Request resolved: #150705
Approved by: https://github.com/atalman
pytorchbot pushed a commit that referenced this pull request Apr 8, 2025
By addressing a feedback requested at #145746
Pull Request resolved: #150705
Approved by: https://github.com/atalman

(cherry picked from commit 5228986)
atalman pushed a commit that referenced this pull request Apr 8, 2025
[CUDA] Only use vec128 if CUDA version is newer than 12.8 (#150705)

By addressing a feedback requested at #145746
Pull Request resolved: #150705
Approved by: https://github.com/atalman

(cherry picked from commit 5228986)

Co-authored-by: Nikita Shulga <nshulga@meta.com>
malfet added a commit that referenced this pull request Apr 9, 2025
timocafe pushed a commit to timocafe/pytorch that referenced this pull request Apr 16, 2025
amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
pytorchmergebot pushed a commit that referenced this pull request May 2, 2025
Fixes #147376.
As per request: #145746 (review)
This PR omits sm80 or older of using vec8 kernels due to long compilation and large binary size.

Pull Request resolved: #148320
Approved by: https://github.com/eqy, https://github.com/malfet, https://github.com/atalman
pytorchbot pushed a commit that referenced this pull request May 6, 2025
Fixes #147376.
As per request: #145746 (review)
This PR omits sm80 or older of using vec8 kernels due to long compilation and large binary size.

Pull Request resolved: #148320
Approved by: https://github.com/eqy, https://github.com/malfet, https://github.com/atalman

(cherry picked from commit 72337bd)
malfet pushed a commit that referenced this pull request May 7, 2025
[ATen][CUDA] Optimize 128 bit vectorization (#148320)

Fixes #147376.
As per request: #145746 (review)
This PR omits sm80 or older of using vec8 kernels due to long compilation and large binary size.

Pull Request resolved: #148320
Approved by: https://github.com/eqy, https://github.com/malfet, https://github.com/atalman

(cherry picked from commit 72337bd)

Co-authored-by: Aidyn-A <31858918+Aidyn-A@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/trunk Trigger trunk jobs on your pull request Merged module: core aten Related to change to the Core ATen opset module: cuda Related to torch.cuda, and CUDA support in general open source Reverted topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants