KEMBAR78
Store user model to simplify ONNXProgram.{adapt_torch_*,__call__} APIs by thiagocrepaldi · Pull Request #115281 · pytorch/pytorch · GitHub
Skip to content

Conversation

@thiagocrepaldi
Copy link
Collaborator

@thiagocrepaldi thiagocrepaldi commented Dec 6, 2023

Stack from ghstack (oldest at bottom):

Currently (after #114407), the user has must pass the original user model to APIs such as ONNXProgram.__call__, ONNXProgram.adapt_torch_inputs_to_onnx and ONNXProgram.adapt_torch_outputs_to_onnx APIs.

This was needed because when the model is fakefied, a version of the non-fakefied model is needed so that the Initializers, buffers and constants can be extracted from a real model (and used as input to the ONNX model).
That approach brings an unnecessary usability burden to the user when the model is not fakefied, because the model that was already passed to torch.onnx.dynamo_export could be used to extract state_dict.

This PR adds ONNXProgram._model_torch attribute to store the user model and demote model argument of the aforementioned APIs to optional, only (as opposed to required).

As a result, for the fakefied model scenario, the user still need to pass the required model, but for non fakefied models, the persisted model is implicitly used to extract the model state_dict, making it easier to use.

Currently, the user has must pass the original user ``model`` to
APIs such as ``ONNXProgram.__call__``,
``ONNXProgram.adapt_torch_inputs_to_onnx`` and
``ONNXProgram.adapt_torch_outputs_to_onnx`` APIs.

This is needed because when the model is fakefied, a version of the
non-fakefied model is needed so that the Initializers, buffers and
constants can be extracted (and used as input to the ONNX model).
That approach brings an unnecessary burden to the user when the model is
not fakefied, because the model that was already passed to
``torch.onnx.dynamo_export`` could be used to extract ``state_dict``

This PR adds ``ONNXProgram._model_torch`` to store the user model and
promotes the ``model`` argument to the aforementioned APIs to optional.
As a result, for the fakefied model scenario, the user can still pass
the required model, but for non fakefied models, the persisted model is
implicitly used to extract the model state_dict

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 6, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/115281

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3e8612b with merge base 441ecf0 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…call__} APIs"

Currently, the user has must pass the original user ``model`` to
APIs such as ``ONNXProgram.__call__``,
``ONNXProgram.adapt_torch_inputs_to_onnx`` and
``ONNXProgram.adapt_torch_outputs_to_onnx`` APIs.

This is needed because when the model is fakefied, a version of the
non-fakefied model is needed so that the Initializers, buffers and
constants can be extracted (and used as input to the ONNX model).
That approach brings an unnecessary burden to the user when the model is
not fakefied, because the model that was already passed to
``torch.onnx.dynamo_export`` could be used to extract ``state_dict``

This PR adds ``ONNXProgram._model_torch`` to store the user model and
promotes the ``model`` argument to the aforementioned APIs to optional.
As a result, for the fakefied model scenario, the user can still pass
the required model, but for non fakefied models, the persisted model is
implicitly used to extract the model state_dict

[ghstack-poisoned]
…call__} APIs"

Currently, the user has must pass the original user ``model`` to
APIs such as ``ONNXProgram.__call__``,
``ONNXProgram.adapt_torch_inputs_to_onnx`` and
``ONNXProgram.adapt_torch_outputs_to_onnx`` APIs.

This is needed because when the model is fakefied, a version of the
non-fakefied model is needed so that the Initializers, buffers and
constants can be extracted (and used as input to the ONNX model).
That approach brings an unnecessary burden to the user when the model is
not fakefied, because the model that was already passed to
``torch.onnx.dynamo_export`` could be used to extract ``state_dict``

This PR adds ``ONNXProgram._model_torch`` to store the user model and
promotes the ``model`` argument to the aforementioned APIs to optional.
As a result, for the fakefied model scenario, the user can still pass
the required model, but for non fakefied models, the persisted model is
implicitly used to extract the model state_dict

[ghstack-poisoned]
self,
*model_args,
model: Optional[
model_with_state_dict: Optional[
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: for my understanding, do we need only state_dict or the model?

Copy link
Collaborator Author

@thiagocrepaldi thiagocrepaldi Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need the model, because the lifted "constant tensors" are not part of the ExportedProgram.state_dict (yet ?!?!?!)

I will investigate further and discuss with Meta and check whether I can make that change on ExportedProgram.state_dict , simplifying the requirement to state_dict only (which would be more ideal)

Thiago Crepaldi added 4 commits December 7, 2023 03:15
…call__} APIs"

Currently, the user has must pass the original user ``model`` to
APIs such as ``ONNXProgram.__call__``,
``ONNXProgram.adapt_torch_inputs_to_onnx`` and
``ONNXProgram.adapt_torch_outputs_to_onnx`` APIs.

This is needed because when the model is fakefied, a version of the
non-fakefied model is needed so that the Initializers, buffers and
constants can be extracted (and used as input to the ONNX model).
That approach brings an unnecessary burden to the user when the model is
not fakefied, because the model that was already passed to
``torch.onnx.dynamo_export`` could be used to extract ``state_dict``

This PR adds ``ONNXProgram._model_torch`` to store the user model and
promotes the ``model`` argument to the aforementioned APIs to optional.
As a result, for the fakefied model scenario, the user can still pass
the required model, but for non fakefied models, the persisted model is
implicitly used to extract the model state_dict

[ghstack-poisoned]
…call__} APIs"

Currently, the user has must pass the original user ``model`` to
APIs such as ``ONNXProgram.__call__``,
``ONNXProgram.adapt_torch_inputs_to_onnx`` and
``ONNXProgram.adapt_torch_outputs_to_onnx`` APIs.

This is needed because when the model is fakefied, a version of the
non-fakefied model is needed so that the Initializers, buffers and
constants can be extracted (and used as input to the ONNX model).
That approach brings an unnecessary burden to the user when the model is
not fakefied, because the model that was already passed to
``torch.onnx.dynamo_export`` could be used to extract ``state_dict``

This PR adds ``ONNXProgram._model_torch`` to store the user model and
promotes the ``model`` argument to the aforementioned APIs to optional.
As a result, for the fakefied model scenario, the user can still pass
the required model, but for non fakefied models, the persisted model is
implicitly used to extract the model state_dict

[ghstack-poisoned]
…call__} APIs"

Currently, the user has must pass the original user ``model`` to
APIs such as ``ONNXProgram.__call__``,
``ONNXProgram.adapt_torch_inputs_to_onnx`` and
``ONNXProgram.adapt_torch_outputs_to_onnx`` APIs.

This is needed because when the model is fakefied, a version of the
non-fakefied model is needed so that the Initializers, buffers and
constants can be extracted (and used as input to the ONNX model).
That approach brings an unnecessary burden to the user when the model is
not fakefied, because the model that was already passed to
``torch.onnx.dynamo_export`` could be used to extract ``state_dict``

This PR adds ``ONNXProgram._model_torch`` to store the user model and
promotes the ``model`` argument to the aforementioned APIs to optional.
As a result, for the fakefied model scenario, the user can still pass
the required model, but for non fakefied models, the persisted model is
implicitly used to extract the model state_dict

[ghstack-poisoned]
…call__} APIs"

Currently, the user has must pass the original user ``model`` to
APIs such as ``ONNXProgram.__call__``,
``ONNXProgram.adapt_torch_inputs_to_onnx`` and
``ONNXProgram.adapt_torch_outputs_to_onnx`` APIs.

This is needed because when the model is fakefied, a version of the
non-fakefied model is needed so that the Initializers, buffers and
constants can be extracted (and used as input to the ONNX model).
That approach brings an unnecessary burden to the user when the model is
not fakefied, because the model that was already passed to
``torch.onnx.dynamo_export`` could be used to extract ``state_dict``

This PR adds ``ONNXProgram._model_torch`` to store the user model and
promotes the ``model`` argument to the aforementioned APIs to optional.
As a result, for the fakefied model scenario, the user can still pass
the required model, but for non fakefied models, the persisted model is
implicitly used to extract the model state_dict

[ghstack-poisoned]
Thiago Crepaldi added 2 commits December 7, 2023 16:42
…call__} APIs"

Currently, the user has must pass the original user ``model`` to
APIs such as ``ONNXProgram.__call__``,
``ONNXProgram.adapt_torch_inputs_to_onnx`` and
``ONNXProgram.adapt_torch_outputs_to_onnx`` APIs.

This is needed because when the model is fakefied, a version of the
non-fakefied model is needed so that the Initializers, buffers and
constants can be extracted (and used as input to the ONNX model).
That approach brings an unnecessary burden to the user when the model is
not fakefied, because the model that was already passed to
``torch.onnx.dynamo_export`` could be used to extract ``state_dict``

This PR adds ``ONNXProgram._model_torch`` to store the user model and
promotes the ``model`` argument to the aforementioned APIs to optional.
As a result, for the fakefied model scenario, the user can still pass
the required model, but for non fakefied models, the persisted model is
implicitly used to extract the model state_dict

[ghstack-poisoned]
…call__} APIs"


Currently (after #114407), the user has must pass the original user ``model`` to APIs such as ``ONNXProgram.__call__``, ``ONNXProgram.adapt_torch_inputs_to_onnx`` and ``ONNXProgram.adapt_torch_outputs_to_onnx`` APIs.

This was needed because when the model is fakefied, a version of the non-fakefied model is needed so that the Initializers, buffers and constants can be extracted from a real model (and used as input to the ONNX model).
That approach brings an unnecessary usability burden to the user when the model is not fakefied, because the model that was already passed to ``torch.onnx.dynamo_export`` could be used to extract ``state_dict``.

This PR adds ``ONNXProgram._model_torch`` attribute to store the user model and demote ``model`` argument of the aforementioned APIs to optional, only (as opposed to required).

As a result, for the fakefied model scenario, the user still need to pass the required model, but for non fakefied models, the persisted model is implicitly used to extract the model state_dict, making it easier to use.

[ghstack-poisoned]
_fake_context: Final[Optional[ONNXFakeContext]]
_export_exception: Final[Optional[Exception]]
_model_signature: Final[Optional[torch.export.ExportGraphSignature]]
_model_torch: Final[
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mark as experimental?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tracked by #115461

Copy link
Collaborator

@BowenBao BowenBao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚢 to unblock CI and experiments. Let's keep exploring and revisit #115461

@thiagocrepaldi
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 9, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Dec 9, 2023
pytorchmergebot pushed a commit that referenced this pull request Dec 9, 2023
@facebook-github-bot facebook-github-bot deleted the gh/thiagocrepaldi/15/head branch December 12, 2023 15:30
dmenig pushed a commit to dmenig/pytorch that referenced this pull request Dec 21, 2023
pytorch#115281)

Currently (after pytorch#114407), the user has must pass the original user ``model`` to APIs such as ``ONNXProgram.__call__``, ``ONNXProgram.adapt_torch_inputs_to_onnx`` and ``ONNXProgram.adapt_torch_outputs_to_onnx`` APIs.

This was needed because when the model is fakefied, a version of the non-fakefied model is needed so that the Initializers, buffers and constants can be extracted from a real model (and used as input to the ONNX model).
That approach brings an unnecessary usability burden to the user when the model is not fakefied, because the model that was already passed to ``torch.onnx.dynamo_export`` could be used to extract ``state_dict``.

This PR adds ``ONNXProgram._model_torch`` attribute to store the user model and demote ``model`` argument of the aforementioned APIs to optional, only (as opposed to required).

As a result, for the fakefied model scenario, the user still need to pass the required model, but for non fakefied models, the persisted model is implicitly used to extract the model state_dict, making it easier to use.
Pull Request resolved: pytorch#115281
Approved by: https://github.com/BowenBao
ghstack dependencies: pytorch#114407
dmenig pushed a commit to dmenig/pytorch that referenced this pull request Dec 21, 2023
dmenig pushed a commit to dmenig/pytorch that referenced this pull request Dec 21, 2023
thiagocrepaldi pushed a commit to thiagocrepaldi/pytorch that referenced this pull request Jan 6, 2024
pytorch#115281)

Currently (after pytorch#114407), the user has must pass the original user ``model`` to APIs such as ``ONNXProgram.__call__``, ``ONNXProgram.adapt_torch_inputs_to_onnx`` and ``ONNXProgram.adapt_torch_outputs_to_onnx`` APIs.

This was needed because when the model is fakefied, a version of the non-fakefied model is needed so that the Initializers, buffers and constants can be extracted from a real model (and used as input to the ONNX model).
That approach brings an unnecessary usability burden to the user when the model is not fakefied, because the model that was already passed to ``torch.onnx.dynamo_export`` could be used to extract ``state_dict``.

This PR adds ``ONNXProgram._model_torch`` attribute to store the user model and demote ``model`` argument of the aforementioned APIs to optional, only (as opposed to required).

As a result, for the fakefied model scenario, the user still need to pass the required model, but for non fakefied models, the persisted model is implicitly used to extract the model state_dict, making it easier to use.
Pull Request resolved: pytorch#115281
Approved by: https://github.com/BowenBao
ghstack dependencies: pytorch#114407
huydhn pushed a commit that referenced this pull request Jan 8, 2024
#115281) (#115583)

Currently (after #114407), the user has must pass the original user ``model`` to APIs such as ``ONNXProgram.__call__``, ``ONNXProgram.adapt_torch_inputs_to_onnx`` and ``ONNXProgram.adapt_torch_outputs_to_onnx`` APIs.

This was needed because when the model is fakefied, a version of the non-fakefied model is needed so that the Initializers, buffers and constants can be extracted from a real model (and used as input to the ONNX model).
That approach brings an unnecessary usability burden to the user when the model is not fakefied, because the model that was already passed to ``torch.onnx.dynamo_export`` could be used to extract ``state_dict``.

This PR adds ``ONNXProgram._model_torch`` attribute to store the user model and demote ``model`` argument of the aforementioned APIs to optional, only (as opposed to required).

As a result, for the fakefied model scenario, the user still need to pass the required model, but for non fakefied models, the persisted model is implicitly used to extract the model state_dict, making it easier to use.
Pull Request resolved: #115281
Approved by: https://github.com/BowenBao
ghstack dependencies: #114407
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: onnx torch.onnx related changes that should show up in the release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants