-
Notifications
You must be signed in to change notification settings - Fork 30.9k
Fx with meta #16836
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fx with meta #16836
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
…, and model attributes are copied to the traced version
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why the changes in type annotations of Electra, GPT-Neo and RoBERTa are included in this PR, they should go in a separate PR that's not focused on ONNX.
The actual code changes are way too complex for me to fully parse, so I'll trust you have tested them on all existing models. Make sure you run all slow ONNX tests before merging this PR though. I'd like to know why one of the tests is removed however.
|
@sgugger is it ready now? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On my side, yes, but let's wait for @LysandreJik approval as well as this is a lot of change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Big change! Superficially reviewed, but it looks good to me. I haven't tried it out myself.
Should additional tests be implemented to ensure that everything works as intended? If you think current tests provide enough coverage, happy to merge it like this.
|
Since it does not change the API, and does not provide any new feature, the current set of test is enough. |
* Add meta proxy * Uses meta data to trace data dependent control-flow * Remove commented class * Handles torch creating functions * Added type annotation to fix tracing * Tracing works for everything but T5 and GPT-J * Almost all previously supported models pass * All architectures can be traced except T5 * Intermediate commit to have a trace of the comparison operators for HFProxy * Everything works, except loss computation * Everything works * Removed unused import * Overriden methods do not use underlying ops (linear and torch.matmul), and model attributes are copied to the traced version * Fix torch_matmul_override * Change attributes reference to deepcopy * Remove breakpoint and add torch_index_override * Small fix * Fix typo * Replace asserts by explicit exceptions
| return super().__eq__(other) | ||
| @property | ||
| def dtype(self): | ||
| return self.tracer.root.dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line causes the following code to be dead. Is this intended?
|
|
||
| def __contains__(self, key): | ||
| return False | ||
| return self.tracer.create_proxy("call_method", "size", (self,), {}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implementation does not provide any way to get concrete shape information from the MetaTensor, and thus almost all of the tests are failing because of errors passing tensor shapes to indexing (when is_torch_fx_available = True)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Failing tests for reference:
FAILED tests/models/albert/test_modeling_albert.py::AlbertModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/albert/test_modeling_albert.py::AlbertModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/albert/test_modeling_albert.py::AlbertModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/albert/test_modeling_albert.py::AlbertModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/bert/test_modeling_bert.py::BertModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/bert/test_modeling_bert.py::BertModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/bert/test_modeling_bert.py::BertModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/bert/test_modeling_bert.py::BertModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/distilbert/test_modeling_distilbert.py::DistilBertModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/distilbert/test_modeling_distilbert.py::DistilBertModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or have an __...
FAILED tests/models/distilbert/test_modeling_distilbert.py::DistilBertModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/distilbert/test_modeling_distilbert.py::DistilBertModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or have an __...
FAILED tests/models/electra/test_modeling_electra.py::ElectraModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/electra/test_modeling_electra.py::ElectraModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or have an __index__ m...
FAILED tests/models/electra/test_modeling_electra.py::ElectraModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/electra/test_modeling_electra.py::ElectraModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or have an __index__ m...
FAILED tests/models/gpt2/test_modeling_gpt2.py::GPT2ModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/gpt2/test_modeling_gpt2.py::GPT2ModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/gpt_neo/test_modeling_gpt_neo.py::GPTNeoModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/gpt_neo/test_modeling_gpt_neo.py::GPTNeoModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or have an __index__ me...
FAILED tests/models/gptj/test_modeling_gptj.py::GPTJModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/gptj/test_modeling_gptj.py::GPTJModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/megatron_bert/test_modeling_megatron_bert.py::MegatronBertModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __inde...
FAILED tests/models/megatron_bert/test_modeling_megatron_bert.py::MegatronBertModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or ha...
FAILED tests/models/megatron_bert/test_modeling_megatron_bert.py::MegatronBertModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __inde...
FAILED tests/models/megatron_bert/test_modeling_megatron_bert.py::MegatronBertModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or ha...
FAILED tests/models/mobilebert/test_modeling_mobilebert.py::MobileBertModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/mobilebert/test_modeling_mobilebert.py::MobileBertModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or have an __...
FAILED tests/models/mobilebert/test_modeling_mobilebert.py::MobileBertModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/mobilebert/test_modeling_mobilebert.py::MobileBertModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or have an __...
FAILED tests/models/roberta/test_modeling_roberta.py::RobertaModelTest::test_torch_fx - AssertionError: Couldn't trace module.
FAILED tests/models/roberta/test_modeling_roberta.py::RobertaModelTest::test_torch_fx_output_loss - AssertionError: Couldn't trace module.
FAILED tests/models/t5/test_modeling_t5.py::T5ModelTest::test_torch_fx - AssertionError: Couldn't trace module.
FAILED tests/models/t5/test_modeling_t5.py::T5ModelTest::test_torch_fx_output_loss - AssertionError: Couldn't trace module.
* Add meta proxy * Uses meta data to trace data dependent control-flow * Remove commented class * Handles torch creating functions * Added type annotation to fix tracing * Tracing works for everything but T5 and GPT-J * Almost all previously supported models pass * All architectures can be traced except T5 * Intermediate commit to have a trace of the comparison operators for HFProxy * Everything works, except loss computation * Everything works * Removed unused import * Overriden methods do not use underlying ops (linear and torch.matmul), and model attributes are copied to the traced version * Fix torch_matmul_override * Change attributes reference to deepcopy * Remove breakpoint and add torch_index_override * Small fix * Fix typo * Replace asserts by explicit exceptions
What does this PR do?
This PR simplies and improves the way tracing works with torch.fx.
Instead of recording concrete values via a forward pass on the original model, metadata is attached to the proxies, either tensors on the
metadevice (which saves us from making actual computations, only shape inference is performed) or any other type such astorch.Sizeand builtin types. On top of allowing to trace very big models, this gives much more flexibility and should allow to support many new architectures.A big thanks to @jamesr66a as he was the one who provided the basis for tracing with meta tensors, and I simply extended what was already done to our purposes.
@jamesr66a @pbelevich I would love your review and feedbacks!