KEMBAR78
Add 0dim Tensor overload for _foreach_div by janeyx99 · Pull Request #113688 · pytorch/pytorch · GitHub
Skip to content

Conversation

janeyx99
Copy link
Contributor

@janeyx99 janeyx99 commented Nov 14, 2023

This PR is ALMOST basically just following the steps from #106677 EXCEPT! We do add one feature. Similar to fused_adam(w), for the CUDA dispatches: when the scalar tensor is on CPU, we .item and redispatch to the normal scalar overload. Otherwise, the cuda kernel will complain about mismatch in devices between the scalar and the tensors.

Why do we add this feature? Our optimizers want to allow lr as a tensor, and lr could be a CPU tensor. lr is used with foreach_div_ in Adam, so our CI will break otherwise.

After this PR, _foreach_mul and _foreach_div will accept either a CPU or a GPU tensor for the scalar tensor (vs only a GPU tensor). They join the ranks of fused_adam(w) in this characteristic. I did not yet do the same thing for foreach_add (the only other foreach op with a .Tensor overload) because there is no use case and will be more involved.

cc @crcrpar

Stack from ghstack (oldest at bottom):

@pytorch-bot pytorch-bot bot added the release notes: foreach_frontend release notes category label Nov 14, 2023
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 14, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/113688

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ed39357 with merge base afef32b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

janeyx99 added a commit that referenced this pull request Nov 14, 2023
ghstack-source-id: c32ca37
Pull Request resolved: #113688
@janeyx99 janeyx99 marked this pull request as ready for review November 14, 2023 22:06
self.is_inplace = False if func is None else func.__name__.endswith('_')

def __call__(self, inputs, is_cuda, is_fastpath, **kwargs):
def __call__(self, inputs, is_cuda, expect_fastpath, **kwargs):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename to expect_fastpath, which is what it really is and is less confusing

@janeyx99 janeyx99 requested review from albanD and mlazos November 14, 2023 22:07
janeyx99 added a commit that referenced this pull request Nov 15, 2023
ghstack-source-id: f4da19a
Pull Request resolved: #113688
janeyx99 added a commit that referenced this pull request Nov 15, 2023
ghstack-source-id: 2dd88dd
Pull Request resolved: #113688
@janeyx99 janeyx99 added topic: new features topic category ciflow/trunk Trigger trunk jobs on your pull request labels Nov 15, 2023
This PR is ALMOST basically just following the steps from #106677 EXCEPT! We do add one feature. Similar to fused_adam(w), for the CUDA dispatches: when the scalar tensor is on CPU, we .item and redispatch to the normal scalar overload. Otherwise, the cuda kernel will complain about mismatch in devices between the scalar and the tensors.

Why do we add this feature? Our optimizers want to allow lr as a tensor, and lr could be a CPU tensor. lr is used with foreach_div_ in Adam, so our CI will break otherwise.

After this PR, `_foreach_mul` and `_foreach_div` will accept either a CPU or a GPU tensor for the scalar tensor (vs only a GPU tensor). They join the ranks of `fused_adam(w)` in this characteristic. I did not yet do the same thing for foreach_add (the only other foreach op with a .Tensor overload) because there is no use case and will be more involved.

cc crcrpar 




[ghstack-poisoned]
janeyx99 added a commit that referenced this pull request Nov 15, 2023
ghstack-source-id: d7ffa5d
Pull Request resolved: #113688
tensors = [make_tensor((2, 2), dtype=torch.float, device="cuda") for _ in range(2)]
with self.assertRaisesRegex(RuntimeError, "scalar tensor expected to be on"):
torch._foreach_mul(tensors, torch.tensor(1.0, device="cpu"))
torch._foreach_add(tensors, torch.tensor(1.0, device="cpu"), alpha=1.0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works now, but made me realize I didn't add a case for _foreach_add when I added the overload. Adding that now.

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please confirm there is a test case that also ensures that a list of cpu Tensor and a list of cuda Tensor properly raises.

@janeyx99
Copy link
Contributor Author

janeyx99 commented Nov 15, 2023

Please confirm there is a test case that also ensures that a list of cpu Tensor and a list of cuda Tensor properly raises.

test_parity tests that both a forloop over the ops + the foreach op will return the same thing or error the same:

try:
with ctxmgr:
actual = func([sample.input, *sample.args], self.is_cuda, expect_fastpath, **kwargs)
except Exception as e:
with (
self.assertRaisesRegex(type(e), re.escape(str(e)))
if not (op.has_no_in_place or op.has_no_out_of_place)
else self.assertRaises(type(e))
):
ref([ref_input, *sample.ref_args], **ref_kwargs)
else:
expected = ref([ref_input, *sample.ref_args], **ref_kwargs)
self.assertEqual(expected, actual)

Note that if the reference code errors in line 174, the test will fail.

@janeyx99
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@mlazos
Copy link
Contributor

mlazos commented Nov 15, 2023

Please confirm there is a test case that also ensures that a list of cpu Tensor and a list of cuda Tensor properly raises.

@janeyx99 hmm at what level will this error, I probably need to add the grouping at the dynamo level if this occurs when FakeTensor tracing as we discussed

@janeyx99
Copy link
Contributor Author

@janeyx99 hmm at what level will this error, I probably need to add the grouping at the dynamo level if this occurs when FakeTensor tracing as we discussed

Ah, what happens is that we'll dispatch into the CUDA impl foreach_tensor_div_tensor_kernel_cuda which is a C++ function that checks device types. Since the device types do not match, we then dispatch to the slow impl, which is just another C++ function that for loops over the normal div. Once it calls into div, it will error. It's worth ensuring what torch.compile will do (hopefully it's not too smart; we don't want it converting the second list to GPU tensors).

@mlazos
Copy link
Contributor

mlazos commented Nov 16, 2023

@janeyx99 hmm at what level will this error, I probably need to add the grouping at the dynamo level if this occurs when FakeTensor tracing as we discussed

Ah, what happens is that we'll dispatch into the CUDA impl foreach_tensor_div_tensor_kernel_cuda which is a C++ function that checks device types. Since the device types do not match, we then dispatch to the slow impl, which is just another C++ function that for loops over the normal div. Once it calls into div, it will error. It's worth ensuring what torch.compile will do (hopefully it's not too smart; we don't want it converting the second list to GPU tensors).

Nah this is good, this is exactly what I want. I also will only apply the special handling we discussed when there is a single tensor in the second arg.

@facebook-github-bot facebook-github-bot deleted the gh/janeyx99/104/head branch November 19, 2023 15:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: foreach_frontend release notes category topic: new features topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants