KEMBAR78
[foreach] Fix 0-size handling for real for real by janeyx99 · Pull Request #109402 · pytorch/pytorch · GitHub
Skip to content

Conversation

@janeyx99
Copy link
Contributor

@janeyx99 janeyx99 commented Sep 15, 2023

@crcrpar's last attempt to fix the 0-size problem unfortunately did not pass all cases. See my comment in #100701. When we have a tail tensor of size 0, the old code would mess with the chunk logic to check the previous tensor's length. This is flawed because:

  1. if the previous tensor was also 0 sized, (so a tensor list of [tensor, tensor, tensor, ..., 0-sized tensor, 0-sized tensor],) chunks would still be 0 and the nested for loop would be missed.
  2. the nested forloop pronounces side effects on tensorListMeta that shouldn't be there! This can mess up the compute in unexpected ways that I haven't really needed to reason through.

We noticed that the problem had not been fixed due to an internal report. This PR solves the issue by:

  • removing the finagling of chunks when the tail tensor is 0-sized
  • adding a surefire way for the kernel to be launched in the case where the last tensor is 0-sized AND there's content in the metadata, signifying there is stuff to compute still.

test plan

As I went through the code, I also added some comments explaining what's up and modified our tensor inputs to ensure that this case is tested in the test_parity test in test_foreach.py. Yes, I do realize there is quite a bit of duplication and that this file could be due for a refactor. That said, the primary goal of this PR is to fix the pretty egregious bug and refactoring can be a followup.

cc @awgu @crcrpar

Stack from ghstack (oldest at bottom):

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 15, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/109402

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 1 Pending

As of commit 26a342a with merge base a565f1b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: foreach_frontend release notes category label Sep 15, 2023
@janeyx99 janeyx99 added the topic: bug fixes topic category label Sep 15, 2023
crcrpar's last attempt to fix the 0-size problem unfortunately did not pass all cases. See my comment in #100701. When we have a tail tensor of size 0, the old code would mess with the chunk logic to check the previous tensor's length. This is flawed because:
1. if the previous tensor was also 0 sized, (so a tensor list of [tensor, tensor, tensor, ..., 0-sized tensor, 0-sized tensor],) chunks would still be 0 and the nested for loop would be missed.
2. the nested forloop pronounces side effects on tensorListMeta that _shouldn't_ be there! This can mess up the compute in unexpected ways that I haven't really needed to reason through.

We noticed that the problem had not been fixed due to an internal report. This PR solves the issue by:
- removing the finagling of chunks when the tail tensor is 0-sized
- adding a surefire way for the kernel to be launched in the case where the last tensor is 0-sized AND there's content in the metadata, signifying there is stuff to compute still.

## test plan

As I went through the code, I also added some comments explaining what's up and modified our tensor inputs to ensure that this case is tested in the test_parity test in test_foreach.py. Yes, I do realize there is quite a bit of duplication and that this file could be due for a refactor. That said, the primary goal of this PR is to fix the pretty egregious bug and refactoring can be a followup.

cc awgu crcrpar 




[ghstack-poisoned]
janeyx99 added a commit that referenced this pull request Sep 15, 2023
ghstack-source-id: 701970c
Pull Request resolved: #109402
@crcrpar
Copy link
Collaborator

crcrpar commented Sep 16, 2023

I vaguely remember ngimel suggested it should be possible to filter zero size tensors and pass non zero size tensors to multi tensor apply kernel. I failed to do so that time but would it be worth a try now?

@janeyx99
Copy link
Contributor Author

haha i remember us saying we should refactor. i've taken a look and the refactoring wouldn't get more efficient than now, and could get a bit more complicated due to the fact that blocks OR tensors could fill up. could be worth a shot rewriting with filtering, but running through the tensors once seems better than having two passes.

@janeyx99
Copy link
Contributor Author

janeyx99 commented Sep 16, 2023

actually we may be able to do so as a part of the check fast path api...that may be a lot easier....especially because currently this PR still wouldn't fix foreach_norm.

it looks like i'd be changing the dichotomy of the check_fast_path_api to take in std::vecs instead of ArrayRefs so we could easily drop size-0 tensors. I will explore the viability of this issue later, but here's the plan:

  1. if just doing a filtering is simplest and works, we can just go with that on Monday
  2. if the filtering is nontrivial and will take longer than Monday, I will land this change with XFAILs for things that should work but never worked to unblock people + then keep iterating on idea 1.

@janeyx99 janeyx99 marked this pull request as draft September 16, 2023 04:16
crcrpar's last attempt to fix the 0-size problem unfortunately did not pass all cases. See my comment in #100701. When we have a tail tensor of size 0, the old code would mess with the chunk logic to check the previous tensor's length. This is flawed because:
1. if the previous tensor was also 0 sized, (so a tensor list of [tensor, tensor, tensor, ..., 0-sized tensor, 0-sized tensor],) chunks would still be 0 and the nested for loop would be missed.
2. the nested forloop pronounces side effects on tensorListMeta that _shouldn't_ be there! This can mess up the compute in unexpected ways that I haven't really needed to reason through.

We noticed that the problem had not been fixed due to an internal report. This PR solves the issue by:
- removing the finagling of chunks when the tail tensor is 0-sized
- adding a surefire way for the kernel to be launched in the case where the last tensor is 0-sized AND there's content in the metadata, signifying there is stuff to compute still.

## test plan

As I went through the code, I also added some comments explaining what's up and modified our tensor inputs to ensure that this case is tested in the test_parity test in test_foreach.py. Yes, I do realize there is quite a bit of duplication and that this file could be due for a refactor. That said, the primary goal of this PR is to fix the pretty egregious bug and refactoring can be a followup.

cc awgu crcrpar 




[ghstack-poisoned]
janeyx99 added a commit that referenced this pull request Sep 18, 2023
ghstack-source-id: bd07c18
Pull Request resolved: #109402
crcrpar's last attempt to fix the 0-size problem unfortunately did not pass all cases. See my comment in #100701. When we have a tail tensor of size 0, the old code would mess with the chunk logic to check the previous tensor's length. This is flawed because:
1. if the previous tensor was also 0 sized, (so a tensor list of [tensor, tensor, tensor, ..., 0-sized tensor, 0-sized tensor],) chunks would still be 0 and the nested for loop would be missed.
2. the nested forloop pronounces side effects on tensorListMeta that _shouldn't_ be there! This can mess up the compute in unexpected ways that I haven't really needed to reason through.

We noticed that the problem had not been fixed due to an internal report. This PR solves the issue by:
- removing the finagling of chunks when the tail tensor is 0-sized
- adding a surefire way for the kernel to be launched in the case where the last tensor is 0-sized AND there's content in the metadata, signifying there is stuff to compute still.

## test plan

As I went through the code, I also added some comments explaining what's up and modified our tensor inputs to ensure that this case is tested in the test_parity test in test_foreach.py. Yes, I do realize there is quite a bit of duplication and that this file could be due for a refactor. That said, the primary goal of this PR is to fix the pretty egregious bug and refactoring can be a followup.

cc awgu crcrpar 




[ghstack-poisoned]
@janeyx99 janeyx99 added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 21, 2023
@janeyx99
Copy link
Contributor Author

Update: I've gone on a grand endeavor to do filtering first, but it is a much more involved change that seems to affect other parts of the stack. (See my attempt at #109550).

The safest path of procession is:

  • tidy up and land this change
  • land the big change cautiously as a refactor

here's to hoping for green CI

crcrpar's last attempt to fix the 0-size problem unfortunately did not pass all cases. See my comment in #100701. When we have a tail tensor of size 0, the old code would mess with the chunk logic to check the previous tensor's length. This is flawed because:
1. if the previous tensor was also 0 sized, (so a tensor list of [tensor, tensor, tensor, ..., 0-sized tensor, 0-sized tensor],) chunks would still be 0 and the nested for loop would be missed.
2. the nested forloop pronounces side effects on tensorListMeta that _shouldn't_ be there! This can mess up the compute in unexpected ways that I haven't really needed to reason through.

We noticed that the problem had not been fixed due to an internal report. This PR solves the issue by:
- removing the finagling of chunks when the tail tensor is 0-sized
- adding a surefire way for the kernel to be launched in the case where the last tensor is 0-sized AND there's content in the metadata, signifying there is stuff to compute still.

## test plan

As I went through the code, I also added some comments explaining what's up and modified our tensor inputs to ensure that this case is tested in the test_parity test in test_foreach.py. Yes, I do realize there is quite a bit of duplication and that this file could be due for a refactor. That said, the primary goal of this PR is to fix the pretty egregious bug and refactoring can be a followup.

cc awgu crcrpar 




[ghstack-poisoned]
janeyx99 added a commit that referenced this pull request Sep 21, 2023
ghstack-source-id: 1a5add0
Pull Request resolved: #109402
@janeyx99 janeyx99 marked this pull request as ready for review September 21, 2023 20:15
toleranceOverride(
{
torch.complex64: tol(atol=1e-05, rtol=1e-05)
torch.complex64: tol(atol=3e-04, rtol=2e-05)
Copy link
Contributor Author

@janeyx99 janeyx99 Sep 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@parth-desai @peterbell10 hi 3e-04 seems like a decently large atol to me, and I did confirm locally that the reason for this disparity is the jiterator change. One can repo the following provides different results before and after #102427.

import torch
x = torch.tensor(-7.8167-0.0451j, device='cuda:0')
torch.set_printoptions(precision=10)
print(torch.tan(x))
print(torch._foreach_tan([x])[0])
print(torch._foreach_tan([x.to(device="cpu")])[0])

Before:
image

After:
image

This PR just happened to catch this since I added more sample inputs to test the empty tensor case so the seed changed. I'm wondering if this is acceptable or whether an issue should be raised to call this out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or one can run python test/test_foreach.py -k test_parity__foreach_tan_slowpath_outplace_cuda_complex64 without the tolerance changes to repro as well

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed that does look quite bad. I think it's okay to revert the changes in UnaryGeometricTanKernel.cu.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Jane, Please raise an issue. I will try to fix it in a separate PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jiterator uses different complex math implementations (from llvm) than thrust (which is used throughout eager codebase), I think we already had similar discrepancies with sigmoid? Worth checking.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened an issue #110014

@janeyx99 janeyx99 requested a review from albanD September 21, 2023 20:22
crcrpar's last attempt to fix the 0-size problem unfortunately did not pass all cases. See my comment in #100701. When we have a tail tensor of size 0, the old code would mess with the chunk logic to check the previous tensor's length. This is flawed because:
1. if the previous tensor was also 0 sized, (so a tensor list of [tensor, tensor, tensor, ..., 0-sized tensor, 0-sized tensor],) chunks would still be 0 and the nested for loop would be missed.
2. the nested forloop pronounces side effects on tensorListMeta that _shouldn't_ be there! This can mess up the compute in unexpected ways that I haven't really needed to reason through.

We noticed that the problem had not been fixed due to an internal report. This PR solves the issue by:
- removing the finagling of chunks when the tail tensor is 0-sized
- adding a surefire way for the kernel to be launched in the case where the last tensor is 0-sized AND there's content in the metadata, signifying there is stuff to compute still.

## test plan

As I went through the code, I also added some comments explaining what's up and modified our tensor inputs to ensure that this case is tested in the test_parity test in test_foreach.py. Yes, I do realize there is quite a bit of duplication and that this file could be due for a refactor. That said, the primary goal of this PR is to fix the pretty egregious bug and refactoring can be a followup.

cc awgu crcrpar 




[ghstack-poisoned]
crcrpar's last attempt to fix the 0-size problem unfortunately did not pass all cases. See my comment in #100701. When we have a tail tensor of size 0, the old code would mess with the chunk logic to check the previous tensor's length. This is flawed because:
1. if the previous tensor was also 0 sized, (so a tensor list of [tensor, tensor, tensor, ..., 0-sized tensor, 0-sized tensor],) chunks would still be 0 and the nested for loop would be missed.
2. the nested forloop pronounces side effects on tensorListMeta that _shouldn't_ be there! This can mess up the compute in unexpected ways that I haven't really needed to reason through.

We noticed that the problem had not been fixed due to an internal report. This PR solves the issue by:
- removing the finagling of chunks when the tail tensor is 0-sized
- adding a surefire way for the kernel to be launched in the case where the last tensor is 0-sized AND there's content in the metadata, signifying there is stuff to compute still.

## test plan

As I went through the code, I also added some comments explaining what's up and modified our tensor inputs to ensure that this case is tested in the test_parity test in test_foreach.py. Yes, I do realize there is quite a bit of duplication and that this file could be due for a refactor. That said, the primary goal of this PR is to fix the pretty egregious bug and refactoring can be a followup.

cc awgu crcrpar 




[ghstack-poisoned]
janeyx99 added a commit that referenced this pull request Sep 25, 2023
ghstack-source-id: c6ff3f3
Pull Request resolved: #109402
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!
Let's follow up on the precision issue in details

for num_tensors, rightmost_arg_type, intersperse_empty_tensors in itertools.product(
num_input_tensors, self._rightmost_arg_types, (True, False)):
if intersperse_empty_tensors and (num_tensors != max(num_input_tensors) or str(device) == 'cpu'):
# generate interspersed empty tensors for only 1 N on non-cpu device to lessen redundancy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only 1 N ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, like the largest N.

@janeyx99
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: foreach_frontend release notes category topic: bug fixes topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants