-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Faster Faster BatchSampler #137423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Faster Faster BatchSampler #137423
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137423
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c16622a with merge base a063a82 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc3ee08 to
b14772e
Compare
|
Thanks for the PR, super curious how the results stand for even smaller batch sizes. (We probably should have a uniform script to benchmark and check correctness, I can maybe look into that!). |
torch/utils/data/sampler.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's unclear to me that this is safe in all scenarios although on the surface it seems OK. What benefit does this give us though?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's my understanding that it creates the required references first and then passes them all to zip at once - which is more vectorised? I found this SO post quite informative on it.
I don't have the level of understanding of the pytorch codebase as either of you will, but I'm happy to work with you to test more things if that would be helpful. I found that the original list comprehension method was taking time when looping over large batches of small data when benchmarking some code with a colleague (@SusannaGreen).
I should mention, the zip function now includes a strict=True argument, which will force zip to error if the lengths don't match up (i.e. the last batch is smaller than batch_size).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@andrewkho Should be marginally faster by dropping down into the C implementation for these functions as much as possible
I have some (quite janky) batch verification code I used with the aforementioned benchmarking code. I've put it in a gist here. It's janky because to test the |
b14772e to
4f3e092
Compare
|
Fixed for lintrunner requirements. New benchmarks: Apple M1 Max: Now it's slower for |
I mean you can ignore the linting with special comments if it affects speed... |
|
I buy the itertools islice is faster for sure. It's a shame we can't use itertools.batched both because of our minimum Python version and because it returns tuple instead of lists lol. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me, and the speed up makes sense me based on how itertools.islice is implemented and how zip is implemented.
torch/utils/data/sampler.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ugh why is this converted to a list twice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it was for linting 😬
torch/utils/data/sampler.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can just call it something else to remoe the redef typing error? No reason to reuse the name if we don't need to. It's complaining because it was previously a Tuple[int,...] in the other branch of the if loop and now it's an Iterator[int] cast to a List.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes a lot of sense. Thanks!
|
Minus those nits: it looks good to me. |
4f3e092 to
1c30f7f
Compare
|
Annoyingly, for the The zip part is slower now, I just need a bit of time to fix that then I think I'm good. |
1c30f7f to
d8aed7b
Compare
d8aed7b to
c16622a
Compare
|
Ok. I think I've tidied it up now. Here's the benchmarks. Seems to pass the lintrunner fine :-) AMD Ryzen Threadripper PRO 3995WX: Apple M1 Max: |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
Hmmm. A bunch of the trunk tests failed after merging (same with some other PRs unrelated to this one). I think it's to do with changes in this PR: #136519 ? |
## Motivation Many PRs optimizing samplers (for eg #147706, #137423) are leveraging an adhoc script for benchmarking samplers. The script and outputs are often copied over in PRs. We want to begin centralizing benchmarks for torch.utils.data components. ## What ? * This PR adds a new sub-folder in `benchmarks` for `data`. This is aimed to cover benchmarking scripts for torch.utils.data components like dataloader and sampler. * Specifically, this PR includes a simple script to time samplers. This is often "copy-pasted" in PRs optimizing samplers. Having it in a centralized location should prevent that, and allow a common standard. ## Output ``` Benchmark Results: +--------------+-------------+----------------+-----------+-----------+ | Batch Size | Drop Last | Original (s) | New (s) | Speedup | +==============+=============+================+===========+===========+ | 4 | True | 0.004 | 0.0088 | -119.62% | +--------------+-------------+----------------+-----------+-----------+ | 4 | False | 0.0083 | 0.009 | -9.23% | +--------------+-------------+----------------+-----------+-----------+ | 8 | True | 0.003 | 0.0074 | -147.64% | +--------------+-------------+----------------+-----------+-----------+ | 8 | False | 0.0054 | 0.0075 | -38.72% | +--------------+-------------+----------------+-----------+-----------+ | 64 | True | 0.0021 | 0.0056 | -161.92% | +--------------+-------------+----------------+-----------+-----------+ | 64 | False | 0.0029 | 0.0055 | -92.50% | +--------------+-------------+----------------+-----------+-----------+ | 640 | True | 0.002 | 0.0055 | -168.75% | +--------------+-------------+----------------+-----------+-----------+ | 640 | False | 0.0024 | 0.0062 | -161.35% | +--------------+-------------+----------------+-----------+-----------+ | 6400 | True | 0.0021 | 0.0055 | -160.13% | +--------------+-------------+----------------+-----------+-----------+ | 6400 | False | 0.0021 | 0.0068 | -215.46% | +--------------+-------------+----------------+-----------+-----------+ | 64000 | True | 0.0042 | 0.0065 | -55.29% | +--------------+-------------+----------------+-----------+-----------+ | 64000 | False | 0.0029 | 0.0077 | -169.56% | +--------------+-------------+----------------+-----------+-----------+ ``` Pull Request resolved: #156974 Approved by: https://github.com/ramanishsingh
Builds upon #76951.
Benchmarking code is the same as in #76950.
AMD Ryzen Threadripper PRO 3995WX:
When
drop_last == True, it useszipto speed things up.When
drop_last == False, it usesitertoolsto speed things up.itertoolswas the fastest way I could find that deals with the last batch if it is smaller thanbatch_size. I have a pure python method too, but it is slower whenbatch_sizeis 4 or 8, so I have committed theitertoolsversion for now.Happy to chat further about this change :-) I understand you may not want to introduce the
itertoolspackage into sampler.py.