-
Notifications
You must be signed in to change notification settings - Fork 368
feat: support aten index_put converter for accumulate=False #2880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| ) -> TRTTensor: | ||
| # Reshape indices to add an extra dimension if necessary (indices is a Tuple of ITensors) | ||
| reshaped_indices = [] | ||
| for i, each_input in enumerate(indices): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since indices is possible to be ITensor per the schema, you may not be able to iterate an ITensor.
In the test case, you can try to change the line 173 to inputs=[source_tensor, indices_tensor, value_tensor],.
It's kind of similar to the offsets in the annoying embedding_bag. You can think about how to use native TRT Layers to do this, like ILoop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Besides, what blocks you when accumulate=True?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much for your review. When indices is a torch.tensor, an error occurs in PyTorch as shown in the example below. This situation is somewhat different from embedding_bag. It is a case where the input is a tuple of tensors, which we discussed earlier.
If you look at the example, the index_put_ function throws an error when indices is of torch.tensor type and only works correctly when indices is a tuple or list.
Therefore, indices can be iterated over for loop and I did not use a for loop for each_input since it is an ITensor. If I am mistaken, your comments would be very helpful.
One more question I have is about the type definition of indices when it is a tuple of tensors. Is it correct to define indices as Union[TRTTensor, Tuple[TRTTensor, ...]]?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When accumulate=True, if there are duplicate pairs of index in indices, the corresponding values should be summed and then removed from the elements. Therefore, I aimed to obtain indices without duplicated pairs and corresponding modified values, and then use these to input into the scatter layer. However, I encountered difficulties in implementing the for loop to check for duplicate pairs of index in indices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the detailed explanations! Yes you are right, the indices should be list or tuple, and thus it could be iterated over. Then your current implementation LGTM.
One more question I have is about the type definition of indices when it is a tuple of tensors. Is it correct to define indices as Union[TRTTensor, Tuple[TRTTensor, ...]]?
I think it could be Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]] since a single TRTTensor cannot be iterated and per the schema, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When you say accumulate=True is causing issue, I believe the duplicate indices causes issues. I faced the same in scatter_reduce and I believe advanced indexing would be the way to deal with it (lengthy code that would be I believe :( ). Do you have any other ideas?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have write a validator to handle the accumulate=True case. And I have created a separate issue for implementing the converter for accumulate=True. It would be great to share ideas and work together on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation looks good to me. Can you add a test case like this -
tensor = torch.zeros([4, 4, 4, 4], dtype = torch.int32)
indices = (torch.tensor([0, 1, 2, 3]), torch.tensor([2, 3, 1, 0]))
values = torch.tensor([10, 20, 30, 40], dtype = torch.int32)
out = torch.index_put_(tensor, indices, values)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets write a validator for this case and resolve in a new PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have write validator for broadcasting.
77c1d8f to
663cc02
Compare
5f6f2b2 to
c72222a
Compare

Description
I have implemented the
aten::index_putoperation using theadd_scatterlayer withtrt.ScatterMode.ND. However, I was unable to implement theaccumulate=Truecase, which is currently handled by the validator.Fixes # (issue)
Type of change
Checklist: