KEMBAR78
Fx with meta by michaelbenayoun · Pull Request #16836 · huggingface/transformers · GitHub
Skip to content

Conversation

@michaelbenayoun
Copy link
Member

@michaelbenayoun michaelbenayoun commented Apr 19, 2022

What does this PR do?

This PR simplies and improves the way tracing works with torch.fx.

Instead of recording concrete values via a forward pass on the original model, metadata is attached to the proxies, either tensors on the meta device (which saves us from making actual computations, only shape inference is performed) or any other type such as torch.Size and builtin types. On top of allowing to trace very big models, this gives much more flexibility and should allow to support many new architectures.

A big thanks to @jamesr66a as he was the one who provided the basis for tracing with meta tensors, and I simply extended what was already done to our purposes.

@jamesr66a @pbelevich I would love your review and feedbacks!

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 19, 2022

The documentation is not available anymore as the PR was closed or merged.

@michaelbenayoun michaelbenayoun marked this pull request as ready for review April 27, 2022 10:14
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why the changes in type annotations of Electra, GPT-Neo and RoBERTa are included in this PR, they should go in a separate PR that's not focused on ONNX.

The actual code changes are way too complex for me to fully parse, so I'll trust you have tested them on all existing models. Make sure you run all slow ONNX tests before merging this PR though. I'd like to know why one of the tests is removed however.

@michaelbenayoun
Copy link
Member Author

@sgugger is it ready now?

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On my side, yes, but let's wait for @LysandreJik approval as well as this is a lot of change.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Big change! Superficially reviewed, but it looks good to me. I haven't tried it out myself.

Should additional tests be implemented to ensure that everything works as intended? If you think current tests provide enough coverage, happy to merge it like this.

@michaelbenayoun
Copy link
Member Author

Since it does not change the API, and does not provide any new feature, the current set of test is enough.
I also ran the test with torch=1.10.2 (we have TORCH_FX_REQUIRED_VERSION = version.parse("1.10")), and they pass. The torch version could have been a limitation since we use the meta device, but everything is okay.

@michaelbenayoun michaelbenayoun merged commit 2c2a216 into huggingface:main May 2, 2022
@michaelbenayoun michaelbenayoun deleted the fx_with_meta branch May 2, 2022 09:47
stevhliu pushed a commit to stevhliu/transformers that referenced this pull request May 3, 2022
* Add meta proxy

* Uses meta data to trace data dependent control-flow

* Remove commented class

* Handles torch creating functions

* Added type annotation to fix tracing

* Tracing works for everything but T5 and GPT-J

* Almost all previously supported models pass

* All architectures can be traced except T5

* Intermediate commit to have a trace of the comparison operators for HFProxy

* Everything works, except loss computation

* Everything works

* Removed unused import

* Overriden methods do not use underlying ops (linear and torch.matmul), and model attributes are copied to the traced version

* Fix torch_matmul_override

* Change attributes reference to deepcopy

* Remove breakpoint and add torch_index_override

* Small fix

* Fix typo

* Replace asserts by explicit exceptions
return super().__eq__(other)
@property
def dtype(self):
return self.tracer.root.dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line causes the following code to be dead. Is this intended?


def __contains__(self, key):
return False
return self.tracer.create_proxy("call_method", "size", (self,), {})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation does not provide any way to get concrete shape information from the MetaTensor, and thus almost all of the tests are failing because of errors passing tensor shapes to indexing (when is_torch_fx_available = True)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Failing tests for reference:

FAILED tests/models/albert/test_modeling_albert.py::AlbertModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/albert/test_modeling_albert.py::AlbertModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/albert/test_modeling_albert.py::AlbertModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/albert/test_modeling_albert.py::AlbertModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/bert/test_modeling_bert.py::BertModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/bert/test_modeling_bert.py::BertModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/bert/test_modeling_bert.py::BertModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/bert/test_modeling_bert.py::BertModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/distilbert/test_modeling_distilbert.py::DistilBertModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/distilbert/test_modeling_distilbert.py::DistilBertModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or have an __...
FAILED tests/models/distilbert/test_modeling_distilbert.py::DistilBertModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/distilbert/test_modeling_distilbert.py::DistilBertModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or have an __...
FAILED tests/models/electra/test_modeling_electra.py::ElectraModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/electra/test_modeling_electra.py::ElectraModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or have an __index__ m...
FAILED tests/models/electra/test_modeling_electra.py::ElectraModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/electra/test_modeling_electra.py::ElectraModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or have an __index__ m...
FAILED tests/models/gpt2/test_modeling_gpt2.py::GPT2ModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/gpt2/test_modeling_gpt2.py::GPT2ModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/gpt_neo/test_modeling_gpt_neo.py::GPTNeoModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/gpt_neo/test_modeling_gpt_neo.py::GPTNeoModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or have an __index__ me...
FAILED tests/models/gptj/test_modeling_gptj.py::GPTJModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/gptj/test_modeling_gptj.py::GPTJModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/megatron_bert/test_modeling_megatron_bert.py::MegatronBertModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __inde...
FAILED tests/models/megatron_bert/test_modeling_megatron_bert.py::MegatronBertModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or ha...
FAILED tests/models/megatron_bert/test_modeling_megatron_bert.py::MegatronBertModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __inde...
FAILED tests/models/megatron_bert/test_modeling_megatron_bert.py::MegatronBertModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or ha...
FAILED tests/models/mobilebert/test_modeling_mobilebert.py::MobileBertModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/mobilebert/test_modeling_mobilebert.py::MobileBertModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or have an __...
FAILED tests/models/mobilebert/test_modeling_mobilebert.py::MobileBertModelTest::test_torch_fx - TypeError: slice indices must be integers or None or have an __index__ method
FAILED tests/models/mobilebert/test_modeling_mobilebert.py::MobileBertModelTest::test_torch_fx_output_loss - TypeError: slice indices must be integers or None or have an __...
FAILED tests/models/roberta/test_modeling_roberta.py::RobertaModelTest::test_torch_fx - AssertionError: Couldn't trace module.
FAILED tests/models/roberta/test_modeling_roberta.py::RobertaModelTest::test_torch_fx_output_loss - AssertionError: Couldn't trace module.
FAILED tests/models/t5/test_modeling_t5.py::T5ModelTest::test_torch_fx - AssertionError: Couldn't trace module.
FAILED tests/models/t5/test_modeling_t5.py::T5ModelTest::test_torch_fx_output_loss - AssertionError: Couldn't trace module.

elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
* Add meta proxy

* Uses meta data to trace data dependent control-flow

* Remove commented class

* Handles torch creating functions

* Added type annotation to fix tracing

* Tracing works for everything but T5 and GPT-J

* Almost all previously supported models pass

* All architectures can be traced except T5

* Intermediate commit to have a trace of the comparison operators for HFProxy

* Everything works, except loss computation

* Everything works

* Removed unused import

* Overriden methods do not use underlying ops (linear and torch.matmul), and model attributes are copied to the traced version

* Fix torch_matmul_override

* Change attributes reference to deepcopy

* Remove breakpoint and add torch_index_override

* Small fix

* Fix typo

* Replace asserts by explicit exceptions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants