KEMBAR78
[pytree] traverse `dict` in sorted key ordering by XuehaiPan · Pull Request #114947 · pytorch/pytorch · GitHub
Skip to content

Conversation

@XuehaiPan
Copy link
Collaborator

@XuehaiPan XuehaiPan commented Dec 1, 2023

Stack from ghstack (oldest at bottom):

Fixes #114392

Python dict and defaultdict do not take the order of keys into account while comparing two dictionaries.

In [1]: from collections import *

In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1}
Out[2]: True

In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1})
Out[3]: True

In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1})
Out[4]: False

Before this PR, the traversing order of the dict and defaultdict nodes are in insertion order. This means if two equal dicts have the same keys but inserted in different order, the result leaves are different:

In [5]: import torch.utils._pytree as pytree

In [6]: pytree.tree_leaves({'a': 1, 'b': 2})
Out[6]: [1, 2]

In [7]: pytree.tree_leaves({'b': 2, 'a': 1})
Out[7]: [2, 1]

Also we will get different TreeSpec objects because the context of the TreeSpec of a dict node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.

In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2})

In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1})

In [10]: spec1
Out[10]: 
TreeSpec(dict, ['a', 'b'], [*,
  *])

In [11]: spec2
Out[11]: 
TreeSpec(dict, ['b', 'a'], [*,
  *])

In [12]: spec1 == spec2
Out[12]: False

In [13]: spec1.context
Out[13]: ['a', 'b']

In [14]: spec2.context
Out[14]: ['b', 'a']

In [15]: spec1.context == spec2.context
Out[15]: False

Not only contexts differ, but the order of the children_specs is also related.


This PR makes the traversal order of dict / defaultdict follow the sorted key order. This makes the behavior of dict traversal consistent with optree and JAX pytree. It also changed the context for a dictionary:

  • old context: a list of keys in insertion order
  • new context: a list consists of two elements:
    1. a list of sorted keys
    2. a dictionary with the original keys in insertion order, but all values are None. This is used to preserve the original insertion order while doing unflattening.

Some notes of the traversal order for dict type:

  1. PyTorch before this PR: traverse dict in insertion order. Preserve the key order for unflatten(flatten(dict)).
    • Pros:
      • It's intuitive.
      • Do not have overhead for sorting.
      • Do not require the keys to be sortable.
      • Preserve the key order for unflatten(flatten(dict)).
    • Cons:
      • Do not guarantee equal dict get equal leaves and equal treespecs. Might be bad for flattening function keyword arguments (**kwargs).
  2. JAX pytree: traverse dict in sorted order. Unflatten the dict back in sorted order rather than the original insertion order.
    • Pros:
      • Guarantee equal dict get equal leaves and equal treespecs.
    • Cons:
      • It's not intuitive. Need documentation.
      • Have a non-zero overhead for sorting.
      • Require the keys to be sortable.
      • Do not preserve the key order for unflatten(flatten(dict)).
  3. optree: traverse dict in sorted order. Preserve the key order for unflatten(flatten(dict)).
    • Pros:
      • Guarantee equal dict get equal leaves and equal treespecs.
      • Preserve the key order for unflatten(flatten(dict)).
    • Cons:
      • It's not intuitive if users only use tree_flatten or combine d.values() with tree_flatten(d). No concern about tree_map because we will do tree_unflatten in it.
      • Have a non-zero overhead for sorting.

cc @zou3519 @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @penguinwu

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 1, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/114947

Note: Links to docs will display an error until the docs builds have been completed.

❌ 67 New Failures, 45 Unrelated Failures

As of commit 9158851 with merge base 92ca17d (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

XuehaiPan added a commit that referenced this pull request Dec 1, 2023
ghstack-source-id: 0ecd2f2
Pull Request resolved: #114947
Fixes #114392

- #114392

Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries.

```python
In [1]: from collections import *

In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1}
Out[2]: True

In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1})
Out[3]: True

In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1})
Out[4]: False
```

Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different:

```python
In [5]: import torch.utils._pytree as pytree

In [6]: pytree.tree_leaves({'a': 1, 'b': 2})
Out[6]: [1, 2]

In [7]: pytree.tree_leaves({'b': 2, 'a': 1})
Out[7]: [2, 1]
```

Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.

```python
In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2})

In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1})

In [10]: spec1
Out[10]: 
TreeSpec(dict, ['a', 'b'], [*,
  *])

In [11]: spec2
Out[11]: 
TreeSpec(dict, ['b', 'a'], [*,
  *])

In [12]: spec1 == spec2
Out[12]: False

In [13]: spec1.context
Out[13]: ['a', 'b']

In [14]: spec2.context
Out[14]: ['b', 'a']

In [15]: spec1.context == spec2.context
Out[15]: False
```

Not only contexts are differ, the order of the `children_specs` is also related.

This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary:

- old context: a list of keys in insertion order
- new context: a list consists of two elements:
    1. a list of sorted keys
    2. a dictionary with the original keys in insertion order, but all values are `None`.

cc zou3519

[ghstack-poisoned]
Fixes #114392

- #114392

Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries.

```python
In [1]: from collections import *

In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1}
Out[2]: True

In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1})
Out[3]: True

In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1})
Out[4]: False
```

Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different:

```python
In [5]: import torch.utils._pytree as pytree

In [6]: pytree.tree_leaves({'a': 1, 'b': 2})
Out[6]: [1, 2]

In [7]: pytree.tree_leaves({'b': 2, 'a': 1})
Out[7]: [2, 1]
```

Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.

```python
In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2})

In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1})

In [10]: spec1
Out[10]: 
TreeSpec(dict, ['a', 'b'], [*,
  *])

In [11]: spec2
Out[11]: 
TreeSpec(dict, ['b', 'a'], [*,
  *])

In [12]: spec1 == spec2
Out[12]: False

In [13]: spec1.context
Out[13]: ['a', 'b']

In [14]: spec2.context
Out[14]: ['b', 'a']

In [15]: spec1.context == spec2.context
Out[15]: False
```

Not only contexts are differ, the order of the `children_specs` is also related.

This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary:

- old context: a list of keys in insertion order
- new context: a list consists of two elements:
    1. a list of sorted keys
    2. a dictionary with the original keys in insertion order, but all values are `None`.

cc zou3519

[ghstack-poisoned]
XuehaiPan added a commit that referenced this pull request Dec 1, 2023
ghstack-source-id: f904d68
Pull Request resolved: #114947
Fixes #114392

- #114392

Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries.

```python
In [1]: from collections import *

In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1}
Out[2]: True

In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1})
Out[3]: True

In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1})
Out[4]: False
```

Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different:

```python
In [5]: import torch.utils._pytree as pytree

In [6]: pytree.tree_leaves({'a': 1, 'b': 2})
Out[6]: [1, 2]

In [7]: pytree.tree_leaves({'b': 2, 'a': 1})
Out[7]: [2, 1]
```

Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.

```python
In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2})

In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1})

In [10]: spec1
Out[10]: 
TreeSpec(dict, ['a', 'b'], [*,
  *])

In [11]: spec2
Out[11]: 
TreeSpec(dict, ['b', 'a'], [*,
  *])

In [12]: spec1 == spec2
Out[12]: False

In [13]: spec1.context
Out[13]: ['a', 'b']

In [14]: spec2.context
Out[14]: ['b', 'a']

In [15]: spec1.context == spec2.context
Out[15]: False
```

Not only contexts are differ, the order of the `children_specs` is also related.

This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary:

- old context: a list of keys in insertion order
- new context: a list consists of two elements:
    1. a list of sorted keys
    2. a dictionary with the original keys in insertion order, but all values are `None`.

cc zou3519

[ghstack-poisoned]
Fixes #114392

- #114392

Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries.

```python
In [1]: from collections import *

In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1}
Out[2]: True

In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1})
Out[3]: True

In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1})
Out[4]: False
```

Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different:

```python
In [5]: import torch.utils._pytree as pytree

In [6]: pytree.tree_leaves({'a': 1, 'b': 2})
Out[6]: [1, 2]

In [7]: pytree.tree_leaves({'b': 2, 'a': 1})
Out[7]: [2, 1]
```

Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.

```python
In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2})

In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1})

In [10]: spec1
Out[10]: 
TreeSpec(dict, ['a', 'b'], [*,
  *])

In [11]: spec2
Out[11]: 
TreeSpec(dict, ['b', 'a'], [*,
  *])

In [12]: spec1 == spec2
Out[12]: False

In [13]: spec1.context
Out[13]: ['a', 'b']

In [14]: spec2.context
Out[14]: ['b', 'a']

In [15]: spec1.context == spec2.context
Out[15]: False
```

Not only contexts are differ, the order of the `children_specs` is also related.

This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary:

- old context: a list of keys in insertion order
- new context: a list consists of two elements:
    1. a list of sorted keys
    2. a dictionary with the original keys in insertion order, but all values are `None`.

cc zou3519

[ghstack-poisoned]
Fixes #114392

- #114392

Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries.

```python
In [1]: from collections import *

In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1}
Out[2]: True

In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1})
Out[3]: True

In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1})
Out[4]: False
```

Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different:

```python
In [5]: import torch.utils._pytree as pytree

In [6]: pytree.tree_leaves({'a': 1, 'b': 2})
Out[6]: [1, 2]

In [7]: pytree.tree_leaves({'b': 2, 'a': 1})
Out[7]: [2, 1]
```

Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.

```python
In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2})

In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1})

In [10]: spec1
Out[10]: 
TreeSpec(dict, ['a', 'b'], [*,
  *])

In [11]: spec2
Out[11]: 
TreeSpec(dict, ['b', 'a'], [*,
  *])

In [12]: spec1 == spec2
Out[12]: False

In [13]: spec1.context
Out[13]: ['a', 'b']

In [14]: spec2.context
Out[14]: ['b', 'a']

In [15]: spec1.context == spec2.context
Out[15]: False
```

Not only contexts are differ, the order of the `children_specs` is also related.

This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary:

- old context: a list of keys in insertion order
- new context: a list consists of two elements:
    1. a list of sorted keys
    2. a dictionary with the original keys in insertion order, but all values are `None`.

cc zou3519

[ghstack-poisoned]
Fixes #114392

- #114392

Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries.

```python
In [1]: from collections import *

In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1}
Out[2]: True

In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1})
Out[3]: True

In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1})
Out[4]: False
```

Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different:

```python
In [5]: import torch.utils._pytree as pytree

In [6]: pytree.tree_leaves({'a': 1, 'b': 2})
Out[6]: [1, 2]

In [7]: pytree.tree_leaves({'b': 2, 'a': 1})
Out[7]: [2, 1]
```

Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.

```python
In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2})

In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1})

In [10]: spec1
Out[10]: 
TreeSpec(dict, ['a', 'b'], [*,
  *])

In [11]: spec2
Out[11]: 
TreeSpec(dict, ['b', 'a'], [*,
  *])

In [12]: spec1 == spec2
Out[12]: False

In [13]: spec1.context
Out[13]: ['a', 'b']

In [14]: spec2.context
Out[14]: ['b', 'a']

In [15]: spec1.context == spec2.context
Out[15]: False
```

Not only contexts are differ, the order of the `children_specs` is also related.

This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary:

- old context: a list of keys in insertion order
- new context: a list consists of two elements:
    1. a list of sorted keys
    2. a dictionary with the original keys in insertion order, but all values are `None`.

cc zou3519

[ghstack-poisoned]
Fixes #114392

- #114392

Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries.

```python
In [1]: from collections import *

In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1}
Out[2]: True

In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1})
Out[3]: True

In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1})
Out[4]: False
```

Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different:

```python
In [5]: import torch.utils._pytree as pytree

In [6]: pytree.tree_leaves({'a': 1, 'b': 2})
Out[6]: [1, 2]

In [7]: pytree.tree_leaves({'b': 2, 'a': 1})
Out[7]: [2, 1]
```

Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.

```python
In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2})

In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1})

In [10]: spec1
Out[10]: 
TreeSpec(dict, ['a', 'b'], [*,
  *])

In [11]: spec2
Out[11]: 
TreeSpec(dict, ['b', 'a'], [*,
  *])

In [12]: spec1 == spec2
Out[12]: False

In [13]: spec1.context
Out[13]: ['a', 'b']

In [14]: spec2.context
Out[14]: ['b', 'a']

In [15]: spec1.context == spec2.context
Out[15]: False
```

Not only contexts are differ, the order of the `children_specs` is also related.

This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary:

- old context: a list of keys in insertion order
- new context: a list consists of two elements:
    1. a list of sorted keys
    2. a dictionary with the original keys in insertion order, but all values are `None`.

cc zou3519

[ghstack-poisoned]
Fixes #114392

- #114392

Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries.

```python
In [1]: from collections import *

In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1}
Out[2]: True

In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1})
Out[3]: True

In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1})
Out[4]: False
```

Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different:

```python
In [5]: import torch.utils._pytree as pytree

In [6]: pytree.tree_leaves({'a': 1, 'b': 2})
Out[6]: [1, 2]

In [7]: pytree.tree_leaves({'b': 2, 'a': 1})
Out[7]: [2, 1]
```

Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.

```python
In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2})

In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1})

In [10]: spec1
Out[10]: 
TreeSpec(dict, ['a', 'b'], [*,
  *])

In [11]: spec2
Out[11]: 
TreeSpec(dict, ['b', 'a'], [*,
  *])

In [12]: spec1 == spec2
Out[12]: False

In [13]: spec1.context
Out[13]: ['a', 'b']

In [14]: spec2.context
Out[14]: ['b', 'a']

In [15]: spec1.context == spec2.context
Out[15]: False
```

Not only contexts are differ, the order of the `children_specs` is also related.

This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary:

- old context: a list of keys in insertion order
- new context: a list consists of two elements:
    1. a list of sorted keys
    2. a dictionary with the original keys in insertion order, but all values are `None`.

cc zou3519

[ghstack-poisoned]
Fixes #114392

- #114392

Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries.

```python
In [1]: from collections import *

In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1}
Out[2]: True

In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1})
Out[3]: True

In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1})
Out[4]: False
```

Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different:

```python
In [5]: import torch.utils._pytree as pytree

In [6]: pytree.tree_leaves({'a': 1, 'b': 2})
Out[6]: [1, 2]

In [7]: pytree.tree_leaves({'b': 2, 'a': 1})
Out[7]: [2, 1]
```

Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.

```python
In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2})

In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1})

In [10]: spec1
Out[10]: 
TreeSpec(dict, ['a', 'b'], [*,
  *])

In [11]: spec2
Out[11]: 
TreeSpec(dict, ['b', 'a'], [*,
  *])

In [12]: spec1 == spec2
Out[12]: False

In [13]: spec1.context
Out[13]: ['a', 'b']

In [14]: spec2.context
Out[14]: ['b', 'a']

In [15]: spec1.context == spec2.context
Out[15]: False
```

Not only contexts are differ, the order of the `children_specs` is also related.

This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary:

- old context: a list of keys in insertion order
- new context: a list consists of two elements:
    1. a list of sorted keys
    2. a dictionary with the original keys in insertion order, but all values are `None`.

cc zou3519

[ghstack-poisoned]
Fixes #114392

- #114392

Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries.

```python
In [1]: from collections import *

In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1}
Out[2]: True

In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1})
Out[3]: True

In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1})
Out[4]: False
```

Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different:

```python
In [5]: import torch.utils._pytree as pytree

In [6]: pytree.tree_leaves({'a': 1, 'b': 2})
Out[6]: [1, 2]

In [7]: pytree.tree_leaves({'b': 2, 'a': 1})
Out[7]: [2, 1]
```

Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.

```python
In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2})

In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1})

In [10]: spec1
Out[10]: 
TreeSpec(dict, ['a', 'b'], [*,
  *])

In [11]: spec2
Out[11]: 
TreeSpec(dict, ['b', 'a'], [*,
  *])

In [12]: spec1 == spec2
Out[12]: False

In [13]: spec1.context
Out[13]: ['a', 'b']

In [14]: spec2.context
Out[14]: ['b', 'a']

In [15]: spec1.context == spec2.context
Out[15]: False
```

Not only contexts differ, but the order of the `children_specs` is also related.

------

This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. This makes the behavior of `dict` traversal consistent with optree and JAX pytree. It also changed the context for a dictionary:

- old context: a list of keys in insertion order
- new context: a list consists of two elements:
    1. a list of sorted keys
    2. a dictionary with the original keys in insertion order, but all values are `None`. This is used to preserve the original insertion order while doing unflattening.

Some notes of the traversal order for `dict` type:

1. PyTorch before this PR: traverse `dict` in insertion order. Preserve the key order for `unflatten(flatten(dict))`.
    - Pros:
        - It's intuitive.
        - Do not have overhead for sorting.
        - Do not require the keys to be sortable.
        - Preserve the key order for `unflatten(flatten(dict))`.
    - Cons:
        - Do not guarantee equal `dict` get equal leaves and equal treespecs. Might be bad for flattening function keyword arguments (`**kwargs`).
2. JAX pytree: traverse `dict` in sorted order. Unflatten the `dict` back in sorted order rather than the original insertion order.
    - Pros:
        - Guarantee equal `dict` get equal leaves and equal treespecs.
    - Cons:
        - It's not intuitive. Need documentation.
        - Have a non-zero overhead for sorting.
        - Require the keys to be sortable.
        - Do not preserve the key order for `unflatten(flatten(dict))`.
3. optree: traverse `dict` in sorted order. Preserve the key order for `unflatten(flatten(dict))`.
    - Pros:
        - Guarantee equal `dict` get equal leaves and equal treespecs.
        - Preserve the key order for `unflatten(flatten(dict))`.
    - Cons:
        - It's not intuitive if users only use `tree_flatten` or combine `d.values()` with `tree_flatten(d)`. No concern about `tree_map` because we will do `tree_unflatten` in it.
        - Have a non-zero overhead for sorting.

cc zou3519 avikchaudhuri gmagogsfm zhxchen17 tugsbayasgalan angelayi suo ydwu4

[ghstack-poisoned]
XuehaiPan added a commit that referenced this pull request Jan 27, 2024
ghstack-source-id: bb2e088
Pull Request resolved: #114947
Fixes #114392

- #114392

Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries.

```python
In [1]: from collections import *

In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1}
Out[2]: True

In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1})
Out[3]: True

In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1})
Out[4]: False
```

Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different:

```python
In [5]: import torch.utils._pytree as pytree

In [6]: pytree.tree_leaves({'a': 1, 'b': 2})
Out[6]: [1, 2]

In [7]: pytree.tree_leaves({'b': 2, 'a': 1})
Out[7]: [2, 1]
```

Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.

```python
In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2})

In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1})

In [10]: spec1
Out[10]: 
TreeSpec(dict, ['a', 'b'], [*,
  *])

In [11]: spec2
Out[11]: 
TreeSpec(dict, ['b', 'a'], [*,
  *])

In [12]: spec1 == spec2
Out[12]: False

In [13]: spec1.context
Out[13]: ['a', 'b']

In [14]: spec2.context
Out[14]: ['b', 'a']

In [15]: spec1.context == spec2.context
Out[15]: False
```

Not only contexts differ, but the order of the `children_specs` is also related.

------

This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. This makes the behavior of `dict` traversal consistent with optree and JAX pytree. It also changed the context for a dictionary:

- old context: a list of keys in insertion order
- new context: a list consists of two elements:
    1. a list of sorted keys
    2. a dictionary with the original keys in insertion order, but all values are `None`. This is used to preserve the original insertion order while doing unflattening.

Some notes of the traversal order for `dict` type:

1. PyTorch before this PR: traverse `dict` in insertion order. Preserve the key order for `unflatten(flatten(dict))`.
    - Pros:
        - It's intuitive.
        - Do not have overhead for sorting.
        - Do not require the keys to be sortable.
        - Preserve the key order for `unflatten(flatten(dict))`.
    - Cons:
        - Do not guarantee equal `dict` get equal leaves and equal treespecs. Might be bad for flattening function keyword arguments (`**kwargs`).
2. JAX pytree: traverse `dict` in sorted order. Unflatten the `dict` back in sorted order rather than the original insertion order.
    - Pros:
        - Guarantee equal `dict` get equal leaves and equal treespecs.
    - Cons:
        - It's not intuitive. Need documentation.
        - Have a non-zero overhead for sorting.
        - Require the keys to be sortable.
        - Do not preserve the key order for `unflatten(flatten(dict))`.
3. optree: traverse `dict` in sorted order. Preserve the key order for `unflatten(flatten(dict))`.
    - Pros:
        - Guarantee equal `dict` get equal leaves and equal treespecs.
        - Preserve the key order for `unflatten(flatten(dict))`.
    - Cons:
        - It's not intuitive if users only use `tree_flatten` or combine `d.values()` with `tree_flatten(d)`. No concern about `tree_map` because we will do `tree_unflatten` in it.
        - Have a non-zero overhead for sorting.

cc zou3519 avikchaudhuri gmagogsfm zhxchen17 tugsbayasgalan angelayi suo ydwu4

[ghstack-poisoned]
XuehaiPan added a commit that referenced this pull request Jan 31, 2024
ghstack-source-id: 90afb19
Pull Request resolved: #114947
Fixes #114392

- #114392

Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries.

```python
In [1]: from collections import *

In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1}
Out[2]: True

In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1})
Out[3]: True

In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1})
Out[4]: False
```

Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different:

```python
In [5]: import torch.utils._pytree as pytree

In [6]: pytree.tree_leaves({'a': 1, 'b': 2})
Out[6]: [1, 2]

In [7]: pytree.tree_leaves({'b': 2, 'a': 1})
Out[7]: [2, 1]
```

Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.

```python
In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2})

In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1})

In [10]: spec1
Out[10]: 
TreeSpec(dict, ['a', 'b'], [*,
  *])

In [11]: spec2
Out[11]: 
TreeSpec(dict, ['b', 'a'], [*,
  *])

In [12]: spec1 == spec2
Out[12]: False

In [13]: spec1.context
Out[13]: ['a', 'b']

In [14]: spec2.context
Out[14]: ['b', 'a']

In [15]: spec1.context == spec2.context
Out[15]: False
```

Not only contexts differ, but the order of the `children_specs` is also related.

------

This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. This makes the behavior of `dict` traversal consistent with optree and JAX pytree. It also changed the context for a dictionary:

- old context: a list of keys in insertion order
- new context: a list consists of two elements:
    1. a list of sorted keys
    2. a dictionary with the original keys in insertion order, but all values are `None`. This is used to preserve the original insertion order while doing unflattening.

Some notes of the traversal order for `dict` type:

1. PyTorch before this PR: traverse `dict` in insertion order. Preserve the key order for `unflatten(flatten(dict))`.
    - Pros:
        - It's intuitive.
        - Do not have overhead for sorting.
        - Do not require the keys to be sortable.
        - Preserve the key order for `unflatten(flatten(dict))`.
    - Cons:
        - Do not guarantee equal `dict` get equal leaves and equal treespecs. Might be bad for flattening function keyword arguments (`**kwargs`).
2. JAX pytree: traverse `dict` in sorted order. Unflatten the `dict` back in sorted order rather than the original insertion order.
    - Pros:
        - Guarantee equal `dict` get equal leaves and equal treespecs.
    - Cons:
        - It's not intuitive. Need documentation.
        - Have a non-zero overhead for sorting.
        - Require the keys to be sortable.
        - Do not preserve the key order for `unflatten(flatten(dict))`.
3. optree: traverse `dict` in sorted order. Preserve the key order for `unflatten(flatten(dict))`.
    - Pros:
        - Guarantee equal `dict` get equal leaves and equal treespecs.
        - Preserve the key order for `unflatten(flatten(dict))`.
    - Cons:
        - It's not intuitive if users only use `tree_flatten` or combine `d.values()` with `tree_flatten(d)`. No concern about `tree_map` because we will do `tree_unflatten` in it.
        - Have a non-zero overhead for sorting.

cc zou3519 avikchaudhuri gmagogsfm zhxchen17 tugsbayasgalan angelayi suo ydwu4

[ghstack-poisoned]
XuehaiPan added a commit that referenced this pull request Feb 14, 2024
ghstack-source-id: 99f54d2
Pull Request resolved: #114947
[ghstack-poisoned]
XuehaiPan added a commit that referenced this pull request Mar 10, 2024
ghstack-source-id: 6eb084e
Pull Request resolved: #114947
[ghstack-poisoned]
[ghstack-poisoned]
XuehaiPan added a commit that referenced this pull request Apr 21, 2024
ghstack-source-id: 0d68e27
Pull Request resolved: #114947
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
XuehaiPan added a commit that referenced this pull request May 22, 2024
ghstack-source-id: 4abe7a8
Pull Request resolved: #114947
[ghstack-poisoned]
XuehaiPan added a commit that referenced this pull request Jun 21, 2024
ghstack-source-id: fdffd2c
Pull Request resolved: #114947
[ghstack-poisoned]
XuehaiPan added a commit that referenced this pull request Jun 22, 2024
ghstack-source-id: 4f38eac
Pull Request resolved: #114947
@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Aug 21, 2024
@github-actions github-actions bot closed this Sep 20, 2024
@github-actions github-actions bot deleted the gh/XuehaiPan/17/head branch October 21, 2024 02:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request keep-going Don't stop on first failure, keep running tests until the end module: pytree oncall: export open source release notes: fx release notes category Stale topic: bc breaking topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG][pytree] equal dicts do not imply equal leaves and equal treespecs

4 participants