KEMBAR78
Faster Faster BatchSampler by xangma · Pull Request #137423 · pytorch/pytorch · GitHub
Skip to content

Conversation

@xangma
Copy link
Contributor

@xangma xangma commented Oct 7, 2024

Builds upon #76951.

Benchmarking code is the same as in #76950.

AMD Ryzen Threadripper PRO 3995WX:

  batch_size  drop_last      origin     new  speedup
------------  -----------  --------  ------  ---------
           4  True           0.94    0.5706  64.74%
           4  False          0.9745  0.9468  2.93%
           8  True           0.7423  0.3715  99.82%
           8  False          0.7974  0.5666  40.73%
          64  True           0.5394  0.2085  158.76%
          64  False          0.6083  0.2697  125.51%
         640  True           0.5448  0.1985  174.41%
         640  False          0.7085  0.2308  206.91%
        6400  True           0.5554  0.2028  173.88%
        6400  False          0.7711  0.2109  265.60%
       64000  True           0.556   0.2091  165.82%
       64000  False          0.7803  0.2078  275.58%

When drop_last == True, it uses zip to speed things up.
When drop_last == False, it uses itertools to speed things up.

itertools was the fastest way I could find that deals with the last batch if it is smaller than batch_size. I have a pure python method too, but it is slower when batch_size is 4 or 8, so I have committed the itertools version for now.

Happy to chat further about this change :-) I understand you may not want to introduce the itertools package into sampler.py.

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 7, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137423

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c16622a with merge base a063a82 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: dataloader release notes category label Oct 7, 2024
@colesbury colesbury added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 7, 2024
@xangma xangma force-pushed the batchsampler-update branch 2 times, most recently from cc3ee08 to b14772e Compare October 8, 2024 08:21
@divyanshk
Copy link
Contributor

Thanks for the PR, super curious how the results stand for even smaller batch sizes.

(We probably should have a uniform script to benchmark and check correctness, I can maybe look into that!).

Comment on lines +339 to +340
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unclear to me that this is safe in all scenarios although on the surface it seems OK. What benefit does this give us though?

Copy link
Contributor Author

@xangma xangma Oct 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's my understanding that it creates the required references first and then passes them all to zip at once - which is more vectorised? I found this SO post quite informative on it.

I don't have the level of understanding of the pytorch codebase as either of you will, but I'm happy to work with you to test more things if that would be helpful. I found that the original list comprehension method was taking time when looping over large batches of small data when benchmarking some code with a colleague (@SusannaGreen).

I should mention, the zip function now includes a strict=True argument, which will force zip to error if the lengths don't match up (i.e. the last batch is smaller than batch_size).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andrewkho Should be marginally faster by dropping down into the C implementation for these functions as much as possible

@xangma
Copy link
Contributor Author

xangma commented Oct 9, 2024

Thanks for the PR, super curious how the results stand for even smaller batch sizes.

(We probably should have a uniform script to benchmark and check correctness, I can maybe look into that!).

I have some (quite janky) batch verification code I used with the aforementioned benchmarking code. I've put it in a gist here. It's janky because to test the drop_last == False functionality for the itertools change, you need to manually set DATA_SIZE to be smaller than VERIFICATION_BATCHES * BATCH_SIZE . I didn't have much time to spend on it sorry 😅 There's an example output as a comment to the gist :-)

@xangma xangma force-pushed the batchsampler-update branch from b14772e to 4f3e092 Compare October 9, 2024 12:04
@xangma
Copy link
Contributor Author

xangma commented Oct 9, 2024

Fixed for lintrunner requirements.

New benchmarks:
AMD Ryzen Threadripper PRO 3995WX:

  batch_size  drop_last      origin     new  speedup
------------  -----------  --------  ------  ---------
           4  True           0.8761  0.5565  57.44%
           4  False          0.978   1.1248  -13.05%
           8  True           0.7094  0.3651  94.31%
           8  False          0.7892  0.6649  18.69%
          64  True           0.535   0.2006  166.77%
          64  False          0.5954  0.3011  97.77%
         640  True           0.5205  0.1853  180.88%
         640  False          0.6293  0.2435  158.43%
        6400  True           0.5219  0.1898  175.03%
        6400  False          0.6602  0.2167  204.61%
       64000  True           0.5214  0.197   164.66%
       64000  False          0.6677  0.2679  149.26%

Apple M1 Max:

  batch_size  drop_last      origin     new  speedup
------------  -----------  --------  ------  ---------
           4  True           0.6173  0.4025  53.38%
           4  False          0.8404  0.8706  -3.46%
           8  True           0.5214  0.2833  84.04%
           8  False          0.6657  0.5511  20.78%
          64  True           0.358   0.1763  103.01%
          64  False          0.4884  0.2625  86.10%
         640  True           0.3749  0.1581  137.04%
         640  False          0.5404  0.2072  160.80%
        6400  True           0.3879  0.1683  130.45%
        6400  False          0.5752  0.1923  199.05%
       64000  True           0.3884  0.1767  119.83%
       64000  False          0.5906  0.1938  204.68%

Now it's slower for batch_size == 4 Silly linting :'(

@Skylion007
Copy link
Collaborator

Fixed for lintrunner requirements.

New benchmarks: AMD Ryzen Threadripper PRO 3995WX:

  batch_size  drop_last      origin     new  speedup
------------  -----------  --------  ------  ---------
           4  True           0.8761  0.5565  57.44%
           4  False          0.978   1.1248  -13.05%
           8  True           0.7094  0.3651  94.31%
           8  False          0.7892  0.6649  18.69%
          64  True           0.535   0.2006  166.77%
          64  False          0.5954  0.3011  97.77%
         640  True           0.5205  0.1853  180.88%
         640  False          0.6293  0.2435  158.43%
        6400  True           0.5219  0.1898  175.03%
        6400  False          0.6602  0.2167  204.61%
       64000  True           0.5214  0.197   164.66%
       64000  False          0.6677  0.2679  149.26%

Apple M1 Max:

  batch_size  drop_last      origin     new  speedup
------------  -----------  --------  ------  ---------
           4  True           0.6173  0.4025  53.38%
           4  False          0.8404  0.8706  -3.46%
           8  True           0.5214  0.2833  84.04%
           8  False          0.6657  0.5511  20.78%
          64  True           0.358   0.1763  103.01%
          64  False          0.4884  0.2625  86.10%
         640  True           0.3749  0.1581  137.04%
         640  False          0.5404  0.2072  160.80%
        6400  True           0.3879  0.1683  130.45%
        6400  False          0.5752  0.1923  199.05%
       64000  True           0.3884  0.1767  119.83%
       64000  False          0.5906  0.1938  204.68%

Now it's slower for batch_size == 4 Silly linting :'(

I mean you can ignore the linting with special comments if it affects speed...

@Skylion007
Copy link
Collaborator

I buy the itertools islice is faster for sure. It's a shame we can't use itertools.batched both because of our minimum Python version and because it returns tuple instead of lists lol.

Copy link
Collaborator

@Skylion007 Skylion007 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me, and the speed up makes sense me based on how itertools.islice is implemented and how zip is implemented.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugh why is this converted to a list twice?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it was for linting 😬

Copy link
Collaborator

@Skylion007 Skylion007 Oct 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just call it something else to remoe the redef typing error? No reason to reuse the name if we don't need to. It's complaining because it was previously a Tuple[int,...] in the other branch of the if loop and now it's an Iterator[int] cast to a List.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes a lot of sense. Thanks!

@Skylion007
Copy link
Collaborator

Minus those nits: it looks good to me.

@xangma xangma force-pushed the batchsampler-update branch from 4f3e092 to 1c30f7f Compare October 12, 2024 10:09
@xangma
Copy link
Contributor Author

xangma commented Oct 12, 2024

Annoyingly, for the itertools part, if you don't convert to a list before yielding, it's seemingly really slow.

The zip part is slower now, I just need a bit of time to fix that then I think I'm good.

@xangma xangma force-pushed the batchsampler-update branch from 1c30f7f to d8aed7b Compare October 12, 2024 20:33
@xangma xangma force-pushed the batchsampler-update branch from d8aed7b to c16622a Compare October 12, 2024 21:05
@xangma
Copy link
Contributor Author

xangma commented Oct 12, 2024

Ok. I think I've tidied it up now.

Here's the benchmarks. Seems to pass the lintrunner fine :-)

AMD Ryzen Threadripper PRO 3995WX:

  batch_size  drop_last      origin     new  speedup
------------  -----------  --------  ------  ---------
           4  True           0.8708  0.4741  83.69%
           4  False          0.9435  0.8576  10.01%
           8  True           0.7023  0.3281  114.02%
           8  False          0.7701  0.507   51.88%
          64  True           0.5286  0.1925  174.60%
          64  False          0.5895  0.2584  128.11%
         640  True           0.5206  0.1856  180.52%
         640  False          0.629   0.2208  184.92%
        6400  True           0.5194  0.1905  172.68%
        6400  False          0.6601  0.1933  241.47%
       64000  True           0.519   0.1965  164.04%
       64000  False          0.6632  0.1883  252.13%

Apple M1 Max:

  batch_size  drop_last      origin     new  speedup
------------  -----------  --------  ------  ---------
           4  True           0.6297  0.3192  97.27%
           4  False          0.849   0.6475  31.10%
           8  True           0.4987  0.2378  109.75%
           8  False          0.6554  0.4149  57.96%
          64  True           0.3536  0.1767  100.15%
          64  False          0.4896  0.2304  112.52%
         640  True           0.3723  0.1552  139.86%
         640  False          0.5375  0.1879  186.06%
        6400  True           0.3797  0.169   124.62%
        6400  False          0.5773  0.1691  241.34%
       64000  True           0.3815  0.18    111.93%
       64000  False          0.5846  0.1674  249.19%
       ```

@cyyever
Copy link
Collaborator

cyyever commented Oct 13, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 13, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@xangma
Copy link
Contributor Author

xangma commented Oct 15, 2024

Hmmm. A bunch of the trunk tests failed after merging (same with some other PRs unrelated to this one). I think it's to do with changes in this PR: #136519 ?

@Skylion007 @divyanshk

pytorchmergebot pushed a commit that referenced this pull request Jun 27, 2025
## Motivation
Many PRs optimizing samplers (for eg #147706, #137423) are leveraging an adhoc script for benchmarking samplers. The script and outputs are often copied over in PRs. We want to begin centralizing benchmarks for torch.utils.data components.

## What ?
* This PR adds a new sub-folder in `benchmarks`  for `data`. This is aimed to cover benchmarking scripts for torch.utils.data components like dataloader and sampler.
* Specifically, this PR includes a simple script to time samplers. This is often "copy-pasted" in PRs optimizing samplers. Having it in a centralized location should prevent that, and allow a common standard.

## Output
```
Benchmark Results:
+--------------+-------------+----------------+-----------+-----------+
|   Batch Size | Drop Last   |   Original (s) |   New (s) | Speedup   |
+==============+=============+================+===========+===========+
|            4 | True        |         0.004  |    0.0088 | -119.62%  |
+--------------+-------------+----------------+-----------+-----------+
|            4 | False       |         0.0083 |    0.009  | -9.23%    |
+--------------+-------------+----------------+-----------+-----------+
|            8 | True        |         0.003  |    0.0074 | -147.64%  |
+--------------+-------------+----------------+-----------+-----------+
|            8 | False       |         0.0054 |    0.0075 | -38.72%   |
+--------------+-------------+----------------+-----------+-----------+
|           64 | True        |         0.0021 |    0.0056 | -161.92%  |
+--------------+-------------+----------------+-----------+-----------+
|           64 | False       |         0.0029 |    0.0055 | -92.50%   |
+--------------+-------------+----------------+-----------+-----------+
|          640 | True        |         0.002  |    0.0055 | -168.75%  |
+--------------+-------------+----------------+-----------+-----------+
|          640 | False       |         0.0024 |    0.0062 | -161.35%  |
+--------------+-------------+----------------+-----------+-----------+
|         6400 | True        |         0.0021 |    0.0055 | -160.13%  |
+--------------+-------------+----------------+-----------+-----------+
|         6400 | False       |         0.0021 |    0.0068 | -215.46%  |
+--------------+-------------+----------------+-----------+-----------+
|        64000 | True        |         0.0042 |    0.0065 | -55.29%   |
+--------------+-------------+----------------+-----------+-----------+
|        64000 | False       |         0.0029 |    0.0077 | -169.56%  |
+--------------+-------------+----------------+-----------+-----------+
```
Pull Request resolved: #156974
Approved by: https://github.com/ramanishsingh
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: dataloader release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants