KEMBAR78
Added Diffentiable per_sample_weights Check to EmbeddingBag.cpp by EmmettBicker · Pull Request #142338 · pytorch/pytorch · GitHub
Skip to content

Conversation

@EmmettBicker
Copy link
Contributor

Added a check in aten/src/ATen/native/EmbeddingBag.cpp that checks if per_sample_weights needs a gradient in order to determine if at::_embedding_bag_forward_only or at::_embedding_bag should run.

Also, added two tests in test_embedding.py that check if the command now works.

Fixes #136457

Author: Emmett Bicker

Added a check in aten/src/ATen/native/EmbeddingBag.cpp that checks if per_sample_weights needs a gradient in order to determine if forward_only or the full embedding bag implementation should run.

Also, added two tests in test_embedding.py that check if the command now works.
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 8, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/142338

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 7a54d87 with merge base 524395e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

x = torch.arange(1, 5, device=device).expand(3, -1)
w = torch.rand(3, 4, device=device, requires_grad=per_sample_weights_use_grad)
bag(x, per_sample_weights=F.softmax(w, dim=-1))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put this test under TestEmbeddingNNDeviceType? This lets us parametrize on device as well /avoid duplication between cpu/cuda.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! Ill update the PR shortly

Copy link
Contributor

@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Small comment on testing

@soulitzer soulitzer added release notes: autograd release notes category topic: bug fixes topic category labels Dec 9, 2024
@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Dec 9, 2024
@EmmettBicker
Copy link
Contributor Author

@soulitzer Does this look better?

Copy link
Contributor

@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@soulitzer
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 10, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 3 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@EmmettBicker
Copy link
Contributor Author

EmmettBicker commented Dec 10, 2024

@soulitzer Hi! Sorry I think I removed some whitespace at the same time you @ merged and it might have broken it? Which would make a lot of sense.

@soulitzer
Copy link
Contributor

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@EmmettBicker
Copy link
Contributor Author

Hi @soulitzer ! I'm so sorry I'm still getting used to contributing and forgot to run the linter on the other file I edited, I just removed the two other problematic pieces of whitespace in another commit.

@soulitzer
Copy link
Contributor

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@EmmettBicker EmmettBicker deleted the embedding_bag_per_sample_weights_check branch December 11, 2024 04:04
@EmmettBicker
Copy link
Contributor Author

Woohoo!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: autograd release notes category topic: bug fixes topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

EmbeddingBag causes internal assertion error if differentiable per_sample_weights are provided but the embedding weight has gradients disabled.

5 participants