KEMBAR78
Support fp8 in AOTInductor + support optional<> in C ABI by int3 · Pull Request #112527 · pytorch/pytorch · GitHub
Skip to content

Conversation

int3
Copy link
Contributor

@int3 int3 commented Oct 31, 2023

Stack from ghstack (oldest at bottom):

This was originally ipiszy's PR: #112358

It turns out that we need to add support for optional types in order to
support fp8 gemm (i.e. scaled_mm). Since our ABI-stable C interface
can't support optional<> directly, I am passing in optional types via
pointer instead.

AtenTensorHandles are already pointers, so nothing needs to change
there. Only value types need to change.

We decided on this approach instead of adding an extra bool param to
the callee because this simplifies things. Having the same number of
arguments regardless of whether we are emitting Python / C++ /
ABI-compatible C++ makes codegen easier.

There are a number of existing ABI-compatible functions that have
optional-typed value parameters. Previously, they just assumed they
would never be passed a nullopt / None at runtime. Changing them to
use pointer types now would break ABI stability, so I have created an
exclude list for those functions.

Finally, I think the current implementation is kind of messy, and only
works for FallbackKernels, even though technically ExternKernels could
also have the same issue. It also doesn't support optional types nested
in lists. I've left FIXME comments for both issues.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler

Differential Revision: D51084289

This was originally @ipiszy's PR: #112358

It turns out that we need to add support for optional types in order to
support fp8 gemm (i.e. scaled_mm). Since our ABI-stable C interface
can't support optional<> directly, I have created a ShimOptional struct
instead.

This ShimOptional is used only for non-pointer optional types; pointer
optionals can have their nullopt value represented by `nullptr`.

I decided to create a ShimOptional instead of adding an extra `bool`
param to the callee because this simplifies things. Having the same
number of arguments regardless of whether we are emitting Python / C++ /
ABI-compatible C++ makes codegen easier.

There are a number of existing ABI-compatible functions that have
optional-typed parameters. Previously, they just assumed they would
never be passed a `nullopt` / `None` at runtime. Changing them to use
ShimOptional now would break ABI stability, so I have created an
exclude list for those functions.

Finally, I think the current implementation is kind of messy, pulling in
argument type info from a variety of places, and possibly missing some
edge cases with const arg codegen. I've left a bunch of FIXME comments;
would appreciate feedback on whether I could improve things.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 31, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/112527

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 2b1de23 with merge base 78b8465 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

int3 added a commit that referenced this pull request Oct 31, 2023
This was originally ipiszy's PR: #112358

It turns out that we need to add support for optional types in order to
support fp8 gemm (i.e. scaled_mm). Since our ABI-stable C interface
can't support optional<> directly, I have created a ShimOptional struct
instead.

This ShimOptional is used only for non-pointer optional types; pointer
optionals can have their nullopt value represented by `nullptr`.

I decided to create a ShimOptional instead of adding an extra `bool`
param to the callee because this simplifies things. Having the same
number of arguments regardless of whether we are emitting Python / C++ /
ABI-compatible C++ makes codegen easier.

There are a number of existing ABI-compatible functions that have
optional-typed parameters. Previously, they just assumed they would
never be passed a `nullopt` / `None` at runtime. Changing them to use
ShimOptional now would break ABI stability, so I have created an
exclude list for those functions.

Finally, I think the current implementation is kind of messy, pulling in
argument type info from a variety of places, and possibly missing some
edge cases with const arg codegen. I've left a bunch of FIXME comments;
would appreciate feedback on whether I could improve things.

ghstack-source-id: ae2aed5
Pull Request resolved: #112527

def codegen_const_args(self):
return map(V.graph.wrapper_code.val_to_arg_str, self.constant_args)
# FIXME: separate invocation for cpp arg strs?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this is fine since looking at its caller, codegen_args, it looks like the codegen is meant to target Python only

I'm overall confused by the difference between ExternKernels and FallbackKernels though, and how codegen logic is split between the two. Both of them define codegen_args, but they share a single implementation of codegen_kwargs...

@int3 int3 requested review from SherlockNoMad, chenyang78, desertfire, ipiszy, jansel and pipibjc and removed request for pipibjc October 31, 2023 20:48
@ipiszy
Copy link
Contributor

ipiszy commented Nov 1, 2023

Thanks @int3 !

I wonder have we considered using a pointer to represent optional instead of introducing a boolean value? Code would look cleaner in this case.

wrt existing APIs which are affected by this change: can we add new APIs and deprecate old APIs in parallel? e.g. We could add a new set of APIs in this PR, and then after the PR is released in prod, we remove the legacy branching logics and hard code kernel names of these operators with new API names.

@int3
Copy link
Contributor Author

int3 commented Nov 1, 2023

I wonder have we considered using a pointer to represent optional instead of introducing a boolean value? Code would look cleaner in this case.

Not sure what you have in mind here. You mean something like foo(int *opt_int) that gets invoked like foo(new int(123))? Not sure how that's cleaner...

wrt existing APIs which are affected by this change: can we add new APIs and deprecate old APIs in parallel? e.g. We could add a new set of APIs in this PR, and then after the PR is released in prod, we remove the legacy branching logics and hard code kernel names of these operators with new API names.

Yeah we should do that. I think it might be easier to do it in a follow-up PR though

@chenyang78
Copy link
Contributor

I wonder have we considered using a pointer to represent optional instead of introducing a boolean value? Code would look cleaner in this case.

Not sure what you have in mind here. You mean something like foo(int *opt_int) that gets invoked like foo(new int(123))? Not sure how that's cleaner...

I also think using pointers might be better. The idea would be that we always pass a pointer of the element type of an optional argument, where

(1) we use NULL to present c10::nullopt; and
(2) make a non-NULL pointer point to any non-nullopt value

For example, for c10::optional<int> case, we would have something like

foo(0 /*c10::nullopt*/);

or

int arg = 10;
foo(&arg);

Because we always know the element type of the optional argument in shim, we wouldn't have ambiguity. Moreover, we wouldn't have to implement our own Optional struct to handle various types.

if hasattr(self, "kwargs_default_value"):
type_ = self.kwargs_default_value.get(arg_name).get("type")
else:
type_ = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to just throw an exception if we couldn't get a valid type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That creates a lot of runtime errors... the issue is that kwargs_default_value is defined only on FallbackKernel, but this method is defined on ExternKernel. I was hoping @desertfire or @jansel might be able to suggest a better way of getting the type here, and/or if this is a thing that I need to concern myself with for non-fallback ExternKernels

x.name for x in kernel._schema.arguments if x.kwarg_only
]

def is_legacy_abi_kernel(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we really need this. For these two kernels, all optional arguments have real values, so we wouldn't hit the c10::nullopt path. Moreover, we could add another version of interface, e.g. aoti_torch__scaled_dot_product_flash_attention_nullopt to handle the missing nullopt cases. Relying on this is_legacy_abi_kernel looks very hacky to me.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to distinguish these two models, but I agree with @chenyang78 that generating these two fallback functions names differently in the wrapper codegen is a cleaner solution.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For these two kernels, all optional arguments have real values, so we wouldn't hit the c10::nullopt path.

We would be changing the way the real values get passed in though (regardless of whether we are doing the ShimOptional or pointer approach).

But okay I am happy to do the other less hacky approach once we figure out the issues with getting the valid type above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if we add new versions of these two interfaces (actually we only need to care about the flash_attention one and the other is not used in prod), to make sure the newly published snapshots work with old predictor binaries, we still need to keep the branch until the new API is released in predictor, correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If repeat_interleave_Tensor is not in production, let's fix it by changing its shim API. cc @adnanaziz

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you meant to tag @aakhundov :)

Just checked with him, he says it's fine to change

AtenTensorHandle scale_a,
AtenTensorHandle scale_b,
AtenTensorHandle scale_result,
bool use_fast_accum,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not use bool in the C interface.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, why not? Is it because C bools and C++ bools are subtly different?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is sufficient. What if the call passes in a float?

The idea is to use reinterpret_cast so the type doesn't matter. Although I see that I forgot to add that into the codegen. But I see what @chenyang78 meant by using pointers instead... that might be nicer, I'll give it a shot.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like aoti_torch__scaled_dot_product_flash_attention is also using bools

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As described in that stackoverflow thread, "C's and C++'s bool type are different, but, as long as you stick to the same compiler (in your case, gcc), it should be safe, as this is a reasonable common scenario.", but we can't make that assumption here, as we don't know what the model.so was compiled with.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I thought we were only concerned with models compiled under fbcode. Okay, I'll change it.

x.name for x in kernel._schema.arguments if x.kwarg_only
]

def is_legacy_abi_kernel(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to distinguish these two models, but I agree with @chenyang78 that generating these two fallback functions names differently in the wrapper codegen is a cleaner solution.

This was originally ipiszy's PR: #112358

It turns out that we need to add support for optional types in order to
support fp8 gemm (i.e. scaled_mm). Since our ABI-stable C interface
can't support optional<> directly, I have created a ShimOptional struct
instead.

This ShimOptional is used only for non-pointer optional types; pointer
optionals can have their nullopt value represented by `nullptr`.

I decided to create a ShimOptional instead of adding an extra `bool`
param to the callee because this simplifies things. Having the same
number of arguments regardless of whether we are emitting Python / C++ /
ABI-compatible C++ makes codegen easier.

There are a number of existing ABI-compatible functions that have
optional-typed parameters. Previously, they just assumed they would
never be passed a `nullopt` / `None` at runtime. Changing them to use
ShimOptional now would break ABI stability, so I have created an
exclude list for those functions.

Finally, I think the current implementation is kind of messy, pulling in
argument type info from a variety of places, and possibly missing some
edge cases with const arg codegen. I've left a bunch of FIXME comments;
would appreciate feedback on whether I could improve things.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
int3 added a commit that referenced this pull request Nov 2, 2023
This was originally ipiszy's PR: #112358

It turns out that we need to add support for optional types in order to
support fp8 gemm (i.e. scaled_mm). Since our ABI-stable C interface
can't support optional<> directly, I have created a ShimOptional struct
instead.

This ShimOptional is used only for non-pointer optional types; pointer
optionals can have their nullopt value represented by `nullptr`.

I decided to create a ShimOptional instead of adding an extra `bool`
param to the callee because this simplifies things. Having the same
number of arguments regardless of whether we are emitting Python / C++ /
ABI-compatible C++ makes codegen easier.

There are a number of existing ABI-compatible functions that have
optional-typed parameters. Previously, they just assumed they would
never be passed a `nullopt` / `None` at runtime. Changing them to use
ShimOptional now would break ABI stability, so I have created an
exclude list for those functions.

Finally, I think the current implementation is kind of messy, pulling in
argument type info from a variety of places, and possibly missing some
edge cases with const arg codegen. I've left a bunch of FIXME comments;
would appreciate feedback on whether I could improve things.

ghstack-source-id: 4d47aec
Pull Request resolved: #112527
This was originally ipiszy's PR: #112358

It turns out that we need to add support for optional types in order to
support fp8 gemm (i.e. scaled_mm). Since our ABI-stable C interface
can't support optional<> directly, I am passing in optional types via
pointer instead.

`AtenTensorHandle`s are already pointers, so nothing needs to change
there. Only value types need to change.

We decided on this approach instead of adding an extra `bool` param to
the callee because this simplifies things. Having the same number of
arguments regardless of whether we are emitting Python / C++ /
ABI-compatible C++ makes codegen easier.

There are a number of existing ABI-compatible functions that have
optional-typed value parameters. Previously, they just assumed they
would never be passed a `nullopt` / `None` at runtime. Changing them to
use pointer types now would break ABI stability, so I have created an
exclude list for those functions.

Finally, I think the current implementation is kind of messy, and only
works for FallbackKernels, even though technically ExternKernels could
also have the same issue. It also doesn't support optional types nested
in lists. I've left FIXME comments for both issues.

[ghstack-poisoned]
@int3
Copy link
Contributor Author

int3 commented Nov 3, 2023

  • Now using @chenyang78's suggestion re passing all optionals by pointer
  • Ultimately I'd like ExternKernel to have access to the cpp schema, which would make passing the type to val_to_arg_str much simpler, but for now I'm using a hacky workaround to unblock things.
  • Not doing the legacy ABI versioning in this diff; @chenyang78 says he will tackle it as he ran into something else that needs this too
  • Still using bools in the interface for now, pending feedback. Initial research suggests that this is safe

make_fallback(aten._thnn_fused_lstm_cell, require_dense)
make_fallback(aten.topk)
make_fallback(aten.upsample_bicubic2d_backward, require_contiguous)
make_fallback(aten._scaled_mm.default)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we probably need to keep this since we haven't added lowering for _scaled_mm? (I'm a bit confused about this as well). Have you tried that test/inductor/test_fp8.py can run successfully if we remove this fallback?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, I've checked and the test passes

not entirely sure either about how the fallback mechanism works though

This was originally ipiszy's PR: #112358

It turns out that we need to add support for optional types in order to
support fp8 gemm (i.e. scaled_mm). Since our ABI-stable C interface
can't support optional<> directly, I am passing in optional types via
pointer instead.

`AtenTensorHandle`s are already pointers, so nothing needs to change
there. Only value types need to change.

We decided on this approach instead of adding an extra `bool` param to
the callee because this simplifies things. Having the same number of
arguments regardless of whether we are emitting Python / C++ /
ABI-compatible C++ makes codegen easier.

There are a number of existing ABI-compatible functions that have
optional-typed value parameters. Previously, they just assumed they
would never be passed a `nullopt` / `None` at runtime. Changing them to
use pointer types now would break ABI stability, so I have created an
exclude list for those functions.

Finally, I think the current implementation is kind of messy, and only
works for FallbackKernels, even though technically ExternKernels could
also have the same issue. It also doesn't support optional types nested
in lists. I've left FIXME comments for both issues.

[ghstack-poisoned]
int3 added a commit that referenced this pull request Nov 3, 2023
This was originally ipiszy's PR: #112358

It turns out that we need to add support for optional types in order to
support fp8 gemm (i.e. scaled_mm). Since our ABI-stable C interface
can't support optional<> directly, I am passing in optional types via
pointer instead.

`AtenTensorHandle`s are already pointers, so nothing needs to change
there. Only value types need to change.

We decided on this approach instead of adding an extra `bool` param to
the callee because this simplifies things. Having the same number of
arguments regardless of whether we are emitting Python / C++ /
ABI-compatible C++ makes codegen easier.

There are a number of existing ABI-compatible functions that have
optional-typed value parameters. Previously, they just assumed they
would never be passed a `nullopt` / `None` at runtime. Changing them to
use pointer types now would break ABI stability, so I have created an
exclude list for those functions.

Finally, I think the current implementation is kind of messy, and only
works for FallbackKernels, even though technically ExternKernels could
also have the same issue. It also doesn't support optional types nested
in lists. I've left FIXME comments for both issues.

ghstack-source-id: 0bcf2f1
Pull Request resolved: #112527
@int3
Copy link
Contributor Author

int3 commented Nov 8, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR has internal changes and must be landed via Phabricator

Details for Dev Infra team Raised by workflow job

@int3
Copy link
Contributor Author

int3 commented Nov 8, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR has internal changes and must be landed via Phabricator

Details for Dev Infra team Raised by workflow job

@int3
Copy link
Contributor Author

int3 commented Nov 8, 2023

Let's see if unlinking works...

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@swolchok
Copy link
Contributor

Changing them to use pointer types now would break ABI stability

are we stuck with ABI stability already?

@int3
Copy link
Contributor Author

int3 commented Nov 10, 2023

We do need to maintain backwards compatibility for _scaled_dot_product_flash_attention which is being used. But I believe the plan is to make a temporary v2 version of the function, deprecate the old one, then move all callers to using the new ABI format. @chenyang78 is looking into it.

@chenyang78
Copy link
Contributor

We do need to maintain backwards compatibility for _scaled_dot_product_flash_attention which is being used. But I believe the plan is to make a temporary v2 version of the function, deprecate the old one, then move all callers to using the new ABI format. @chenyang78 is looking into it.

@swolchok #113487

Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
)

This was originally ipiszy's PR: pytorch#112358

It turns out that we need to add support for optional types in order to
support fp8 gemm (i.e. scaled_mm). Since our ABI-stable C interface
can't support optional<> directly, I am passing in optional types via
pointer instead.

`AtenTensorHandle`s are already pointers, so nothing needs to change
there. Only value types need to change.

We decided on this approach instead of adding an extra `bool` param to
the callee because this simplifies things. Having the same number of
arguments regardless of whether we are emitting Python / C++ /
ABI-compatible C++ makes codegen easier.

There are a number of existing ABI-compatible functions that have
optional-typed value parameters. Previously, they just assumed they
would never be passed a `nullopt` / `None` at runtime. Changing them to
use pointer types now would break ABI stability, so I have created an
exclude list for those functions.

Finally, I think the current implementation is kind of messy, and only
works for FallbackKernels, even though technically ExternKernels could
also have the same issue. It also doesn't support optional types nested
in lists. I've left FIXME comments for both issues.

Differential Revision: [D51084289](https://our.internmc.facebook.com/intern/diff/D51084289)
Pull Request resolved: pytorch#112527
Approved by: https://github.com/chenyang78, https://github.com/desertfire
pytorchmergebot pushed a commit that referenced this pull request Nov 15, 2023
…112527)" (#113747)

Test Plan: sandcastle

Differential Revision: D51330618

Pull Request resolved: #113747
Approved by: https://github.com/chenyang78, https://github.com/khabinov
int3 added a commit that referenced this pull request Dec 1, 2023
This is a backout of #113747 which reverted the above two commits.

[ghstack-poisoned]
int3 added a commit that referenced this pull request Dec 1, 2023
This is a backout of #113747 which reverted the above two commits.

ghstack-source-id: f5a794f
Pull Request resolved: #114974
int3 added a commit that referenced this pull request Dec 1, 2023
This is a backout of #113747 which reverted the above two commits.

[ghstack-poisoned]
int3 added a commit that referenced this pull request Dec 1, 2023
This is a backout of #113747 which reverted the above two commits.

ghstack-source-id: fecfab4
Pull Request resolved: #114990
int3 added a commit that referenced this pull request Dec 2, 2023
…pport)"


This is a backout of #113747 which reverted the above two commits. Now that
#113997 has landed, this diff can be landed safely without breaking ABI compatibility.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
int3 added a commit that referenced this pull request Dec 2, 2023
This is a backout of #113747 which reverted the above two commits.

Pull Request resolved: #114974
ghstack-source-id: fecfab4
pytorchmergebot pushed a commit that referenced this pull request Dec 2, 2023
…4974)

This is a backout of #113747 which reverted the above two commits. Now that
#113997 has landed, this diff can be landed safely without breaking ABI compatibility.

Pull Request resolved: #114974
Approved by: https://github.com/chenyang78
dmenig pushed a commit to dmenig/pytorch that referenced this pull request Dec 21, 2023
… support) (pytorch#114974)

This is a backout of pytorch#113747 which reverted the above two commits. Now that
pytorch#113997 has landed, this diff can be landed safely without breaking ABI compatibility.

Pull Request resolved: pytorch#114974
Approved by: https://github.com/chenyang78
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants