-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[pytree] traverse dict
in sorted key ordering
#114947
Draft
XuehaiPan
wants to merge
60
commits into
gh/XuehaiPan/17/base
Choose a base branch
from
gh/XuehaiPan/17/head
base: gh/XuehaiPan/17/base
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
+83
−21
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[ghstack-poisoned]
XuehaiPan
added a commit
that referenced
this pull request
Dec 1, 2023
ghstack-source-id: 0ecd2f2ec0baf9b5a90d7b810876fff9a37f101b Pull Request resolved: #114947
Fixes #114392 - #114392 Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries. ```python In [1]: from collections import * In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1} Out[2]: True In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1}) Out[3]: True In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1}) Out[4]: False ``` Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different: ```python In [5]: import torch.utils._pytree as pytree In [6]: pytree.tree_leaves({'a': 1, 'b': 2}) Out[6]: [1, 2] In [7]: pytree.tree_leaves({'b': 2, 'a': 1}) Out[7]: [2, 1] ``` Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements. ```python In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2}) In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1}) In [10]: spec1 Out[10]: TreeSpec(dict, ['a', 'b'], [*, *]) In [11]: spec2 Out[11]: TreeSpec(dict, ['b', 'a'], [*, *]) In [12]: spec1 == spec2 Out[12]: False In [13]: spec1.context Out[13]: ['a', 'b'] In [14]: spec2.context Out[14]: ['b', 'a'] In [15]: spec1.context == spec2.context Out[15]: False ``` Not only contexts are differ, the order of the `children_specs` is also related. This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary: - old context: a list of keys in insertion order - new context: a list consists of two elements: 1. a list of sorted keys 2. a dictionary with the original keys in insertion order, but all values are `None`. cc zou3519 [ghstack-poisoned]
Fixes #114392 - #114392 Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries. ```python In [1]: from collections import * In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1} Out[2]: True In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1}) Out[3]: True In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1}) Out[4]: False ``` Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different: ```python In [5]: import torch.utils._pytree as pytree In [6]: pytree.tree_leaves({'a': 1, 'b': 2}) Out[6]: [1, 2] In [7]: pytree.tree_leaves({'b': 2, 'a': 1}) Out[7]: [2, 1] ``` Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements. ```python In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2}) In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1}) In [10]: spec1 Out[10]: TreeSpec(dict, ['a', 'b'], [*, *]) In [11]: spec2 Out[11]: TreeSpec(dict, ['b', 'a'], [*, *]) In [12]: spec1 == spec2 Out[12]: False In [13]: spec1.context Out[13]: ['a', 'b'] In [14]: spec2.context Out[14]: ['b', 'a'] In [15]: spec1.context == spec2.context Out[15]: False ``` Not only contexts are differ, the order of the `children_specs` is also related. This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary: - old context: a list of keys in insertion order - new context: a list consists of two elements: 1. a list of sorted keys 2. a dictionary with the original keys in insertion order, but all values are `None`. cc zou3519 [ghstack-poisoned]
XuehaiPan
added a commit
that referenced
this pull request
Dec 1, 2023
ghstack-source-id: f904d68a47e8a091afa13d26116cd21c16592869 Pull Request resolved: #114947
XuehaiPan
commented
Dec 1, 2023
Fixes #114392 - #114392 Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries. ```python In [1]: from collections import * In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1} Out[2]: True In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1}) Out[3]: True In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1}) Out[4]: False ``` Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different: ```python In [5]: import torch.utils._pytree as pytree In [6]: pytree.tree_leaves({'a': 1, 'b': 2}) Out[6]: [1, 2] In [7]: pytree.tree_leaves({'b': 2, 'a': 1}) Out[7]: [2, 1] ``` Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements. ```python In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2}) In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1}) In [10]: spec1 Out[10]: TreeSpec(dict, ['a', 'b'], [*, *]) In [11]: spec2 Out[11]: TreeSpec(dict, ['b', 'a'], [*, *]) In [12]: spec1 == spec2 Out[12]: False In [13]: spec1.context Out[13]: ['a', 'b'] In [14]: spec2.context Out[14]: ['b', 'a'] In [15]: spec1.context == spec2.context Out[15]: False ``` Not only contexts are differ, the order of the `children_specs` is also related. This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary: - old context: a list of keys in insertion order - new context: a list consists of two elements: 1. a list of sorted keys 2. a dictionary with the original keys in insertion order, but all values are `None`. cc zou3519 [ghstack-poisoned]
Fixes #114392 - #114392 Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries. ```python In [1]: from collections import * In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1} Out[2]: True In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1}) Out[3]: True In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1}) Out[4]: False ``` Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different: ```python In [5]: import torch.utils._pytree as pytree In [6]: pytree.tree_leaves({'a': 1, 'b': 2}) Out[6]: [1, 2] In [7]: pytree.tree_leaves({'b': 2, 'a': 1}) Out[7]: [2, 1] ``` Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements. ```python In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2}) In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1}) In [10]: spec1 Out[10]: TreeSpec(dict, ['a', 'b'], [*, *]) In [11]: spec2 Out[11]: TreeSpec(dict, ['b', 'a'], [*, *]) In [12]: spec1 == spec2 Out[12]: False In [13]: spec1.context Out[13]: ['a', 'b'] In [14]: spec2.context Out[14]: ['b', 'a'] In [15]: spec1.context == spec2.context Out[15]: False ``` Not only contexts are differ, the order of the `children_specs` is also related. This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary: - old context: a list of keys in insertion order - new context: a list consists of two elements: 1. a list of sorted keys 2. a dictionary with the original keys in insertion order, but all values are `None`. cc zou3519 [ghstack-poisoned]
Fixes #114392 - #114392 Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries. ```python In [1]: from collections import * In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1} Out[2]: True In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1}) Out[3]: True In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1}) Out[4]: False ``` Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different: ```python In [5]: import torch.utils._pytree as pytree In [6]: pytree.tree_leaves({'a': 1, 'b': 2}) Out[6]: [1, 2] In [7]: pytree.tree_leaves({'b': 2, 'a': 1}) Out[7]: [2, 1] ``` Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements. ```python In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2}) In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1}) In [10]: spec1 Out[10]: TreeSpec(dict, ['a', 'b'], [*, *]) In [11]: spec2 Out[11]: TreeSpec(dict, ['b', 'a'], [*, *]) In [12]: spec1 == spec2 Out[12]: False In [13]: spec1.context Out[13]: ['a', 'b'] In [14]: spec2.context Out[14]: ['b', 'a'] In [15]: spec1.context == spec2.context Out[15]: False ``` Not only contexts are differ, the order of the `children_specs` is also related. This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary: - old context: a list of keys in insertion order - new context: a list consists of two elements: 1. a list of sorted keys 2. a dictionary with the original keys in insertion order, but all values are `None`. cc zou3519 [ghstack-poisoned]
Fixes #114392 - #114392 Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries. ```python In [1]: from collections import * In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1} Out[2]: True In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1}) Out[3]: True In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1}) Out[4]: False ``` Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different: ```python In [5]: import torch.utils._pytree as pytree In [6]: pytree.tree_leaves({'a': 1, 'b': 2}) Out[6]: [1, 2] In [7]: pytree.tree_leaves({'b': 2, 'a': 1}) Out[7]: [2, 1] ``` Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements. ```python In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2}) In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1}) In [10]: spec1 Out[10]: TreeSpec(dict, ['a', 'b'], [*, *]) In [11]: spec2 Out[11]: TreeSpec(dict, ['b', 'a'], [*, *]) In [12]: spec1 == spec2 Out[12]: False In [13]: spec1.context Out[13]: ['a', 'b'] In [14]: spec2.context Out[14]: ['b', 'a'] In [15]: spec1.context == spec2.context Out[15]: False ``` Not only contexts are differ, the order of the `children_specs` is also related. This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary: - old context: a list of keys in insertion order - new context: a list consists of two elements: 1. a list of sorted keys 2. a dictionary with the original keys in insertion order, but all values are `None`. cc zou3519 [ghstack-poisoned]
Fixes #114392 - #114392 Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries. ```python In [1]: from collections import * In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1} Out[2]: True In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1}) Out[3]: True In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1}) Out[4]: False ``` Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different: ```python In [5]: import torch.utils._pytree as pytree In [6]: pytree.tree_leaves({'a': 1, 'b': 2}) Out[6]: [1, 2] In [7]: pytree.tree_leaves({'b': 2, 'a': 1}) Out[7]: [2, 1] ``` Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements. ```python In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2}) In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1}) In [10]: spec1 Out[10]: TreeSpec(dict, ['a', 'b'], [*, *]) In [11]: spec2 Out[11]: TreeSpec(dict, ['b', 'a'], [*, *]) In [12]: spec1 == spec2 Out[12]: False In [13]: spec1.context Out[13]: ['a', 'b'] In [14]: spec2.context Out[14]: ['b', 'a'] In [15]: spec1.context == spec2.context Out[15]: False ``` Not only contexts are differ, the order of the `children_specs` is also related. This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary: - old context: a list of keys in insertion order - new context: a list consists of two elements: 1. a list of sorted keys 2. a dictionary with the original keys in insertion order, but all values are `None`. cc zou3519 [ghstack-poisoned]
This was referenced Dec 2, 2023
Closed
Fixes #114392 - #114392 Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries. ```python In [1]: from collections import * In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1} Out[2]: True In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1}) Out[3]: True In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1}) Out[4]: False ``` Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different: ```python In [5]: import torch.utils._pytree as pytree In [6]: pytree.tree_leaves({'a': 1, 'b': 2}) Out[6]: [1, 2] In [7]: pytree.tree_leaves({'b': 2, 'a': 1}) Out[7]: [2, 1] ``` Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements. ```python In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2}) In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1}) In [10]: spec1 Out[10]: TreeSpec(dict, ['a', 'b'], [*, *]) In [11]: spec2 Out[11]: TreeSpec(dict, ['b', 'a'], [*, *]) In [12]: spec1 == spec2 Out[12]: False In [13]: spec1.context Out[13]: ['a', 'b'] In [14]: spec2.context Out[14]: ['b', 'a'] In [15]: spec1.context == spec2.context Out[15]: False ``` Not only contexts are differ, the order of the `children_specs` is also related. This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary: - old context: a list of keys in insertion order - new context: a list consists of two elements: 1. a list of sorted keys 2. a dictionary with the original keys in insertion order, but all values are `None`. cc zou3519 [ghstack-poisoned]
Fixes #114392 - #114392 Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries. ```python In [1]: from collections import * In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1} Out[2]: True In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1}) Out[3]: True In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1}) Out[4]: False ``` Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different: ```python In [5]: import torch.utils._pytree as pytree In [6]: pytree.tree_leaves({'a': 1, 'b': 2}) Out[6]: [1, 2] In [7]: pytree.tree_leaves({'b': 2, 'a': 1}) Out[7]: [2, 1] ``` Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements. ```python In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2}) In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1}) In [10]: spec1 Out[10]: TreeSpec(dict, ['a', 'b'], [*, *]) In [11]: spec2 Out[11]: TreeSpec(dict, ['b', 'a'], [*, *]) In [12]: spec1 == spec2 Out[12]: False In [13]: spec1.context Out[13]: ['a', 'b'] In [14]: spec2.context Out[14]: ['b', 'a'] In [15]: spec1.context == spec2.context Out[15]: False ``` Not only contexts are differ, the order of the `children_specs` is also related. This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary: - old context: a list of keys in insertion order - new context: a list consists of two elements: 1. a list of sorted keys 2. a dictionary with the original keys in insertion order, but all values are `None`. cc zou3519 [ghstack-poisoned]
XuehaiPan
commented
Dec 2, 2023
XuehaiPan
added a commit
that referenced
this pull request
Jan 18, 2024
ghstack-source-id: b588d9835d827886846f37eb1375eaaae8bfbddb Pull Request resolved: #114947
Fixes #114392 - #114392 Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries. ```python In [1]: from collections import * In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1} Out[2]: True In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1}) Out[3]: True In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1}) Out[4]: False ``` Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different: ```python In [5]: import torch.utils._pytree as pytree In [6]: pytree.tree_leaves({'a': 1, 'b': 2}) Out[6]: [1, 2] In [7]: pytree.tree_leaves({'b': 2, 'a': 1}) Out[7]: [2, 1] ``` Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements. ```python In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2}) In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1}) In [10]: spec1 Out[10]: TreeSpec(dict, ['a', 'b'], [*, *]) In [11]: spec2 Out[11]: TreeSpec(dict, ['b', 'a'], [*, *]) In [12]: spec1 == spec2 Out[12]: False In [13]: spec1.context Out[13]: ['a', 'b'] In [14]: spec2.context Out[14]: ['b', 'a'] In [15]: spec1.context == spec2.context Out[15]: False ``` Not only contexts differ, but the order of the `children_specs` is also related. ------ This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. This makes the behavior of `dict` traversal consistent with optree and JAX pytree. It also changed the context for a dictionary: - old context: a list of keys in insertion order - new context: a list consists of two elements: 1. a list of sorted keys 2. a dictionary with the original keys in insertion order, but all values are `None`. This is used to preserve the original insertion order while doing unflattening. Some notes of the traversal order for `dict` type: 1. PyTorch before this PR: traverse `dict` in insertion order. Preserve the key order for `unflatten(flatten(dict))`. - Pros: - It's intuitive. - Do not have overhead for sorting. - Do not require the keys to be sortable. - Preserve the key order for `unflatten(flatten(dict))`. - Cons: - Do not guarantee equal `dict` get equal leaves and equal treespecs. Might be bad for flattening function keyword arguments (`**kwargs`). 2. JAX pytree: traverse `dict` in sorted order. Unflatten the `dict` back in sorted order rather than the original insertion order. - Pros: - Guarantee equal `dict` get equal leaves and equal treespecs. - Cons: - It's not intuitive. Need documentation. - Have a non-zero overhead for sorting. - Require the keys to be sortable. - Do not preserve the key order for `unflatten(flatten(dict))`. 3. optree: traverse `dict` in sorted order. Preserve the key order for `unflatten(flatten(dict))`. - Pros: - Guarantee equal `dict` get equal leaves and equal treespecs. - Preserve the key order for `unflatten(flatten(dict))`. - Cons: - It's not intuitive if users only use `tree_flatten` or combine `d.values()` with `tree_flatten(d)`. No concern about `tree_map` because we will do `tree_unflatten` in it. - Have a non-zero overhead for sorting. cc zou3519 avikchaudhuri gmagogsfm zhxchen17 tugsbayasgalan angelayi suo ydwu4 [ghstack-poisoned]
XuehaiPan
added a commit
that referenced
this pull request
Jan 20, 2024
ghstack-source-id: 07ebc925efc5156685f8d7bd99c51cf717b81014 Pull Request resolved: #114947
Fixes #114392 - #114392 Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries. ```python In [1]: from collections import * In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1} Out[2]: True In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1}) Out[3]: True In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1}) Out[4]: False ``` Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different: ```python In [5]: import torch.utils._pytree as pytree In [6]: pytree.tree_leaves({'a': 1, 'b': 2}) Out[6]: [1, 2] In [7]: pytree.tree_leaves({'b': 2, 'a': 1}) Out[7]: [2, 1] ``` Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements. ```python In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2}) In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1}) In [10]: spec1 Out[10]: TreeSpec(dict, ['a', 'b'], [*, *]) In [11]: spec2 Out[11]: TreeSpec(dict, ['b', 'a'], [*, *]) In [12]: spec1 == spec2 Out[12]: False In [13]: spec1.context Out[13]: ['a', 'b'] In [14]: spec2.context Out[14]: ['b', 'a'] In [15]: spec1.context == spec2.context Out[15]: False ``` Not only contexts differ, but the order of the `children_specs` is also related. ------ This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. This makes the behavior of `dict` traversal consistent with optree and JAX pytree. It also changed the context for a dictionary: - old context: a list of keys in insertion order - new context: a list consists of two elements: 1. a list of sorted keys 2. a dictionary with the original keys in insertion order, but all values are `None`. This is used to preserve the original insertion order while doing unflattening. Some notes of the traversal order for `dict` type: 1. PyTorch before this PR: traverse `dict` in insertion order. Preserve the key order for `unflatten(flatten(dict))`. - Pros: - It's intuitive. - Do not have overhead for sorting. - Do not require the keys to be sortable. - Preserve the key order for `unflatten(flatten(dict))`. - Cons: - Do not guarantee equal `dict` get equal leaves and equal treespecs. Might be bad for flattening function keyword arguments (`**kwargs`). 2. JAX pytree: traverse `dict` in sorted order. Unflatten the `dict` back in sorted order rather than the original insertion order. - Pros: - Guarantee equal `dict` get equal leaves and equal treespecs. - Cons: - It's not intuitive. Need documentation. - Have a non-zero overhead for sorting. - Require the keys to be sortable. - Do not preserve the key order for `unflatten(flatten(dict))`. 3. optree: traverse `dict` in sorted order. Preserve the key order for `unflatten(flatten(dict))`. - Pros: - Guarantee equal `dict` get equal leaves and equal treespecs. - Preserve the key order for `unflatten(flatten(dict))`. - Cons: - It's not intuitive if users only use `tree_flatten` or combine `d.values()` with `tree_flatten(d)`. No concern about `tree_map` because we will do `tree_unflatten` in it. - Have a non-zero overhead for sorting. cc zou3519 avikchaudhuri gmagogsfm zhxchen17 tugsbayasgalan angelayi suo ydwu4 [ghstack-poisoned]
Fixes #114392 - #114392 Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries. ```python In [1]: from collections import * In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1} Out[2]: True In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1}) Out[3]: True In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1}) Out[4]: False ``` Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different: ```python In [5]: import torch.utils._pytree as pytree In [6]: pytree.tree_leaves({'a': 1, 'b': 2}) Out[6]: [1, 2] In [7]: pytree.tree_leaves({'b': 2, 'a': 1}) Out[7]: [2, 1] ``` Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements. ```python In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2}) In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1}) In [10]: spec1 Out[10]: TreeSpec(dict, ['a', 'b'], [*, *]) In [11]: spec2 Out[11]: TreeSpec(dict, ['b', 'a'], [*, *]) In [12]: spec1 == spec2 Out[12]: False In [13]: spec1.context Out[13]: ['a', 'b'] In [14]: spec2.context Out[14]: ['b', 'a'] In [15]: spec1.context == spec2.context Out[15]: False ``` Not only contexts differ, but the order of the `children_specs` is also related. ------ This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. This makes the behavior of `dict` traversal consistent with optree and JAX pytree. It also changed the context for a dictionary: - old context: a list of keys in insertion order - new context: a list consists of two elements: 1. a list of sorted keys 2. a dictionary with the original keys in insertion order, but all values are `None`. This is used to preserve the original insertion order while doing unflattening. Some notes of the traversal order for `dict` type: 1. PyTorch before this PR: traverse `dict` in insertion order. Preserve the key order for `unflatten(flatten(dict))`. - Pros: - It's intuitive. - Do not have overhead for sorting. - Do not require the keys to be sortable. - Preserve the key order for `unflatten(flatten(dict))`. - Cons: - Do not guarantee equal `dict` get equal leaves and equal treespecs. Might be bad for flattening function keyword arguments (`**kwargs`). 2. JAX pytree: traverse `dict` in sorted order. Unflatten the `dict` back in sorted order rather than the original insertion order. - Pros: - Guarantee equal `dict` get equal leaves and equal treespecs. - Cons: - It's not intuitive. Need documentation. - Have a non-zero overhead for sorting. - Require the keys to be sortable. - Do not preserve the key order for `unflatten(flatten(dict))`. 3. optree: traverse `dict` in sorted order. Preserve the key order for `unflatten(flatten(dict))`. - Pros: - Guarantee equal `dict` get equal leaves and equal treespecs. - Preserve the key order for `unflatten(flatten(dict))`. - Cons: - It's not intuitive if users only use `tree_flatten` or combine `d.values()` with `tree_flatten(d)`. No concern about `tree_map` because we will do `tree_unflatten` in it. - Have a non-zero overhead for sorting. cc zou3519 avikchaudhuri gmagogsfm zhxchen17 tugsbayasgalan angelayi suo ydwu4 [ghstack-poisoned]
XuehaiPan
added a commit
that referenced
this pull request
Jan 27, 2024
ghstack-source-id: bb2e088ae2e68bd6f064bf435ce790c25e7e7aee Pull Request resolved: #114947
Fixes #114392 - #114392 Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries. ```python In [1]: from collections import * In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1} Out[2]: True In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1}) Out[3]: True In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1}) Out[4]: False ``` Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different: ```python In [5]: import torch.utils._pytree as pytree In [6]: pytree.tree_leaves({'a': 1, 'b': 2}) Out[6]: [1, 2] In [7]: pytree.tree_leaves({'b': 2, 'a': 1}) Out[7]: [2, 1] ``` Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements. ```python In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2}) In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1}) In [10]: spec1 Out[10]: TreeSpec(dict, ['a', 'b'], [*, *]) In [11]: spec2 Out[11]: TreeSpec(dict, ['b', 'a'], [*, *]) In [12]: spec1 == spec2 Out[12]: False In [13]: spec1.context Out[13]: ['a', 'b'] In [14]: spec2.context Out[14]: ['b', 'a'] In [15]: spec1.context == spec2.context Out[15]: False ``` Not only contexts differ, but the order of the `children_specs` is also related. ------ This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. This makes the behavior of `dict` traversal consistent with optree and JAX pytree. It also changed the context for a dictionary: - old context: a list of keys in insertion order - new context: a list consists of two elements: 1. a list of sorted keys 2. a dictionary with the original keys in insertion order, but all values are `None`. This is used to preserve the original insertion order while doing unflattening. Some notes of the traversal order for `dict` type: 1. PyTorch before this PR: traverse `dict` in insertion order. Preserve the key order for `unflatten(flatten(dict))`. - Pros: - It's intuitive. - Do not have overhead for sorting. - Do not require the keys to be sortable. - Preserve the key order for `unflatten(flatten(dict))`. - Cons: - Do not guarantee equal `dict` get equal leaves and equal treespecs. Might be bad for flattening function keyword arguments (`**kwargs`). 2. JAX pytree: traverse `dict` in sorted order. Unflatten the `dict` back in sorted order rather than the original insertion order. - Pros: - Guarantee equal `dict` get equal leaves and equal treespecs. - Cons: - It's not intuitive. Need documentation. - Have a non-zero overhead for sorting. - Require the keys to be sortable. - Do not preserve the key order for `unflatten(flatten(dict))`. 3. optree: traverse `dict` in sorted order. Preserve the key order for `unflatten(flatten(dict))`. - Pros: - Guarantee equal `dict` get equal leaves and equal treespecs. - Preserve the key order for `unflatten(flatten(dict))`. - Cons: - It's not intuitive if users only use `tree_flatten` or combine `d.values()` with `tree_flatten(d)`. No concern about `tree_map` because we will do `tree_unflatten` in it. - Have a non-zero overhead for sorting. cc zou3519 avikchaudhuri gmagogsfm zhxchen17 tugsbayasgalan angelayi suo ydwu4 [ghstack-poisoned]
XuehaiPan
added a commit
that referenced
this pull request
Jan 31, 2024
ghstack-source-id: 90afb190a27bf0913f35d6750155447c5dc1b7c8 Pull Request resolved: #114947
Fixes #114392 - #114392 Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries. ```python In [1]: from collections import * In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1} Out[2]: True In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1}) Out[3]: True In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1}) Out[4]: False ``` Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different: ```python In [5]: import torch.utils._pytree as pytree In [6]: pytree.tree_leaves({'a': 1, 'b': 2}) Out[6]: [1, 2] In [7]: pytree.tree_leaves({'b': 2, 'a': 1}) Out[7]: [2, 1] ``` Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements. ```python In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2}) In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1}) In [10]: spec1 Out[10]: TreeSpec(dict, ['a', 'b'], [*, *]) In [11]: spec2 Out[11]: TreeSpec(dict, ['b', 'a'], [*, *]) In [12]: spec1 == spec2 Out[12]: False In [13]: spec1.context Out[13]: ['a', 'b'] In [14]: spec2.context Out[14]: ['b', 'a'] In [15]: spec1.context == spec2.context Out[15]: False ``` Not only contexts differ, but the order of the `children_specs` is also related. ------ This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. This makes the behavior of `dict` traversal consistent with optree and JAX pytree. It also changed the context for a dictionary: - old context: a list of keys in insertion order - new context: a list consists of two elements: 1. a list of sorted keys 2. a dictionary with the original keys in insertion order, but all values are `None`. This is used to preserve the original insertion order while doing unflattening. Some notes of the traversal order for `dict` type: 1. PyTorch before this PR: traverse `dict` in insertion order. Preserve the key order for `unflatten(flatten(dict))`. - Pros: - It's intuitive. - Do not have overhead for sorting. - Do not require the keys to be sortable. - Preserve the key order for `unflatten(flatten(dict))`. - Cons: - Do not guarantee equal `dict` get equal leaves and equal treespecs. Might be bad for flattening function keyword arguments (`**kwargs`). 2. JAX pytree: traverse `dict` in sorted order. Unflatten the `dict` back in sorted order rather than the original insertion order. - Pros: - Guarantee equal `dict` get equal leaves and equal treespecs. - Cons: - It's not intuitive. Need documentation. - Have a non-zero overhead for sorting. - Require the keys to be sortable. - Do not preserve the key order for `unflatten(flatten(dict))`. 3. optree: traverse `dict` in sorted order. Preserve the key order for `unflatten(flatten(dict))`. - Pros: - Guarantee equal `dict` get equal leaves and equal treespecs. - Preserve the key order for `unflatten(flatten(dict))`. - Cons: - It's not intuitive if users only use `tree_flatten` or combine `d.values()` with `tree_flatten(d)`. No concern about `tree_map` because we will do `tree_unflatten` in it. - Have a non-zero overhead for sorting. cc zou3519 avikchaudhuri gmagogsfm zhxchen17 tugsbayasgalan angelayi suo ydwu4 [ghstack-poisoned]
XuehaiPan
added a commit
that referenced
this pull request
Feb 14, 2024
ghstack-source-id: 99f54d2a525a318ab4a7ce00f35e590ddb5139f0 Pull Request resolved: #114947
XuehaiPan
added a commit
that referenced
this pull request
Mar 10, 2024
ghstack-source-id: 6eb084ea9fbaceac644cdb2035a6ff311dbc5046 Pull Request resolved: #114947
XuehaiPan
added a commit
that referenced
this pull request
Apr 21, 2024
ghstack-source-id: 0d68e27c2105767e8f28100058a8caf8f62e91cc Pull Request resolved: #114947
XuehaiPan
added a commit
that referenced
this pull request
May 22, 2024
ghstack-source-id: 4abe7a89ac4c049431d81f0d7093e16dede3307f Pull Request resolved: #114947
XuehaiPan
added a commit
that referenced
this pull request
Jun 21, 2024
ghstack-source-id: fdffd2ca0048f02e6ea77c0091c0dd88458d720f Pull Request resolved: #114947
XuehaiPan
added a commit
that referenced
this pull request
Jun 22, 2024
ghstack-source-id: 4f38eac3dcc4d1e6691a425f9b2b91708451cc6a Pull Request resolved: #114947
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
ciflow/inductor
ciflow/trunk
Trigger trunk jobs on your pull request
keep-going
Don't stop on first failure, keep running tests until the end
module: pytree
oncall: export
open source
release notes: fx
release notes category
topic: bc breaking
topic category
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Stack from ghstack (oldest at bottom):
dict
in sorted key ordering #114947context
andchildren_specs
as private implementation details #116375children_specs
access #116374Fixes #114392
dict
s do not imply equal leaves and equal treespecs #114392Python
dict
anddefaultdict
do not take the order of keys into account while comparing two dictionaries.Before this PR, the traversing order of the
dict
anddefaultdict
nodes are in insertion order. This means if two equaldict
s have the same keys but inserted in different order, the result leaves are different:Also we will get different
TreeSpec
objects because the context of theTreeSpec
of adict
node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.Not only contexts differ, but the order of the
children_specs
is also related.This PR makes the traversal order of
dict
/defaultdict
follow the sorted key order. This makes the behavior ofdict
traversal consistent with optree and JAX pytree. It also changed the context for a dictionary:None
. This is used to preserve the original insertion order while doing unflattening.Some notes of the traversal order for
dict
type:dict
in insertion order. Preserve the key order forunflatten(flatten(dict))
.unflatten(flatten(dict))
.dict
get equal leaves and equal treespecs. Might be bad for flattening function keyword arguments (**kwargs
).dict
in sorted order. Unflatten thedict
back in sorted order rather than the original insertion order.dict
get equal leaves and equal treespecs.unflatten(flatten(dict))
.dict
in sorted order. Preserve the key order forunflatten(flatten(dict))
.dict
get equal leaves and equal treespecs.unflatten(flatten(dict))
.tree_flatten
or combined.values()
withtree_flatten(d)
. No concern abouttree_map
because we will dotree_unflatten
in it.cc @zou3519 @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4