KEMBAR78
feat: support aten.index_select converter by chohk88 · Pull Request #2710 · pytorch/TensorRT · GitHub
Skip to content

Conversation

@chohk88
Copy link
Collaborator

@chohk88 chohk88 commented Mar 25, 2024

Description

New feature to support aten.index_select converter. I also add test case for different dimensions.

Fixes # (#2708)

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@chohk88 chohk88 requested review from apbose and zewenli98 March 25, 2024 11:19
@chohk88 chohk88 self-assigned this Mar 25, 2024
@github-actions github-actions bot added component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Mar 25, 2024
@github-actions github-actions bot requested a review from gs-olive March 25, 2024 11:19
@chohk88 chohk88 linked an issue Mar 25, 2024 that may be closed by this pull request
index: TRTTensor,
) -> TRTTensor:
# The axis parameter specifies the dimension along which to index.
gather_layer = ctx.net.add_gather(input, index, axis=dim)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dim likely needs to be corrected using get_positive_dim to ensure the value is positive for add_gather

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have modified it. Thanks!

("2d_input_dim_0", (10, 3), 0, (0, 2)),
("2d_input_dim_1", (5, 10), 1, (1, 2, 3)),
("3d_input_dim_0", (10, 5, 10), 0, (0, 5)),
("3d_input_dim_2", (10, 5, 10), 2, (3, 3, 4)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test case for a negative dim input

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added a test case for a negative dim input and verified a test case. Thank you!

kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.index.index_select(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the index_select function could be put into select.py

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved index_select inside select.py. Thank you!

Copy link
Collaborator

@zewenli98 zewenli98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

elementwise,
embedding,
grid,
index,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can likely be removed - it seems to be causing a circular import error in CI

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! It seems I overlooked removing an unnecessary import.

Copy link
Contributor

@gs-olive gs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

@narendasan narendasan merged commit cec3835 into main Apr 12, 2024
@narendasan narendasan deleted the aten_index_select_converter branch April 12, 2024 00:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

aten.index_select

5 participants