-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Add ScalarType -> shim conversion, add stable::Tensor.scalar_type #160557
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160557
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit f12f9d1 with merge base a44a0d3 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
||
stack[0] = from(t); | ||
stack[1] = from(std::optional(t_dtype)); // dtype | ||
stack[1] = from(std::optional(t.scalar_type())); // dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For testing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The to/from logic of this file got moved to utils.h
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file got split into tensor-struct.h which has everything before the PR, and tensor-inl.h which implements scalar_type as it relies on from/to. The reason for the code split is to allow the code to build without circular dependencies. Without the split, tensor.h would depend on library.h (for to/from) and library.h would depend on tensor.h (cuz to/from Tensor needs a Tensor def).
Now, utils.h (which has to/from) depends on tensor-struct.h, tensor-inl.h depends on both utils.h and tensor-struct.h, and users depend on tensor.h still, which depends on all of the above.
case ScalarType::UInt64: | ||
return from(aoti_torch_dtype_uint64()); | ||
default: | ||
throw std::runtime_error( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NOTE!! THIS IS WHERE I WANT REVIEW!! cc @albanD
Prior, if we had an IValue dtype that was qint8, from_ivalue would call from(ScalarType::Qint8)
, and the code would just reinterpret the enum and spit out the int32_t correspondingly. This was okay because ScalarType wasn't exposed to the end user, and all they had to work with was an abstracted int32_t that they would get from the C shim.
However, with this change today, from(ScalarType::Qint8)
would error!!!! Because now, ScalarType is allowed to be used by the end user, and they can call this function, and naively reinterpreting the enum is no longer ok if the extension's ScalarType is different from libtorch's ScalarType! I think erroring is acceptable because these other types are infrequently used by people anyway, but maybe I am wrong about that. e.g., @swolchok are the Bits ScalarTypes used in ET?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Answering your specific question:
I haven't (yet?) made any attempt to use PyTorch's ScalarType in ExecuTorch. ExecuTorch has https://github.com/pytorch/executorch/blob/52b45e2d2ac244b13a36ddf5d21a9ebe8d8aa17e/runtime/core/portable_type/scalar_type.h#L132
PyTorch's ScalarType will get used in ExecuTorch's ATen mode, though. https://github.com/pytorch/executorch/blob/52b45e2d2ac244b13a36ddf5d21a9ebe8d8aa17e/runtime/core/exec_aten/exec_aten.h#L82
I don't know what the Bits ScalarTypes even are, but ExecuTorch seems to have its own versions of them that it uses: https://github.com/search?q=repo%3Apytorch%2Fexecutorch+ScalarType%3A%3ABits+language%3AC%2B%2B&type=code&l=C%2B%2B
In general: it is not backward compatible to change functionality such that a call that previously succeeded (and really did work fine) is now an error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've concluded it is okay to BC break here given that GitHub search yields 0 users for the narrow use case for which this code would break. Updated the PR body consequently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this code is a copy pasta EXCEPT for the specializations for torch::headeronly::ScalarType
…ar_type" This change _modifies_ the from/to behavior between ScalarType and StableValue! Then we add a Tensor scalar_type API which reuses the from/to logic to return to the user a nice ScalarType (vs an abstracted int32_t). I then changed the test to test the scalar_type API. This code change required some refactoring because of circular dependencies. [ghstack-poisoned]
…ar_type" This change _modifies_ the from/to behavior between ScalarType and StableValue! Then we add a Tensor scalar_type API which reuses the from/to logic to return to the user a nice ScalarType (vs an abstracted int32_t). I then changed the test to test the scalar_type API. This code change required some refactoring because of circular dependencies. [ghstack-poisoned]
@@ -0,0 +1,342 @@ | |||
#pragma once |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I firmly dislike naming things "utils" because it is synonymous with "stuff" and helps neither predict their current contents nor limit their future contents. Instead I would consider a specific name, like say StableIValueConversions.h .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
renamed!!!!!! stableivalue_conversions.h
case ScalarType::UInt64: | ||
return from(aoti_torch_dtype_uint64()); | ||
default: | ||
throw std::runtime_error( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Answering your specific question:
I haven't (yet?) made any attempt to use PyTorch's ScalarType in ExecuTorch. ExecuTorch has https://github.com/pytorch/executorch/blob/52b45e2d2ac244b13a36ddf5d21a9ebe8d8aa17e/runtime/core/portable_type/scalar_type.h#L132
PyTorch's ScalarType will get used in ExecuTorch's ATen mode, though. https://github.com/pytorch/executorch/blob/52b45e2d2ac244b13a36ddf5d21a9ebe8d8aa17e/runtime/core/exec_aten/exec_aten.h#L82
I don't know what the Bits ScalarTypes even are, but ExecuTorch seems to have its own versions of them that it uses: https://github.com/search?q=repo%3Apytorch%2Fexecutorch+ScalarType%3A%3ABits+language%3AC%2B%2B&type=code&l=C%2B%2B
In general: it is not backward compatible to change functionality such that a call that previously succeeded (and really did work fine) is now an error.
…ar_type" This change _modifies_ the from/to behavior between ScalarType and StableValue! Then we add a Tensor scalar_type API which reuses the from/to logic to return to the user a nice ScalarType (vs an abstracted int32_t). I then changed the test to test the scalar_type API. This code change required some refactoring because of circular dependencies. [ghstack-poisoned]
…ar_type" This change _modifies_ the from/to behavior between ScalarType and StableValue! Then we add a Tensor scalar_type API which reuses the from/to logic to return to the user a nice ScalarType (vs an abstracted int32_t). I then changed the test to test the scalar_type API. This code change required some refactoring because of circular dependencies. [ghstack-poisoned]
struct FromImpl<ScalarType> { | ||
static StableIValue call(ScalarType val) { | ||
switch (val) { | ||
case ScalarType::Byte: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This feels like it can benefit from define of sorts, that iterates over known dtypes (if dtypes are part of stable API...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was trying to figure out a way, but it did not seem worth it for this PR. (Also there's no "is dtype part of stable API" function we can call yet).
auto inner_val = to<T>(*sivp); | ||
|
||
// free the memory associated with StableIValue* sivp | ||
delete sivp; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This feels very suspicious.. Why not pass the value as unique_ptr?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about passing it around as std::unique_ptr, but you could certainly declare std::unique_ptr<StableIValue> sivp = to<StableIValue*>(val);
above and insulate yourself against the inner to<T>
throwing an exception.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will address in a followup
@@ -0,0 +1,342 @@ | |||
#pragma once |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: is "stableivalue_conversions.h" really the conventional formatting here? I would've expected "StableIValueConversions.h", as with torch/headeronly/core/ScalarType.h
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think in general it would be good to document the header format, as I've seen both CamelCase.h
, snake_case.h
and something-weird.h
(for example tensor-inl.h
in this PR)
For example, aoti follows snake_case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
idk the conventional format, but I was going under the guise of using Caps when there was a struct definition of the same name, and using lowercase for more thematic naming (like everything in this file relates to ____). And then the -dashes I'm copying from in c10/ Half and Half-inl.h where the dash means it's a continuation of a file, that these files would be together if possible but are broken apart for some other reason.
I'm happy to follow an existing header notation though, if there is one
auto inner_val = to<T>(*sivp); | ||
|
||
// free the memory associated with StableIValue* sivp | ||
delete sivp; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about passing it around as std::unique_ptr, but you could certainly declare std::unique_ptr<StableIValue> sivp = to<StableIValue*>(val);
above and insulate yourself against the inner to<T>
throwing an exception.
@@ -0,0 +1,342 @@ | |||
#pragma once |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think in general it would be good to document the header format, as I've seen both CamelCase.h
, snake_case.h
and something-weird.h
(for example tensor-inl.h
in this PR)
For example, aoti follows snake_case
return from(aoti_torch_dtype_uint8()); | ||
case ScalarType::Char: | ||
return from(aoti_torch_dtype_int8()); | ||
case ScalarType::Short: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Q: Why not add unisgned types here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was following the convention of the actual enum order
int32_t shim_scalartype = to<int32_t>(val); | ||
if (shim_scalartype == aoti_torch_dtype_uint8()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to cast it to some enum and use switch statement? (As it will force devs to add options there when new dtype is added) This statement will be result of never ending series of "Added missing XYZ" here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well the whole reason to have this is because the user binary ScalarType baked in isn't necessarily the same enum as the libtorch binary ScalarType, which is why I'm passing ints through the shim.
return ScalarType::Float8_e5m2fnuz; | ||
} else if (shim_scalartype == aoti_torch_dtype_float8_e4m3fnuz()) { | ||
return ScalarType::Float8_e4m3fnuz; | ||
} else if (shim_scalartype == aoti_torch_dtype_uint16()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you support unisgned dtypes here but not in the previous switch statement?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They're supported above too
@@ -0,0 +1,24 @@ | |||
#pragma once |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See my above comment. Use either CamelCase or snake_case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay, ill switch to snake case for these
void* data_ptr() const { | ||
void* data_ptr; | ||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(ath_.get(), &data_ptr)); | ||
return data_ptr; | ||
} | ||
|
||
int64_t dim() const { | ||
int64_t dim; | ||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(ath_.get(), &dim)); | ||
return dim; | ||
} | ||
|
||
int64_t numel() const { | ||
int64_t numel; | ||
TORCH_ERROR_CODE_CHECK(aoti_torch_get_numel(ath_.get(), &numel)); | ||
return numel; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please please use macros to get rid of copy_pasta
void* data_ptr() const { | |
void* data_ptr; | |
TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(ath_.get(), &data_ptr)); | |
return data_ptr; | |
} | |
int64_t dim() const { | |
int64_t dim; | |
TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(ath_.get(), &dim)); | |
return dim; | |
} | |
int64_t numel() const { | |
int64_t numel; | |
TORCH_ERROR_CODE_CHECK(aoti_torch_get_numel(ath_.get(), &numel)); | |
return numel; | |
} | |
#define _DEF_COSNT_ACCESSOR_METHOD(NAME, DTYPE) \ | |
DTYPE NAME() const { \ | |
DTYPE rc; \ | |
TORCH_ERROR_CODE_CHECK(aoti_torch_get_##NAME(ath_.get(), &rc)); \ | |
return rc; \ | |
} | |
_DEF_CONST_ACCESSOR_METHOD(data_ptr, void*); | |
_DEF_CONST_ACCESSOR_METHOD(dim, int64_t); | |
_DEF_CONST_ACCESSOR_METHOD(numel, int64_t); | |
#undef _DEF_CONST_ACCESSOR_METHOD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is more readable currently to see these APIs as is. If we need to add more, I will consider using a preprocessor macro.
…ar_type" TL;DR: Moving to ScalarType in user extensions and removing deprecated dtypes. This change _modifies_ the from/to behavior between ScalarType and StableValue! Whereas before, user extensions could only in abstract pass around obfuscated dtypes appearing as int32_ts, now, users can confidently use torch::headeronly::ScalarType in their extensions for major scalar types. This PR enables ABI stability by adding a translation layer through the shim, so that even if the ScalarType enum values change in the future, user extensions need not fear. Then we add a Tensor scalar_type API which reuses the from/to logic to return to the user a nice ScalarType (vs an abstracted int32_t). I then changed the test to test the scalar_type API. This code change required some refactoring because of circular dependencies. ## BC Breaking note This commit is (narrowly) BC-breaking for unpopular dtypes: `quint*`s, `qint*`s, `Bits*`, `dummy_uint*`s, `dummy_int*`s, `Float8_e8m0fnu`, and `Float4_e2m1fn_x2` in the narrow use case where an extension retrieves a Tensor dtype of the above and passes it into `aoti_torch_call_dispatcher`. As of now, I believe there are 0 users of this use case, so the benefits of this change significantly justify BC-breaking this API. [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Pull Request resolved: #159508 Approved by: https://github.com/janeyx99 ghstack dependencies: #160557
…torch#160557) TL;DR: Moving to ScalarType in user extensions and removing deprecated dtypes. This change _modifies_ the from/to behavior between ScalarType and StableValue! Whereas before, user extensions could only in abstract pass around obfuscated dtypes appearing as int32_ts, now, users can confidently use torch::headeronly::ScalarType in their extensions for major scalar types. This PR enables ABI stability by adding a translation layer through the shim, so that even if the ScalarType enum values change in the future, user extensions need not fear. Then we add a Tensor scalar_type API which reuses the from/to logic to return to the user a nice ScalarType (vs an abstracted int32_t). I then changed the test to test the scalar_type API. This code change required some refactoring because of circular dependencies. ## BC Breaking note This commit is (narrowly) BC-breaking for unpopular dtypes: `quint*`s, `qint*`s, `Bits*`, `dummy_uint*`s, `dummy_int*`s, `Float8_e8m0fnu`, and `Float4_e2m1fn_x2` in the narrow use case where an extension retrieves a Tensor dtype of the above and passes it into `aoti_torch_call_dispatcher`. As of now, I believe there are 0 users of this use case, so the benefits of this change significantly justify BC-breaking this API. Pull Request resolved: pytorch#160557 Approved by: https://github.com/mikaylagawarecki, https://github.com/malfet
…9508) Pull Request resolved: pytorch#159508 Approved by: https://github.com/janeyx99 ghstack dependencies: pytorch#160557
TL;DR: Moving to ScalarType in user extensions and removing deprecated dtypes.
This change modifies the from/to behavior between ScalarType and StableValue! Whereas before, user extensions could only in abstract pass around obfuscated dtypes appearing as int32_ts, now, users can confidently use torch::headeronly::ScalarType in their extensions for major scalar types. This PR enables ABI stability by adding a translation layer through the shim, so that even if the ScalarType enum values change in the future, user extensions need not fear.
Then we add a Tensor scalar_type API which reuses the from/to logic to return to the user a nice ScalarType (vs an abstracted int32_t).
I then changed the test to test the scalar_type API.
This code change required some refactoring because of circular dependencies.
BC Breaking note
This commit is (narrowly) BC-breaking for unpopular dtypes:
quint*
s,qint*
s,Bits*
,dummy_uint*
s,dummy_int*
s,Float8_e8m0fnu
, andFloat4_e2m1fn_x2
in the narrow use case where an extension retrieves a Tensor dtype of the above and passes it intoaoti_torch_call_dispatcher
. As of now, I believe there are 0 users of this use case, so the benefits of this change significantly justify BC-breaking this API.Stack from ghstack (oldest at bottom):