KEMBAR78
[MPS] Add searchsorted op by qqaatw · Pull Request #112829 · pytorch/pytorch · GitHub
Skip to content

Conversation

qqaatw
Copy link
Collaborator

@qqaatw qqaatw commented Nov 3, 2023

Stack from ghstack (oldest at bottom):

The metal kernels implemented are closely following Bucketization.cu.

Benchmark:

[----------------------------- searchsorted ----------------------------]
                                                         |  cpu   |  mps 
1 threads: --------------------------------------------------------------
      Batch size: 8; In features: 64; Sorter: True       |    44  |   530
      Batch size: 8; In features: 64; Sorter: False      |    31  |    12
      Batch size: 8; In features: 256; Sorter: True      |   131  |   520
      Batch size: 8; In features: 256; Sorter: False     |   107  |    12
      Batch size: 8; In features: 1024; Sorter: True     |   499  |   590
      Batch size: 8; In features: 1024; Sorter: False    |   398  |    12
      Batch size: 16; In features: 64; Sorter: True      |    71  |   540
      Batch size: 16; In features: 64; Sorter: False     |    57  |    12
      Batch size: 16; In features: 256; Sorter: True     |   242  |   610
      Batch size: 16; In features: 256; Sorter: False    |   200  |    12
      Batch size: 16; In features: 1024; Sorter: True    |   999  |   720
      Batch size: 16; In features: 1024; Sorter: False   |   842  |    12
      Batch size: 32; In features: 64; Sorter: True      |   124  |   509
      Batch size: 32; In features: 64; Sorter: False     |   103  |    12
      Batch size: 32; In features: 256; Sorter: True     |   477  |   650
      Batch size: 32; In features: 256; Sorter: False    |   407  |    12
      Batch size: 32; In features: 1024; Sorter: True    |  1940  |   833
      Batch size: 32; In features: 1024; Sorter: False   |  1710  |    12
      Batch size: 64; In features: 64; Sorter: True      |   231  |   590
      Batch size: 64; In features: 64; Sorter: False     |   194  |    12
      Batch size: 64; In features: 256; Sorter: True     |   937  |   710
      Batch size: 64; In features: 256; Sorter: False    |   800  |    13
      Batch size: 64; In features: 1024; Sorter: True    |  3980  |  1290
      Batch size: 64; In features: 1024; Sorter: False   |  3330  |    12
      Batch size: 128; In features: 64; Sorter: True     |   448  |   650
      Batch size: 128; In features: 64; Sorter: False    |   390  |    13
      Batch size: 128; In features: 256; Sorter: True    |  1830  |   850
      Batch size: 128; In features: 256; Sorter: False   |  1590  |    12
      Batch size: 128; In features: 1024; Sorter: True   |  7790  |  2850
      Batch size: 128; In features: 1024; Sorter: False  |  6670  |    13

[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels Nov 3, 2023
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 3, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/112829

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 96a4912 with merge base 3a284da (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@qqaatw
Copy link
Collaborator Author

qqaatw commented Nov 6, 2023

Hi @albanD @malfet, can you take a look at this stack please?

Copy link
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overal LGTM, but please add description (that says that it implements operator as a metal kernel following closely Bucketization.cu)
Also, would be good to add some sort of perf numbers (to show that it's faster than CPU or large enough tensors)

The metal kernels implemented are closely following `Bucketization.cu`.

```
[----------------------------- searchsorted ----------------------------]
                                                         |  cpu   |  mps 
1 threads: --------------------------------------------------------------
      Batch size: 8; In features: 64; Sorter: True       |    44  |   530
      Batch size: 8; In features: 64; Sorter: False      |    31  |    12
      Batch size: 8; In features: 256; Sorter: True      |   131  |   520
      Batch size: 8; In features: 256; Sorter: False     |   107  |    12
      Batch size: 8; In features: 1024; Sorter: True     |   499  |   590
      Batch size: 8; In features: 1024; Sorter: False    |   398  |    12
      Batch size: 16; In features: 64; Sorter: True      |    71  |   540
      Batch size: 16; In features: 64; Sorter: False     |    57  |    12
      Batch size: 16; In features: 256; Sorter: True     |   242  |   610
      Batch size: 16; In features: 256; Sorter: False    |   200  |    12
      Batch size: 16; In features: 1024; Sorter: True    |   999  |   720
      Batch size: 16; In features: 1024; Sorter: False   |   842  |    12
      Batch size: 32; In features: 64; Sorter: True      |   124  |   509
      Batch size: 32; In features: 64; Sorter: False     |   103  |    12
      Batch size: 32; In features: 256; Sorter: True     |   477  |   650
      Batch size: 32; In features: 256; Sorter: False    |   407  |    12
      Batch size: 32; In features: 1024; Sorter: True    |  1940  |   833
      Batch size: 32; In features: 1024; Sorter: False   |  1710  |    12
      Batch size: 64; In features: 64; Sorter: True      |   231  |   590
      Batch size: 64; In features: 64; Sorter: False     |   194  |    12
      Batch size: 64; In features: 256; Sorter: True     |   937  |   710
      Batch size: 64; In features: 256; Sorter: False    |   800  |    13
      Batch size: 64; In features: 1024; Sorter: True    |  3980  |  1290
      Batch size: 64; In features: 1024; Sorter: False   |  3330  |    12
      Batch size: 128; In features: 64; Sorter: True     |   448  |   650
      Batch size: 128; In features: 64; Sorter: False    |   390  |    13
      Batch size: 128; In features: 256; Sorter: True    |  1830  |   850
      Batch size: 128; In features: 256; Sorter: False   |  1590  |    12
      Batch size: 128; In features: 1024; Sorter: True   |  7790  |  2850
      Batch size: 128; In features: 1024; Sorter: False  |  6670  |    13
```


[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Nov 7, 2023
@facebook-github-bot facebook-github-bot deleted the gh/qqaatw/26/head branch November 11, 2023 15:24
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
The metal kernels implemented are closely following `Bucketization.cu`.

Benchmark:
```
[----------------------------- searchsorted ----------------------------]
                                                         |  cpu   |  mps
1 threads: --------------------------------------------------------------
      Batch size: 8; In features: 64; Sorter: True       |    44  |   530
      Batch size: 8; In features: 64; Sorter: False      |    31  |    12
      Batch size: 8; In features: 256; Sorter: True      |   131  |   520
      Batch size: 8; In features: 256; Sorter: False     |   107  |    12
      Batch size: 8; In features: 1024; Sorter: True     |   499  |   590
      Batch size: 8; In features: 1024; Sorter: False    |   398  |    12
      Batch size: 16; In features: 64; Sorter: True      |    71  |   540
      Batch size: 16; In features: 64; Sorter: False     |    57  |    12
      Batch size: 16; In features: 256; Sorter: True     |   242  |   610
      Batch size: 16; In features: 256; Sorter: False    |   200  |    12
      Batch size: 16; In features: 1024; Sorter: True    |   999  |   720
      Batch size: 16; In features: 1024; Sorter: False   |   842  |    12
      Batch size: 32; In features: 64; Sorter: True      |   124  |   509
      Batch size: 32; In features: 64; Sorter: False     |   103  |    12
      Batch size: 32; In features: 256; Sorter: True     |   477  |   650
      Batch size: 32; In features: 256; Sorter: False    |   407  |    12
      Batch size: 32; In features: 1024; Sorter: True    |  1940  |   833
      Batch size: 32; In features: 1024; Sorter: False   |  1710  |    12
      Batch size: 64; In features: 64; Sorter: True      |   231  |   590
      Batch size: 64; In features: 64; Sorter: False     |   194  |    12
      Batch size: 64; In features: 256; Sorter: True     |   937  |   710
      Batch size: 64; In features: 256; Sorter: False    |   800  |    13
      Batch size: 64; In features: 1024; Sorter: True    |  3980  |  1290
      Batch size: 64; In features: 1024; Sorter: False   |  3330  |    12
      Batch size: 128; In features: 64; Sorter: True     |   448  |   650
      Batch size: 128; In features: 64; Sorter: False    |   390  |    13
      Batch size: 128; In features: 256; Sorter: True    |  1830  |   850
      Batch size: 128; In features: 256; Sorter: False   |  1590  |    12
      Batch size: 128; In features: 1024; Sorter: True   |  7790  |  2850
      Batch size: 128; In features: 1024; Sorter: False  |  6670  |    13
```

Pull Request resolved: pytorch#112829
Approved by: https://github.com/malfet
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mps Run MPS tests (subset of trunk) Merged open source release notes: mps Release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants