-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Store user model to simplify ONNXProgram.{adapt_torch_*,__call__} APIs #115281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Store user model to simplify ONNXProgram.{adapt_torch_*,__call__} APIs #115281
Conversation
Currently, the user has must pass the original user ``model`` to APIs such as ``ONNXProgram.__call__``, ``ONNXProgram.adapt_torch_inputs_to_onnx`` and ``ONNXProgram.adapt_torch_outputs_to_onnx`` APIs. This is needed because when the model is fakefied, a version of the non-fakefied model is needed so that the Initializers, buffers and constants can be extracted (and used as input to the ONNX model). That approach brings an unnecessary burden to the user when the model is not fakefied, because the model that was already passed to ``torch.onnx.dynamo_export`` could be used to extract ``state_dict`` This PR adds ``ONNXProgram._model_torch`` to store the user model and promotes the ``model`` argument to the aforementioned APIs to optional. As a result, for the fakefied model scenario, the user can still pass the required model, but for non fakefied models, the persisted model is implicitly used to extract the model state_dict [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/115281
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3e8612b with merge base 441ecf0 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…call__} APIs" Currently, the user has must pass the original user ``model`` to APIs such as ``ONNXProgram.__call__``, ``ONNXProgram.adapt_torch_inputs_to_onnx`` and ``ONNXProgram.adapt_torch_outputs_to_onnx`` APIs. This is needed because when the model is fakefied, a version of the non-fakefied model is needed so that the Initializers, buffers and constants can be extracted (and used as input to the ONNX model). That approach brings an unnecessary burden to the user when the model is not fakefied, because the model that was already passed to ``torch.onnx.dynamo_export`` could be used to extract ``state_dict`` This PR adds ``ONNXProgram._model_torch`` to store the user model and promotes the ``model`` argument to the aforementioned APIs to optional. As a result, for the fakefied model scenario, the user can still pass the required model, but for non fakefied models, the persisted model is implicitly used to extract the model state_dict [ghstack-poisoned]
…call__} APIs" Currently, the user has must pass the original user ``model`` to APIs such as ``ONNXProgram.__call__``, ``ONNXProgram.adapt_torch_inputs_to_onnx`` and ``ONNXProgram.adapt_torch_outputs_to_onnx`` APIs. This is needed because when the model is fakefied, a version of the non-fakefied model is needed so that the Initializers, buffers and constants can be extracted (and used as input to the ONNX model). That approach brings an unnecessary burden to the user when the model is not fakefied, because the model that was already passed to ``torch.onnx.dynamo_export`` could be used to extract ``state_dict`` This PR adds ``ONNXProgram._model_torch`` to store the user model and promotes the ``model`` argument to the aforementioned APIs to optional. As a result, for the fakefied model scenario, the user can still pass the required model, but for non fakefied models, the persisted model is implicitly used to extract the model state_dict [ghstack-poisoned]
| self, | ||
| *model_args, | ||
| model: Optional[ | ||
| model_with_state_dict: Optional[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: for my understanding, do we need only state_dict or the model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need the model, because the lifted "constant tensors" are not part of the ExportedProgram.state_dict (yet ?!?!?!)
I will investigate further and discuss with Meta and check whether I can make that change on ExportedProgram.state_dict , simplifying the requirement to state_dict only (which would be more ideal)
…call__} APIs" Currently, the user has must pass the original user ``model`` to APIs such as ``ONNXProgram.__call__``, ``ONNXProgram.adapt_torch_inputs_to_onnx`` and ``ONNXProgram.adapt_torch_outputs_to_onnx`` APIs. This is needed because when the model is fakefied, a version of the non-fakefied model is needed so that the Initializers, buffers and constants can be extracted (and used as input to the ONNX model). That approach brings an unnecessary burden to the user when the model is not fakefied, because the model that was already passed to ``torch.onnx.dynamo_export`` could be used to extract ``state_dict`` This PR adds ``ONNXProgram._model_torch`` to store the user model and promotes the ``model`` argument to the aforementioned APIs to optional. As a result, for the fakefied model scenario, the user can still pass the required model, but for non fakefied models, the persisted model is implicitly used to extract the model state_dict [ghstack-poisoned]
…call__} APIs" Currently, the user has must pass the original user ``model`` to APIs such as ``ONNXProgram.__call__``, ``ONNXProgram.adapt_torch_inputs_to_onnx`` and ``ONNXProgram.adapt_torch_outputs_to_onnx`` APIs. This is needed because when the model is fakefied, a version of the non-fakefied model is needed so that the Initializers, buffers and constants can be extracted (and used as input to the ONNX model). That approach brings an unnecessary burden to the user when the model is not fakefied, because the model that was already passed to ``torch.onnx.dynamo_export`` could be used to extract ``state_dict`` This PR adds ``ONNXProgram._model_torch`` to store the user model and promotes the ``model`` argument to the aforementioned APIs to optional. As a result, for the fakefied model scenario, the user can still pass the required model, but for non fakefied models, the persisted model is implicitly used to extract the model state_dict [ghstack-poisoned]
…call__} APIs" Currently, the user has must pass the original user ``model`` to APIs such as ``ONNXProgram.__call__``, ``ONNXProgram.adapt_torch_inputs_to_onnx`` and ``ONNXProgram.adapt_torch_outputs_to_onnx`` APIs. This is needed because when the model is fakefied, a version of the non-fakefied model is needed so that the Initializers, buffers and constants can be extracted (and used as input to the ONNX model). That approach brings an unnecessary burden to the user when the model is not fakefied, because the model that was already passed to ``torch.onnx.dynamo_export`` could be used to extract ``state_dict`` This PR adds ``ONNXProgram._model_torch`` to store the user model and promotes the ``model`` argument to the aforementioned APIs to optional. As a result, for the fakefied model scenario, the user can still pass the required model, but for non fakefied models, the persisted model is implicitly used to extract the model state_dict [ghstack-poisoned]
…call__} APIs" Currently, the user has must pass the original user ``model`` to APIs such as ``ONNXProgram.__call__``, ``ONNXProgram.adapt_torch_inputs_to_onnx`` and ``ONNXProgram.adapt_torch_outputs_to_onnx`` APIs. This is needed because when the model is fakefied, a version of the non-fakefied model is needed so that the Initializers, buffers and constants can be extracted (and used as input to the ONNX model). That approach brings an unnecessary burden to the user when the model is not fakefied, because the model that was already passed to ``torch.onnx.dynamo_export`` could be used to extract ``state_dict`` This PR adds ``ONNXProgram._model_torch`` to store the user model and promotes the ``model`` argument to the aforementioned APIs to optional. As a result, for the fakefied model scenario, the user can still pass the required model, but for non fakefied models, the persisted model is implicitly used to extract the model state_dict [ghstack-poisoned]
…call__} APIs" Currently, the user has must pass the original user ``model`` to APIs such as ``ONNXProgram.__call__``, ``ONNXProgram.adapt_torch_inputs_to_onnx`` and ``ONNXProgram.adapt_torch_outputs_to_onnx`` APIs. This is needed because when the model is fakefied, a version of the non-fakefied model is needed so that the Initializers, buffers and constants can be extracted (and used as input to the ONNX model). That approach brings an unnecessary burden to the user when the model is not fakefied, because the model that was already passed to ``torch.onnx.dynamo_export`` could be used to extract ``state_dict`` This PR adds ``ONNXProgram._model_torch`` to store the user model and promotes the ``model`` argument to the aforementioned APIs to optional. As a result, for the fakefied model scenario, the user can still pass the required model, but for non fakefied models, the persisted model is implicitly used to extract the model state_dict [ghstack-poisoned]
…call__} APIs" Currently (after #114407), the user has must pass the original user ``model`` to APIs such as ``ONNXProgram.__call__``, ``ONNXProgram.adapt_torch_inputs_to_onnx`` and ``ONNXProgram.adapt_torch_outputs_to_onnx`` APIs. This was needed because when the model is fakefied, a version of the non-fakefied model is needed so that the Initializers, buffers and constants can be extracted from a real model (and used as input to the ONNX model). That approach brings an unnecessary usability burden to the user when the model is not fakefied, because the model that was already passed to ``torch.onnx.dynamo_export`` could be used to extract ``state_dict``. This PR adds ``ONNXProgram._model_torch`` attribute to store the user model and demote ``model`` argument of the aforementioned APIs to optional, only (as opposed to required). As a result, for the fakefied model scenario, the user still need to pass the required model, but for non fakefied models, the persisted model is implicitly used to extract the model state_dict, making it easier to use. [ghstack-poisoned]
| _fake_context: Final[Optional[ONNXFakeContext]] | ||
| _export_exception: Final[Optional[Exception]] | ||
| _model_signature: Final[Optional[torch.export.ExportGraphSignature]] | ||
| _model_torch: Final[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mark as experimental?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tracked by #115461
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚢 to unblock CI and experiments. Let's keep exploring and revisit #115461
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
pytorch#115281) Currently (after pytorch#114407), the user has must pass the original user ``model`` to APIs such as ``ONNXProgram.__call__``, ``ONNXProgram.adapt_torch_inputs_to_onnx`` and ``ONNXProgram.adapt_torch_outputs_to_onnx`` APIs. This was needed because when the model is fakefied, a version of the non-fakefied model is needed so that the Initializers, buffers and constants can be extracted from a real model (and used as input to the ONNX model). That approach brings an unnecessary usability burden to the user when the model is not fakefied, because the model that was already passed to ``torch.onnx.dynamo_export`` could be used to extract ``state_dict``. This PR adds ``ONNXProgram._model_torch`` attribute to store the user model and demote ``model`` argument of the aforementioned APIs to optional, only (as opposed to required). As a result, for the fakefied model scenario, the user still need to pass the required model, but for non fakefied models, the persisted model is implicitly used to extract the model state_dict, making it easier to use. Pull Request resolved: pytorch#115281 Approved by: https://github.com/BowenBao ghstack dependencies: pytorch#114407
…orch#114762) Fixed by pytorch#113982 Pull Request resolved: pytorch#114762 Approved by: https://github.com/BowenBao ghstack dependencies: pytorch#114407, pytorch#115281
…ytorch#115353) Pull Request resolved: pytorch#115353 Approved by: https://github.com/BowenBao ghstack dependencies: pytorch#114407, pytorch#115281, pytorch#114762
pytorch#115281) Currently (after pytorch#114407), the user has must pass the original user ``model`` to APIs such as ``ONNXProgram.__call__``, ``ONNXProgram.adapt_torch_inputs_to_onnx`` and ``ONNXProgram.adapt_torch_outputs_to_onnx`` APIs. This was needed because when the model is fakefied, a version of the non-fakefied model is needed so that the Initializers, buffers and constants can be extracted from a real model (and used as input to the ONNX model). That approach brings an unnecessary usability burden to the user when the model is not fakefied, because the model that was already passed to ``torch.onnx.dynamo_export`` could be used to extract ``state_dict``. This PR adds ``ONNXProgram._model_torch`` attribute to store the user model and demote ``model`` argument of the aforementioned APIs to optional, only (as opposed to required). As a result, for the fakefied model scenario, the user still need to pass the required model, but for non fakefied models, the persisted model is implicitly used to extract the model state_dict, making it easier to use. Pull Request resolved: pytorch#115281 Approved by: https://github.com/BowenBao ghstack dependencies: pytorch#114407
#115281) (#115583) Currently (after #114407), the user has must pass the original user ``model`` to APIs such as ``ONNXProgram.__call__``, ``ONNXProgram.adapt_torch_inputs_to_onnx`` and ``ONNXProgram.adapt_torch_outputs_to_onnx`` APIs. This was needed because when the model is fakefied, a version of the non-fakefied model is needed so that the Initializers, buffers and constants can be extracted from a real model (and used as input to the ONNX model). That approach brings an unnecessary usability burden to the user when the model is not fakefied, because the model that was already passed to ``torch.onnx.dynamo_export`` could be used to extract ``state_dict``. This PR adds ``ONNXProgram._model_torch`` attribute to store the user model and demote ``model`` argument of the aforementioned APIs to optional, only (as opposed to required). As a result, for the fakefied model scenario, the user still need to pass the required model, but for non fakefied models, the persisted model is implicitly used to extract the model state_dict, making it easier to use. Pull Request resolved: #115281 Approved by: https://github.com/BowenBao ghstack dependencies: #114407
Stack from ghstack (oldest at bottom):
Currently (after #114407), the user has must pass the original user
modelto APIs such asONNXProgram.__call__,ONNXProgram.adapt_torch_inputs_to_onnxandONNXProgram.adapt_torch_outputs_to_onnxAPIs.This was needed because when the model is fakefied, a version of the non-fakefied model is needed so that the Initializers, buffers and constants can be extracted from a real model (and used as input to the ONNX model).
That approach brings an unnecessary usability burden to the user when the model is not fakefied, because the model that was already passed to
torch.onnx.dynamo_exportcould be used to extractstate_dict.This PR adds
ONNXProgram._model_torchattribute to store the user model and demotemodelargument of the aforementioned APIs to optional, only (as opposed to required).As a result, for the fakefied model scenario, the user still need to pass the required model, but for non fakefied models, the persisted model is implicitly used to extract the model state_dict, making it easier to use.