-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[pytree] traverse dict in sorted key ordering
#114947
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
Fixes #114392 - #114392 Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries. ```python In [1]: from collections import * In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1} Out[2]: True In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1}) Out[3]: True In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1}) Out[4]: False ``` Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different: ```python In [5]: import torch.utils._pytree as pytree In [6]: pytree.tree_leaves({'a': 1, 'b': 2}) Out[6]: [1, 2] In [7]: pytree.tree_leaves({'b': 2, 'a': 1}) Out[7]: [2, 1] ``` Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements. ```python In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2}) In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1}) In [10]: spec1 Out[10]: TreeSpec(dict, ['a', 'b'], [*, *]) In [11]: spec2 Out[11]: TreeSpec(dict, ['b', 'a'], [*, *]) In [12]: spec1 == spec2 Out[12]: False In [13]: spec1.context Out[13]: ['a', 'b'] In [14]: spec2.context Out[14]: ['b', 'a'] In [15]: spec1.context == spec2.context Out[15]: False ``` Not only contexts are differ, the order of the `children_specs` is also related. This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary: - old context: a list of keys in insertion order - new context: a list consists of two elements: 1. a list of sorted keys 2. a dictionary with the original keys in insertion order, but all values are `None`. cc zou3519 [ghstack-poisoned]
Fixes #114392 - #114392 Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries. ```python In [1]: from collections import * In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1} Out[2]: True In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1}) Out[3]: True In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1}) Out[4]: False ``` Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different: ```python In [5]: import torch.utils._pytree as pytree In [6]: pytree.tree_leaves({'a': 1, 'b': 2}) Out[6]: [1, 2] In [7]: pytree.tree_leaves({'b': 2, 'a': 1}) Out[7]: [2, 1] ``` Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements. ```python In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2}) In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1}) In [10]: spec1 Out[10]: TreeSpec(dict, ['a', 'b'], [*, *]) In [11]: spec2 Out[11]: TreeSpec(dict, ['b', 'a'], [*, *]) In [12]: spec1 == spec2 Out[12]: False In [13]: spec1.context Out[13]: ['a', 'b'] In [14]: spec2.context Out[14]: ['b', 'a'] In [15]: spec1.context == spec2.context Out[15]: False ``` Not only contexts are differ, the order of the `children_specs` is also related. This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary: - old context: a list of keys in insertion order - new context: a list consists of two elements: 1. a list of sorted keys 2. a dictionary with the original keys in insertion order, but all values are `None`. cc zou3519 [ghstack-poisoned]
Fixes #114392 - #114392 Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries. ```python In [1]: from collections import * In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1} Out[2]: True In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1}) Out[3]: True In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1}) Out[4]: False ``` Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different: ```python In [5]: import torch.utils._pytree as pytree In [6]: pytree.tree_leaves({'a': 1, 'b': 2}) Out[6]: [1, 2] In [7]: pytree.tree_leaves({'b': 2, 'a': 1}) Out[7]: [2, 1] ``` Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements. ```python In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2}) In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1}) In [10]: spec1 Out[10]: TreeSpec(dict, ['a', 'b'], [*, *]) In [11]: spec2 Out[11]: TreeSpec(dict, ['b', 'a'], [*, *]) In [12]: spec1 == spec2 Out[12]: False In [13]: spec1.context Out[13]: ['a', 'b'] In [14]: spec2.context Out[14]: ['b', 'a'] In [15]: spec1.context == spec2.context Out[15]: False ``` Not only contexts are differ, the order of the `children_specs` is also related. This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary: - old context: a list of keys in insertion order - new context: a list consists of two elements: 1. a list of sorted keys 2. a dictionary with the original keys in insertion order, but all values are `None`. cc zou3519 [ghstack-poisoned]
Fixes #114392 - #114392 Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries. ```python In [1]: from collections import * In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1} Out[2]: True In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1}) Out[3]: True In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1}) Out[4]: False ``` Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different: ```python In [5]: import torch.utils._pytree as pytree In [6]: pytree.tree_leaves({'a': 1, 'b': 2}) Out[6]: [1, 2] In [7]: pytree.tree_leaves({'b': 2, 'a': 1}) Out[7]: [2, 1] ``` Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements. ```python In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2}) In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1}) In [10]: spec1 Out[10]: TreeSpec(dict, ['a', 'b'], [*, *]) In [11]: spec2 Out[11]: TreeSpec(dict, ['b', 'a'], [*, *]) In [12]: spec1 == spec2 Out[12]: False In [13]: spec1.context Out[13]: ['a', 'b'] In [14]: spec2.context Out[14]: ['b', 'a'] In [15]: spec1.context == spec2.context Out[15]: False ``` Not only contexts are differ, the order of the `children_specs` is also related. This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary: - old context: a list of keys in insertion order - new context: a list consists of two elements: 1. a list of sorted keys 2. a dictionary with the original keys in insertion order, but all values are `None`. cc zou3519 [ghstack-poisoned]
Fixes #114392 - #114392 Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries. ```python In [1]: from collections import * In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1} Out[2]: True In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1}) Out[3]: True In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1}) Out[4]: False ``` Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different: ```python In [5]: import torch.utils._pytree as pytree In [6]: pytree.tree_leaves({'a': 1, 'b': 2}) Out[6]: [1, 2] In [7]: pytree.tree_leaves({'b': 2, 'a': 1}) Out[7]: [2, 1] ``` Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements. ```python In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2}) In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1}) In [10]: spec1 Out[10]: TreeSpec(dict, ['a', 'b'], [*, *]) In [11]: spec2 Out[11]: TreeSpec(dict, ['b', 'a'], [*, *]) In [12]: spec1 == spec2 Out[12]: False In [13]: spec1.context Out[13]: ['a', 'b'] In [14]: spec2.context Out[14]: ['b', 'a'] In [15]: spec1.context == spec2.context Out[15]: False ``` Not only contexts are differ, the order of the `children_specs` is also related. This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary: - old context: a list of keys in insertion order - new context: a list consists of two elements: 1. a list of sorted keys 2. a dictionary with the original keys in insertion order, but all values are `None`. cc zou3519 [ghstack-poisoned]
Fixes #114392 - #114392 Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries. ```python In [1]: from collections import * In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1} Out[2]: True In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1}) Out[3]: True In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1}) Out[4]: False ``` Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different: ```python In [5]: import torch.utils._pytree as pytree In [6]: pytree.tree_leaves({'a': 1, 'b': 2}) Out[6]: [1, 2] In [7]: pytree.tree_leaves({'b': 2, 'a': 1}) Out[7]: [2, 1] ``` Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements. ```python In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2}) In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1}) In [10]: spec1 Out[10]: TreeSpec(dict, ['a', 'b'], [*, *]) In [11]: spec2 Out[11]: TreeSpec(dict, ['b', 'a'], [*, *]) In [12]: spec1 == spec2 Out[12]: False In [13]: spec1.context Out[13]: ['a', 'b'] In [14]: spec2.context Out[14]: ['b', 'a'] In [15]: spec1.context == spec2.context Out[15]: False ``` Not only contexts are differ, the order of the `children_specs` is also related. This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary: - old context: a list of keys in insertion order - new context: a list consists of two elements: 1. a list of sorted keys 2. a dictionary with the original keys in insertion order, but all values are `None`. cc zou3519 [ghstack-poisoned]
Fixes #114392 - #114392 Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries. ```python In [1]: from collections import * In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1} Out[2]: True In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1}) Out[3]: True In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1}) Out[4]: False ``` Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different: ```python In [5]: import torch.utils._pytree as pytree In [6]: pytree.tree_leaves({'a': 1, 'b': 2}) Out[6]: [1, 2] In [7]: pytree.tree_leaves({'b': 2, 'a': 1}) Out[7]: [2, 1] ``` Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements. ```python In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2}) In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1}) In [10]: spec1 Out[10]: TreeSpec(dict, ['a', 'b'], [*, *]) In [11]: spec2 Out[11]: TreeSpec(dict, ['b', 'a'], [*, *]) In [12]: spec1 == spec2 Out[12]: False In [13]: spec1.context Out[13]: ['a', 'b'] In [14]: spec2.context Out[14]: ['b', 'a'] In [15]: spec1.context == spec2.context Out[15]: False ``` Not only contexts are differ, the order of the `children_specs` is also related. This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary: - old context: a list of keys in insertion order - new context: a list consists of two elements: 1. a list of sorted keys 2. a dictionary with the original keys in insertion order, but all values are `None`. cc zou3519 [ghstack-poisoned]
Fixes #114392 - #114392 Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries. ```python In [1]: from collections import * In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1} Out[2]: True In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1}) Out[3]: True In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1}) Out[4]: False ``` Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different: ```python In [5]: import torch.utils._pytree as pytree In [6]: pytree.tree_leaves({'a': 1, 'b': 2}) Out[6]: [1, 2] In [7]: pytree.tree_leaves({'b': 2, 'a': 1}) Out[7]: [2, 1] ``` Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements. ```python In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2}) In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1}) In [10]: spec1 Out[10]: TreeSpec(dict, ['a', 'b'], [*, *]) In [11]: spec2 Out[11]: TreeSpec(dict, ['b', 'a'], [*, *]) In [12]: spec1 == spec2 Out[12]: False In [13]: spec1.context Out[13]: ['a', 'b'] In [14]: spec2.context Out[14]: ['b', 'a'] In [15]: spec1.context == spec2.context Out[15]: False ``` Not only contexts are differ, the order of the `children_specs` is also related. This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary: - old context: a list of keys in insertion order - new context: a list consists of two elements: 1. a list of sorted keys 2. a dictionary with the original keys in insertion order, but all values are `None`. cc zou3519 [ghstack-poisoned]
Fixes #114392 - #114392 Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries. ```python In [1]: from collections import * In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1} Out[2]: True In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1}) Out[3]: True In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1}) Out[4]: False ``` Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different: ```python In [5]: import torch.utils._pytree as pytree In [6]: pytree.tree_leaves({'a': 1, 'b': 2}) Out[6]: [1, 2] In [7]: pytree.tree_leaves({'b': 2, 'a': 1}) Out[7]: [2, 1] ``` Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements. ```python In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2}) In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1}) In [10]: spec1 Out[10]: TreeSpec(dict, ['a', 'b'], [*, *]) In [11]: spec2 Out[11]: TreeSpec(dict, ['b', 'a'], [*, *]) In [12]: spec1 == spec2 Out[12]: False In [13]: spec1.context Out[13]: ['a', 'b'] In [14]: spec2.context Out[14]: ['b', 'a'] In [15]: spec1.context == spec2.context Out[15]: False ``` Not only contexts are differ, the order of the `children_specs` is also related. This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary: - old context: a list of keys in insertion order - new context: a list consists of two elements: 1. a list of sorted keys 2. a dictionary with the original keys in insertion order, but all values are `None`. cc zou3519 [ghstack-poisoned]
Fixes #114392 - #114392 Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries. ```python In [1]: from collections import * In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1} Out[2]: True In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1}) Out[3]: True In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1}) Out[4]: False ``` Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different: ```python In [5]: import torch.utils._pytree as pytree In [6]: pytree.tree_leaves({'a': 1, 'b': 2}) Out[6]: [1, 2] In [7]: pytree.tree_leaves({'b': 2, 'a': 1}) Out[7]: [2, 1] ``` Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements. ```python In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2}) In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1}) In [10]: spec1 Out[10]: TreeSpec(dict, ['a', 'b'], [*, *]) In [11]: spec2 Out[11]: TreeSpec(dict, ['b', 'a'], [*, *]) In [12]: spec1 == spec2 Out[12]: False In [13]: spec1.context Out[13]: ['a', 'b'] In [14]: spec2.context Out[14]: ['b', 'a'] In [15]: spec1.context == spec2.context Out[15]: False ``` Not only contexts differ, but the order of the `children_specs` is also related. ------ This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. This makes the behavior of `dict` traversal consistent with optree and JAX pytree. It also changed the context for a dictionary: - old context: a list of keys in insertion order - new context: a list consists of two elements: 1. a list of sorted keys 2. a dictionary with the original keys in insertion order, but all values are `None`. This is used to preserve the original insertion order while doing unflattening. Some notes of the traversal order for `dict` type: 1. PyTorch before this PR: traverse `dict` in insertion order. Preserve the key order for `unflatten(flatten(dict))`. - Pros: - It's intuitive. - Do not have overhead for sorting. - Do not require the keys to be sortable. - Preserve the key order for `unflatten(flatten(dict))`. - Cons: - Do not guarantee equal `dict` get equal leaves and equal treespecs. Might be bad for flattening function keyword arguments (`**kwargs`). 2. JAX pytree: traverse `dict` in sorted order. Unflatten the `dict` back in sorted order rather than the original insertion order. - Pros: - Guarantee equal `dict` get equal leaves and equal treespecs. - Cons: - It's not intuitive. Need documentation. - Have a non-zero overhead for sorting. - Require the keys to be sortable. - Do not preserve the key order for `unflatten(flatten(dict))`. 3. optree: traverse `dict` in sorted order. Preserve the key order for `unflatten(flatten(dict))`. - Pros: - Guarantee equal `dict` get equal leaves and equal treespecs. - Preserve the key order for `unflatten(flatten(dict))`. - Cons: - It's not intuitive if users only use `tree_flatten` or combine `d.values()` with `tree_flatten(d)`. No concern about `tree_map` because we will do `tree_unflatten` in it. - Have a non-zero overhead for sorting. cc zou3519 avikchaudhuri gmagogsfm zhxchen17 tugsbayasgalan angelayi suo ydwu4 [ghstack-poisoned]
Fixes #114392 - #114392 Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries. ```python In [1]: from collections import * In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1} Out[2]: True In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1}) Out[3]: True In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1}) Out[4]: False ``` Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different: ```python In [5]: import torch.utils._pytree as pytree In [6]: pytree.tree_leaves({'a': 1, 'b': 2}) Out[6]: [1, 2] In [7]: pytree.tree_leaves({'b': 2, 'a': 1}) Out[7]: [2, 1] ``` Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements. ```python In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2}) In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1}) In [10]: spec1 Out[10]: TreeSpec(dict, ['a', 'b'], [*, *]) In [11]: spec2 Out[11]: TreeSpec(dict, ['b', 'a'], [*, *]) In [12]: spec1 == spec2 Out[12]: False In [13]: spec1.context Out[13]: ['a', 'b'] In [14]: spec2.context Out[14]: ['b', 'a'] In [15]: spec1.context == spec2.context Out[15]: False ``` Not only contexts differ, but the order of the `children_specs` is also related. ------ This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. This makes the behavior of `dict` traversal consistent with optree and JAX pytree. It also changed the context for a dictionary: - old context: a list of keys in insertion order - new context: a list consists of two elements: 1. a list of sorted keys 2. a dictionary with the original keys in insertion order, but all values are `None`. This is used to preserve the original insertion order while doing unflattening. Some notes of the traversal order for `dict` type: 1. PyTorch before this PR: traverse `dict` in insertion order. Preserve the key order for `unflatten(flatten(dict))`. - Pros: - It's intuitive. - Do not have overhead for sorting. - Do not require the keys to be sortable. - Preserve the key order for `unflatten(flatten(dict))`. - Cons: - Do not guarantee equal `dict` get equal leaves and equal treespecs. Might be bad for flattening function keyword arguments (`**kwargs`). 2. JAX pytree: traverse `dict` in sorted order. Unflatten the `dict` back in sorted order rather than the original insertion order. - Pros: - Guarantee equal `dict` get equal leaves and equal treespecs. - Cons: - It's not intuitive. Need documentation. - Have a non-zero overhead for sorting. - Require the keys to be sortable. - Do not preserve the key order for `unflatten(flatten(dict))`. 3. optree: traverse `dict` in sorted order. Preserve the key order for `unflatten(flatten(dict))`. - Pros: - Guarantee equal `dict` get equal leaves and equal treespecs. - Preserve the key order for `unflatten(flatten(dict))`. - Cons: - It's not intuitive if users only use `tree_flatten` or combine `d.values()` with `tree_flatten(d)`. No concern about `tree_map` because we will do `tree_unflatten` in it. - Have a non-zero overhead for sorting. cc zou3519 avikchaudhuri gmagogsfm zhxchen17 tugsbayasgalan angelayi suo ydwu4 [ghstack-poisoned]
Fixes #114392 - #114392 Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries. ```python In [1]: from collections import * In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1} Out[2]: True In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1}) Out[3]: True In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1}) Out[4]: False ``` Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different: ```python In [5]: import torch.utils._pytree as pytree In [6]: pytree.tree_leaves({'a': 1, 'b': 2}) Out[6]: [1, 2] In [7]: pytree.tree_leaves({'b': 2, 'a': 1}) Out[7]: [2, 1] ``` Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements. ```python In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2}) In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1}) In [10]: spec1 Out[10]: TreeSpec(dict, ['a', 'b'], [*, *]) In [11]: spec2 Out[11]: TreeSpec(dict, ['b', 'a'], [*, *]) In [12]: spec1 == spec2 Out[12]: False In [13]: spec1.context Out[13]: ['a', 'b'] In [14]: spec2.context Out[14]: ['b', 'a'] In [15]: spec1.context == spec2.context Out[15]: False ``` Not only contexts differ, but the order of the `children_specs` is also related. ------ This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. This makes the behavior of `dict` traversal consistent with optree and JAX pytree. It also changed the context for a dictionary: - old context: a list of keys in insertion order - new context: a list consists of two elements: 1. a list of sorted keys 2. a dictionary with the original keys in insertion order, but all values are `None`. This is used to preserve the original insertion order while doing unflattening. Some notes of the traversal order for `dict` type: 1. PyTorch before this PR: traverse `dict` in insertion order. Preserve the key order for `unflatten(flatten(dict))`. - Pros: - It's intuitive. - Do not have overhead for sorting. - Do not require the keys to be sortable. - Preserve the key order for `unflatten(flatten(dict))`. - Cons: - Do not guarantee equal `dict` get equal leaves and equal treespecs. Might be bad for flattening function keyword arguments (`**kwargs`). 2. JAX pytree: traverse `dict` in sorted order. Unflatten the `dict` back in sorted order rather than the original insertion order. - Pros: - Guarantee equal `dict` get equal leaves and equal treespecs. - Cons: - It's not intuitive. Need documentation. - Have a non-zero overhead for sorting. - Require the keys to be sortable. - Do not preserve the key order for `unflatten(flatten(dict))`. 3. optree: traverse `dict` in sorted order. Preserve the key order for `unflatten(flatten(dict))`. - Pros: - Guarantee equal `dict` get equal leaves and equal treespecs. - Preserve the key order for `unflatten(flatten(dict))`. - Cons: - It's not intuitive if users only use `tree_flatten` or combine `d.values()` with `tree_flatten(d)`. No concern about `tree_map` because we will do `tree_unflatten` in it. - Have a non-zero overhead for sorting. cc zou3519 avikchaudhuri gmagogsfm zhxchen17 tugsbayasgalan angelayi suo ydwu4 [ghstack-poisoned]
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Stack from ghstack (oldest at bottom):
dictin sorted key ordering #114947contextandchildren_specsas private implementation details #116375children_specsaccess #116374Fixes #114392
dicts do not imply equal leaves and equal treespecs #114392Python
dictanddefaultdictdo not take the order of keys into account while comparing two dictionaries.Before this PR, the traversing order of the
dictanddefaultdictnodes are in insertion order. This means if two equaldicts have the same keys but inserted in different order, the result leaves are different:Also we will get different
TreeSpecobjects because the context of theTreeSpecof adictnode is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.Not only contexts differ, but the order of the
children_specsis also related.This PR makes the traversal order of
dict/defaultdictfollow the sorted key order. This makes the behavior ofdicttraversal consistent with optree and JAX pytree. It also changed the context for a dictionary:None. This is used to preserve the original insertion order while doing unflattening.Some notes of the traversal order for
dicttype:dictin insertion order. Preserve the key order forunflatten(flatten(dict)).unflatten(flatten(dict)).dictget equal leaves and equal treespecs. Might be bad for flattening function keyword arguments (**kwargs).dictin sorted order. Unflatten thedictback in sorted order rather than the original insertion order.dictget equal leaves and equal treespecs.unflatten(flatten(dict)).dictin sorted order. Preserve the key order forunflatten(flatten(dict)).dictget equal leaves and equal treespecs.unflatten(flatten(dict)).tree_flattenor combined.values()withtree_flatten(d). No concern abouttree_mapbecause we will dotree_unflattenin it.cc @zou3519 @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @penguinwu