KEMBAR78
[ONNX] Support None in fx.args as torchlib inputs by titaiwangms · Pull Request #108708 · pytorch/pytorch · GitHub
Skip to content

Conversation

@titaiwangms
Copy link
Collaborator

@titaiwangms titaiwangms commented Sep 6, 2023

Stack from ghstack (oldest at bottom):

Prior to this PR, if None is returned from intermediate nodes, it will crashes the export because None is not expected to be passed into _fill_tensor_shape_type, and raise beartype roar. The function fills in shape and type to TorchScriptTensor according to its info from FX graph.

This is discovered after microsoft/onnxscript#1043 is supported. The op specifically generates None in one of its inputs, but the only output from it being consumed is the first one (not None).

Reference test from a TorchBench model:

    def test_nanogpt(self):
        import sys

        sys.path.append("/home/titaiwang")

        from nanoGPT.model import GPT, GPTConfig

        # Load the model
        kwargs = {
            "block_size": 256,
            "vocab_size": 8096,  # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
            "n_layer": 2,
            "n_head": 2,
            "n_embd": 128,
            "dropout": 0.0,
            "bias": False,  # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
        }
        config = GPTConfig(**kwargs)
        with torch.backends.cuda.sdp_kernel(
            enable_flash=True, enable_mem_efficient=True
        ):
            model = GPT(config)
        print("Done loading model")
        inputs = torch.arange(128).view(2, 64)
        targets = torch.arange(128).view(2, 64)

        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
            model,
            (inputs,),
            input_kwargs={
                "targets": targets,
            },
            verbose=True,
        )

@pytorch-bot pytorch-bot bot added the release notes: onnx torch.onnx related changes that should show up in the release notes label Sep 6, 2023
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 6, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/108708

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit c961589 with merge base bde75eb (image):

UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

titaiwangms added a commit that referenced this pull request Sep 6, 2023
ghstack-source-id: efca75f
Pull Request resolved: #108708
@titaiwangms titaiwangms marked this pull request as draft September 6, 2023 22:05
@titaiwangms titaiwangms added module: onnx Related to torch.onnx topic: improvements topic category labels Sep 11, 2023
Previous to this PR, if None is returned from intermediate nodes, it will crashes the export because None is not expected to be passed into `_fill_tensor_shape_type`, and raise beartype roar. The function fills in shape and type to TorchScriptTensor according to its info from FX graph.

This is discovered after microsoft/onnxscript#1043 is supported. The op specifically generates None in one of its inputs, but the only output from it being consumed is the first one (not None).

Reference test from a TorchBench model:
```python

    def test_nanogpt(self):
        import sys

        sys.path.append("/home/titaiwang")

        from nanoGPT.model import GPT, GPTConfig

        # Load the model
        kwargs = {
            "block_size": 256,
            "vocab_size": 8096,  # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
            "n_layer": 2,
            "n_head": 2,
            "n_embd": 128,
            "dropout": 0.0,
            "bias": False,  # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
        }
        config = GPTConfig(**kwargs)
        with torch.backends.cuda.sdp_kernel(
            enable_flash=True, enable_mem_efficient=True
        ):
            model = GPT(config)
        print("Done loading model")
        inputs = torch.arange(128).view(2, 64)
        targets = torch.arange(128).view(2, 64)

        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
            model,
            (inputs,),
            input_kwargs={
                "targets": targets,
            },
            verbose=True,
        )
```

[ghstack-poisoned]
titaiwangms added a commit that referenced this pull request Sep 12, 2023
ghstack-source-id: 60961c6
Pull Request resolved: #108708
@titaiwangms titaiwangms marked this pull request as ready for review September 12, 2023 20:33
@titaiwangms titaiwangms added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 12, 2023
@justinchuby justinchuby self-assigned this Sep 12, 2023
@justinchuby
Copy link
Collaborator

The op specifically generates None in one of its inputs, but the only output from it being consumed is the first one (not None).

Do you mean an input or output is None? If output which one?

# TODO(titaiwang): set shape?
if isinstance(expected_value, (torch.SymInt, torch.SymFloat, torch.SymBool)):
if expected_value is None:
# There is no shape/type from None.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would link to the example you shown so readers know when this will happen

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible for use to assume it is always a scalar? Or could there be other cases?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm it all depends on what fx.node gives us. This is pure product from fx graph. So I feel like what we should is only taking None into consideration of our code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My guess is that for CPU, those outputs are useless or non-generated, so it returns None.

Prior to this PR, if None is returned from intermediate nodes, it will crashes the export because None is not expected to be passed into `_fill_tensor_shape_type`, and raise beartype roar. The function fills in shape and type to TorchScriptTensor according to its info from FX graph.

This is discovered after microsoft/onnxscript#1043 is supported. The op specifically generates None in one of its inputs, but the only output from it being consumed is the first one (not None).

Reference test from a TorchBench model:
```python

    def test_nanogpt(self):
        import sys

        sys.path.append("/home/titaiwang")

        from nanoGPT.model import GPT, GPTConfig

        # Load the model
        kwargs = {
            "block_size": 256,
            "vocab_size": 8096,  # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
            "n_layer": 2,
            "n_head": 2,
            "n_embd": 128,
            "dropout": 0.0,
            "bias": False,  # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
        }
        config = GPTConfig(**kwargs)
        with torch.backends.cuda.sdp_kernel(
            enable_flash=True, enable_mem_efficient=True
        ):
            model = GPT(config)
        print("Done loading model")
        inputs = torch.arange(128).view(2, 64)
        targets = torch.arange(128).view(2, 64)

        self.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
            model,
            (inputs,),
            input_kwargs={
                "targets": targets,
            },
            verbose=True,
        )
```

[ghstack-poisoned]
titaiwangms added a commit that referenced this pull request Sep 12, 2023
ghstack-source-id: 35b1cb7
Pull Request resolved: #108708
@titaiwangms
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/titaiwangms/46/head branch September 16, 2023 14:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: onnx Related to torch.onnx open source release notes: onnx torch.onnx related changes that should show up in the release notes topic: improvements topic category

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

5 participants