KEMBAR78
Fix to() on non-contiguous NJTs by jbschlosser · Pull Request #137124 · pytorch/pytorch · GitHub
Skip to content

Conversation

@jbschlosser
Copy link
Contributor

@jbschlosser jbschlosser commented Oct 1, 2024

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 1, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/137124

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures

As of commit cb7ea56 with merge base 0ccd39a (image):

NEW FAILURES - The following jobs have failed:

  • linux-binary-manywheel / manywheel-py3_9-cuda11_8-test / test (gh)
    RuntimeError: recursive_directory_iterator in used pre-CXX11 binaries, see; ['std::filesystem::recursive_directory_iterator::recursion_pending() const', 'std::filesystem::recursive_directory_iterator::depth() const', 'std::filesystem::recursive_directory_iterator::options() const', 'std::filesystem::recursive_directory_iterator::operator*() const', 'std::filesystem::recursive_directory_iterator::disable_recursion_pending()', 'std::filesystem::recursive_directory_iterator::pop(std::error_code&)', 'std::filesystem::recursive_directory_iterator::pop()', 'std::filesystem::recursive_directory_iterator::pop() [clone .cold]', 'std::filesystem::recursive_directory_iterator::increment(std::error_code&)', 'std::filesystem::recursive_directory_iterator::increment(std::error_code&) [clone .cold]', 'std::filesystem::recursive_directory_iterator::operator=(std::filesystem::recursive_directory_iterator&&)', 'std::filesystem::recursive_directory_iterator::operator=(std::filesystem::recursive_directory_iterator const&)', 'std::filesystem::recursive_directory_iterator::recursive_directory_iterator(std::filesystem::path const&, std::filesystem::directory_options, std::error_code*)', 'std::filesystem::recursive_directory_iterator::recursive_directory_iterator(std::filesystem::path const&, std::filesystem::directory_options, std::error_code*)', 'std::filesystem::recursive_directory_iterator::recursive_directory_iterator(std::filesystem::path const&, std::filesystem::directory_options, std::error_code*) [clone .cold]', 'std::filesystem::recursive_directory_iterator::~recursive_directory_iterator()', 'std::filesystem::recursive_directory_iterator::~recursive_directory_iterator()', 'std::filesystem::recursive_directory_iterator::operator++()', 'std::filesystem::recursive_directory_iterator::operator++() [clone .cold]']
  • linux-binary-manywheel / manywheel-py3_9-cuda12_1-test / test (gh)
    RuntimeError: recursive_directory_iterator in used pre-CXX11 binaries, see; ['std::filesystem::recursive_directory_iterator::recursion_pending() const', 'std::filesystem::recursive_directory_iterator::depth() const', 'std::filesystem::recursive_directory_iterator::options() const', 'std::filesystem::recursive_directory_iterator::operator*() const', 'std::filesystem::recursive_directory_iterator::disable_recursion_pending()', 'std::filesystem::recursive_directory_iterator::pop(std::error_code&)', 'std::filesystem::recursive_directory_iterator::pop()', 'std::filesystem::recursive_directory_iterator::pop() [clone .cold]', 'std::filesystem::recursive_directory_iterator::increment(std::error_code&)', 'std::filesystem::recursive_directory_iterator::increment(std::error_code&) [clone .cold]', 'std::filesystem::recursive_directory_iterator::operator=(std::filesystem::recursive_directory_iterator&&)', 'std::filesystem::recursive_directory_iterator::operator=(std::filesystem::recursive_directory_iterator const&)', 'std::filesystem::recursive_directory_iterator::recursive_directory_iterator(std::filesystem::path const&, std::filesystem::directory_options, std::error_code*)', 'std::filesystem::recursive_directory_iterator::recursive_directory_iterator(std::filesystem::path const&, std::filesystem::directory_options, std::error_code*)', 'std::filesystem::recursive_directory_iterator::recursive_directory_iterator(std::filesystem::path const&, std::filesystem::directory_options, std::error_code*) [clone .cold]', 'std::filesystem::recursive_directory_iterator::~recursive_directory_iterator()', 'std::filesystem::recursive_directory_iterator::~recursive_directory_iterator()', 'std::filesystem::recursive_directory_iterator::operator++()', 'std::filesystem::recursive_directory_iterator::operator++() [clone .cold]']
  • linux-binary-manywheel / manywheel-py3_9-cuda12_4-test / test (gh)
    RuntimeError: recursive_directory_iterator in used pre-CXX11 binaries, see; ['std::filesystem::recursive_directory_iterator::recursion_pending() const', 'std::filesystem::recursive_directory_iterator::depth() const', 'std::filesystem::recursive_directory_iterator::options() const', 'std::filesystem::recursive_directory_iterator::operator*() const', 'std::filesystem::recursive_directory_iterator::disable_recursion_pending()', 'std::filesystem::recursive_directory_iterator::pop(std::error_code&)', 'std::filesystem::recursive_directory_iterator::pop()', 'std::filesystem::recursive_directory_iterator::pop() [clone .cold]', 'std::filesystem::recursive_directory_iterator::increment(std::error_code&)', 'std::filesystem::recursive_directory_iterator::increment(std::error_code&) [clone .cold]', 'std::filesystem::recursive_directory_iterator::operator=(std::filesystem::recursive_directory_iterator&&)', 'std::filesystem::recursive_directory_iterator::operator=(std::filesystem::recursive_directory_iterator const&)', 'std::filesystem::recursive_directory_iterator::recursive_directory_iterator(std::filesystem::path const&, std::filesystem::directory_options, std::error_code*)', 'std::filesystem::recursive_directory_iterator::recursive_directory_iterator(std::filesystem::path const&, std::filesystem::directory_options, std::error_code*)', 'std::filesystem::recursive_directory_iterator::recursive_directory_iterator(std::filesystem::path const&, std::filesystem::directory_options, std::error_code*) [clone .cold]', 'std::filesystem::recursive_directory_iterator::~recursive_directory_iterator()', 'std::filesystem::recursive_directory_iterator::~recursive_directory_iterator()', 'std::filesystem::recursive_directory_iterator::operator++()', 'std::filesystem::recursive_directory_iterator::operator++() [clone .cold]']

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Called out via torchrec integration: `lengths` is not handled properly.

Future work (not related to non-contiguous NJTs): debug torch.compile problem; new nested int is allocated only for compile.

```python
import torch

def f(nt):
    return nt.to(device="cpu")

compiled_f = torch.compile(f)

nt = torch.nested.nested_tensor([
    torch.randn(2, 5),
    torch.randn(3, 5),
    torch.randn(4, 5),
], layout=torch.jagged, device="cuda")

out = f(nt)
out_compile = compiled_f(nt)
print(out.shape, out_compile.shape)
```

```
AssertionError: The values for attribute 'shape' do not match: torch.Size([7, j2]) != torch.Size([7, j1])
```

[ghstack-poisoned]
Called out via torchrec integration: `lengths` is not handled properly.

Future work (not related to non-contiguous NJTs): debug torch.compile problem; new nested int is allocated only for compile.

```python
import torch

def f(nt):
    return nt.to(device="cpu")

compiled_f = torch.compile(f)

nt = torch.nested.nested_tensor([
    torch.randn(2, 5),
    torch.randn(3, 5),
    torch.randn(4, 5),
], layout=torch.jagged, device="cuda")

out = f(nt)
out_compile = compiled_f(nt)
print(out.shape, out_compile.shape)
```

```
AssertionError: The values for attribute 'shape' do not match: torch.Size([7, j2]) != torch.Size([7, j1])
```

[ghstack-poisoned]
Called out via torchrec integration: `lengths` is not handled properly.

Future work (not related to non-contiguous NJTs): debug torch.compile problem; new nested int is allocated only for compile.

```python
import torch

def f(nt):
    return nt.to(device="cpu")

compiled_f = torch.compile(f)

nt = torch.nested.nested_tensor([
    torch.randn(2, 5),
    torch.randn(3, 5),
    torch.randn(4, 5),
], layout=torch.jagged, device="cuda")

out = f(nt)
out_compile = compiled_f(nt)
print(out.shape, out_compile.shape)
```

```
AssertionError: The values for attribute 'shape' do not match: torch.Size([7, j2]) != torch.Size([7, j1])
```

[ghstack-poisoned]
Called out via torchrec integration: `lengths` is not handled properly.

Future work (not related to non-contiguous NJTs): debug torch.compile problem; new nested int is allocated only for compile.

```python
import torch

def f(nt):
    return nt.to(device="cpu")

compiled_f = torch.compile(f)

nt = torch.nested.nested_tensor([
    torch.randn(2, 5),
    torch.randn(3, 5),
    torch.randn(4, 5),
], layout=torch.jagged, device="cuda")

out = f(nt)
out_compile = compiled_f(nt)
print(out.shape, out_compile.shape)
```

```
AssertionError: The values for attribute 'shape' do not match: torch.Size([7, j2]) != torch.Size([7, j1])
```

[ghstack-poisoned]
Called out via torchrec integration: `lengths` is not handled properly.

Future work (not related to non-contiguous NJTs): debug torch.compile problem; new nested int is allocated only for compile.

```python
import torch

def f(nt):
    return nt.to(device="cpu")

compiled_f = torch.compile(f)

nt = torch.nested.nested_tensor([
    torch.randn(2, 5),
    torch.randn(3, 5),
    torch.randn(4, 5),
], layout=torch.jagged, device="cuda")

out = f(nt)
out_compile = compiled_f(nt)
print(out.shape, out_compile.shape)
```

```
AssertionError: The values for attribute 'shape' do not match: torch.Size([7, j2]) != torch.Size([7, j1])
```

[ghstack-poisoned]
Called out via torchrec integration: `lengths` is not handled properly.

Future work (not related to non-contiguous NJTs): debug torch.compile problem; new nested int is allocated only for compile.

```python
import torch

def f(nt):
    return nt.to(device="cpu")

compiled_f = torch.compile(f)

nt = torch.nested.nested_tensor([
    torch.randn(2, 5),
    torch.randn(3, 5),
    torch.randn(4, 5),
], layout=torch.jagged, device="cuda")

out = f(nt)
out_compile = compiled_f(nt)
print(out.shape, out_compile.shape)
```

```
AssertionError: The values for attribute 'shape' do not match: torch.Size([7, j2]) != torch.Size([7, j1])
```

[ghstack-poisoned]
Called out via torchrec integration: `lengths` is not handled properly.

Future work (not related to non-contiguous NJTs): debug torch.compile problem; new nested int is allocated only for compile.

```python
import torch

def f(nt):
    return nt.to(device="cpu")

compiled_f = torch.compile(f)

nt = torch.nested.nested_tensor([
    torch.randn(2, 5),
    torch.randn(3, 5),
    torch.randn(4, 5),
], layout=torch.jagged, device="cuda")

out = f(nt)
out_compile = compiled_f(nt)
print(out.shape, out_compile.shape)
```

```
AssertionError: The values for attribute 'shape' do not match: torch.Size([7, j2]) != torch.Size([7, j1])
```

[ghstack-poisoned]
Called out via torchrec integration: `lengths` is not handled properly.

Future work (not related to non-contiguous NJTs): debug torch.compile problem; new nested int is allocated only for compile.

```python
import torch

def f(nt):
    return nt.to(device="cpu")

compiled_f = torch.compile(f)

nt = torch.nested.nested_tensor([
    torch.randn(2, 5),
    torch.randn(3, 5),
    torch.randn(4, 5),
], layout=torch.jagged, device="cuda")

out = f(nt)
out_compile = compiled_f(nt)
print(out.shape, out_compile.shape)
```

```
AssertionError: The values for attribute 'shape' do not match: torch.Size([7, j2]) != torch.Size([7, j1])
```

[ghstack-poisoned]
@jbschlosser jbschlosser added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 3, 2024
Copy link
Contributor

@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jbschlosser
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 2 jobs have failed, first few of them are: linux-binary-manywheel / manywheel-py3_9-cuda12_4-test / test, linux-binary-manywheel / manywheel-py3_9-cuda12_1-test / test

Details for Dev Infra team Raised by workflow job

@jbschlosser
Copy link
Contributor Author

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 3 checks: linux-binary-manywheel / manywheel-py3_9-cuda12_4-test / test, linux-binary-manywheel / manywheel-py3_9-cuda12_1-test / test, linux-binary-manywheel / manywheel-py3_9-cuda11_8-test / test

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the gh/jbschlosser/185/head branch November 8, 2024 02:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: nested tensor Changes that have a direct impact on nested tensors topic: improvements topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants