Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pytree] traverse dict in sorted key ordering #114947

Draft
wants to merge 60 commits into
base: gh/XuehaiPan/17/base
Choose a base branch
from

Conversation

XuehaiPan
Copy link
Collaborator

@XuehaiPan XuehaiPan commented Dec 1, 2023

Stack from ghstack (oldest at bottom):

Fixes #114392

Python dict and defaultdict do not take the order of keys into account while comparing two dictionaries.

In [1]: from collections import *

In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1}
Out[2]: True

In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1})
Out[3]: True

In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1})
Out[4]: False

Before this PR, the traversing order of the dict and defaultdict nodes are in insertion order. This means if two equal dicts have the same keys but inserted in different order, the result leaves are different:

In [5]: import torch.utils._pytree as pytree

In [6]: pytree.tree_leaves({'a': 1, 'b': 2})
Out[6]: [1, 2]

In [7]: pytree.tree_leaves({'b': 2, 'a': 1})
Out[7]: [2, 1]

Also we will get different TreeSpec objects because the context of the TreeSpec of a dict node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.

In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2})

In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1})

In [10]: spec1
Out[10]: 
TreeSpec(dict, ['a', 'b'], [*,
  *])

In [11]: spec2
Out[11]: 
TreeSpec(dict, ['b', 'a'], [*,
  *])

In [12]: spec1 == spec2
Out[12]: False

In [13]: spec1.context
Out[13]: ['a', 'b']

In [14]: spec2.context
Out[14]: ['b', 'a']

In [15]: spec1.context == spec2.context
Out[15]: False

Not only contexts differ, but the order of the children_specs is also related.


This PR makes the traversal order of dict / defaultdict follow the sorted key order. This makes the behavior of dict traversal consistent with optree and JAX pytree. It also changed the context for a dictionary:

  • old context: a list of keys in insertion order
  • new context: a list consists of two elements:
    1. a list of sorted keys
    2. a dictionary with the original keys in insertion order, but all values are None. This is used to preserve the original insertion order while doing unflattening.

Some notes of the traversal order for dict type:

  1. PyTorch before this PR: traverse dict in insertion order. Preserve the key order for unflatten(flatten(dict)).
    • Pros:
      • It's intuitive.
      • Do not have overhead for sorting.
      • Do not require the keys to be sortable.
      • Preserve the key order for unflatten(flatten(dict)).
    • Cons:
      • Do not guarantee equal dict get equal leaves and equal treespecs. Might be bad for flattening function keyword arguments (**kwargs).
  2. JAX pytree: traverse dict in sorted order. Unflatten the dict back in sorted order rather than the original insertion order.
    • Pros:
      • Guarantee equal dict get equal leaves and equal treespecs.
    • Cons:
      • It's not intuitive. Need documentation.
      • Have a non-zero overhead for sorting.
      • Require the keys to be sortable.
      • Do not preserve the key order for unflatten(flatten(dict)).
  3. optree: traverse dict in sorted order. Preserve the key order for unflatten(flatten(dict)).
    • Pros:
      • Guarantee equal dict get equal leaves and equal treespecs.
      • Preserve the key order for unflatten(flatten(dict)).
    • Cons:
      • It's not intuitive if users only use tree_flatten or combine d.values() with tree_flatten(d). No concern about tree_map because we will do tree_unflatten in it.
      • Have a non-zero overhead for sorting.

cc @zou3519 @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

Copy link

pytorch-bot bot commented Dec 1, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/114947

Note: Links to docs will display an error until the docs builds have been completed.

❌ 67 New Failures, 45 Unrelated Failures

As of commit 9158851 with merge base 92ca17d (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

XuehaiPan added a commit that referenced this pull request Dec 1, 2023
ghstack-source-id: 0ecd2f2ec0baf9b5a90d7b810876fff9a37f101b
Pull Request resolved: #114947
Fixes #114392

- #114392

Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries.

```python
In [1]: from collections import *

In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1}
Out[2]: True

In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1})
Out[3]: True

In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1})
Out[4]: False
```

Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different:

```python
In [5]: import torch.utils._pytree as pytree

In [6]: pytree.tree_leaves({'a': 1, 'b': 2})
Out[6]: [1, 2]

In [7]: pytree.tree_leaves({'b': 2, 'a': 1})
Out[7]: [2, 1]
```

Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.

```python
In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2})

In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1})

In [10]: spec1
Out[10]: 
TreeSpec(dict, ['a', 'b'], [*,
  *])

In [11]: spec2
Out[11]: 
TreeSpec(dict, ['b', 'a'], [*,
  *])

In [12]: spec1 == spec2
Out[12]: False

In [13]: spec1.context
Out[13]: ['a', 'b']

In [14]: spec2.context
Out[14]: ['b', 'a']

In [15]: spec1.context == spec2.context
Out[15]: False
```

Not only contexts are differ, the order of the `children_specs` is also related.

This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary:

- old context: a list of keys in insertion order
- new context: a list consists of two elements:
    1. a list of sorted keys
    2. a dictionary with the original keys in insertion order, but all values are `None`.

cc zou3519

[ghstack-poisoned]
Fixes #114392

- #114392

Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries.

```python
In [1]: from collections import *

In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1}
Out[2]: True

In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1})
Out[3]: True

In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1})
Out[4]: False
```

Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different:

```python
In [5]: import torch.utils._pytree as pytree

In [6]: pytree.tree_leaves({'a': 1, 'b': 2})
Out[6]: [1, 2]

In [7]: pytree.tree_leaves({'b': 2, 'a': 1})
Out[7]: [2, 1]
```

Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.

```python
In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2})

In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1})

In [10]: spec1
Out[10]: 
TreeSpec(dict, ['a', 'b'], [*,
  *])

In [11]: spec2
Out[11]: 
TreeSpec(dict, ['b', 'a'], [*,
  *])

In [12]: spec1 == spec2
Out[12]: False

In [13]: spec1.context
Out[13]: ['a', 'b']

In [14]: spec2.context
Out[14]: ['b', 'a']

In [15]: spec1.context == spec2.context
Out[15]: False
```

Not only contexts are differ, the order of the `children_specs` is also related.

This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary:

- old context: a list of keys in insertion order
- new context: a list consists of two elements:
    1. a list of sorted keys
    2. a dictionary with the original keys in insertion order, but all values are `None`.

cc zou3519

[ghstack-poisoned]
XuehaiPan added a commit that referenced this pull request Dec 1, 2023
ghstack-source-id: f904d68a47e8a091afa13d26116cd21c16592869
Pull Request resolved: #114947
Fixes #114392

- #114392

Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries.

```python
In [1]: from collections import *

In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1}
Out[2]: True

In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1})
Out[3]: True

In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1})
Out[4]: False
```

Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different:

```python
In [5]: import torch.utils._pytree as pytree

In [6]: pytree.tree_leaves({'a': 1, 'b': 2})
Out[6]: [1, 2]

In [7]: pytree.tree_leaves({'b': 2, 'a': 1})
Out[7]: [2, 1]
```

Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.

```python
In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2})

In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1})

In [10]: spec1
Out[10]: 
TreeSpec(dict, ['a', 'b'], [*,
  *])

In [11]: spec2
Out[11]: 
TreeSpec(dict, ['b', 'a'], [*,
  *])

In [12]: spec1 == spec2
Out[12]: False

In [13]: spec1.context
Out[13]: ['a', 'b']

In [14]: spec2.context
Out[14]: ['b', 'a']

In [15]: spec1.context == spec2.context
Out[15]: False
```

Not only contexts are differ, the order of the `children_specs` is also related.

This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary:

- old context: a list of keys in insertion order
- new context: a list consists of two elements:
    1. a list of sorted keys
    2. a dictionary with the original keys in insertion order, but all values are `None`.

cc zou3519

[ghstack-poisoned]
Fixes #114392

- #114392

Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries.

```python
In [1]: from collections import *

In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1}
Out[2]: True

In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1})
Out[3]: True

In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1})
Out[4]: False
```

Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different:

```python
In [5]: import torch.utils._pytree as pytree

In [6]: pytree.tree_leaves({'a': 1, 'b': 2})
Out[6]: [1, 2]

In [7]: pytree.tree_leaves({'b': 2, 'a': 1})
Out[7]: [2, 1]
```

Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.

```python
In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2})

In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1})

In [10]: spec1
Out[10]: 
TreeSpec(dict, ['a', 'b'], [*,
  *])

In [11]: spec2
Out[11]: 
TreeSpec(dict, ['b', 'a'], [*,
  *])

In [12]: spec1 == spec2
Out[12]: False

In [13]: spec1.context
Out[13]: ['a', 'b']

In [14]: spec2.context
Out[14]: ['b', 'a']

In [15]: spec1.context == spec2.context
Out[15]: False
```

Not only contexts are differ, the order of the `children_specs` is also related.

This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary:

- old context: a list of keys in insertion order
- new context: a list consists of two elements:
    1. a list of sorted keys
    2. a dictionary with the original keys in insertion order, but all values are `None`.

cc zou3519

[ghstack-poisoned]
Fixes #114392

- #114392

Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries.

```python
In [1]: from collections import *

In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1}
Out[2]: True

In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1})
Out[3]: True

In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1})
Out[4]: False
```

Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different:

```python
In [5]: import torch.utils._pytree as pytree

In [6]: pytree.tree_leaves({'a': 1, 'b': 2})
Out[6]: [1, 2]

In [7]: pytree.tree_leaves({'b': 2, 'a': 1})
Out[7]: [2, 1]
```

Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.

```python
In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2})

In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1})

In [10]: spec1
Out[10]: 
TreeSpec(dict, ['a', 'b'], [*,
  *])

In [11]: spec2
Out[11]: 
TreeSpec(dict, ['b', 'a'], [*,
  *])

In [12]: spec1 == spec2
Out[12]: False

In [13]: spec1.context
Out[13]: ['a', 'b']

In [14]: spec2.context
Out[14]: ['b', 'a']

In [15]: spec1.context == spec2.context
Out[15]: False
```

Not only contexts are differ, the order of the `children_specs` is also related.

This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary:

- old context: a list of keys in insertion order
- new context: a list consists of two elements:
    1. a list of sorted keys
    2. a dictionary with the original keys in insertion order, but all values are `None`.

cc zou3519

[ghstack-poisoned]
Fixes #114392

- #114392

Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries.

```python
In [1]: from collections import *

In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1}
Out[2]: True

In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1})
Out[3]: True

In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1})
Out[4]: False
```

Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different:

```python
In [5]: import torch.utils._pytree as pytree

In [6]: pytree.tree_leaves({'a': 1, 'b': 2})
Out[6]: [1, 2]

In [7]: pytree.tree_leaves({'b': 2, 'a': 1})
Out[7]: [2, 1]
```

Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.

```python
In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2})

In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1})

In [10]: spec1
Out[10]: 
TreeSpec(dict, ['a', 'b'], [*,
  *])

In [11]: spec2
Out[11]: 
TreeSpec(dict, ['b', 'a'], [*,
  *])

In [12]: spec1 == spec2
Out[12]: False

In [13]: spec1.context
Out[13]: ['a', 'b']

In [14]: spec2.context
Out[14]: ['b', 'a']

In [15]: spec1.context == spec2.context
Out[15]: False
```

Not only contexts are differ, the order of the `children_specs` is also related.

This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary:

- old context: a list of keys in insertion order
- new context: a list consists of two elements:
    1. a list of sorted keys
    2. a dictionary with the original keys in insertion order, but all values are `None`.

cc zou3519

[ghstack-poisoned]
Fixes #114392

- #114392

Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries.

```python
In [1]: from collections import *

In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1}
Out[2]: True

In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1})
Out[3]: True

In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1})
Out[4]: False
```

Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different:

```python
In [5]: import torch.utils._pytree as pytree

In [6]: pytree.tree_leaves({'a': 1, 'b': 2})
Out[6]: [1, 2]

In [7]: pytree.tree_leaves({'b': 2, 'a': 1})
Out[7]: [2, 1]
```

Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.

```python
In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2})

In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1})

In [10]: spec1
Out[10]: 
TreeSpec(dict, ['a', 'b'], [*,
  *])

In [11]: spec2
Out[11]: 
TreeSpec(dict, ['b', 'a'], [*,
  *])

In [12]: spec1 == spec2
Out[12]: False

In [13]: spec1.context
Out[13]: ['a', 'b']

In [14]: spec2.context
Out[14]: ['b', 'a']

In [15]: spec1.context == spec2.context
Out[15]: False
```

Not only contexts are differ, the order of the `children_specs` is also related.

This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary:

- old context: a list of keys in insertion order
- new context: a list consists of two elements:
    1. a list of sorted keys
    2. a dictionary with the original keys in insertion order, but all values are `None`.

cc zou3519

[ghstack-poisoned]
Fixes #114392

- #114392

Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries.

```python
In [1]: from collections import *

In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1}
Out[2]: True

In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1})
Out[3]: True

In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1})
Out[4]: False
```

Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different:

```python
In [5]: import torch.utils._pytree as pytree

In [6]: pytree.tree_leaves({'a': 1, 'b': 2})
Out[6]: [1, 2]

In [7]: pytree.tree_leaves({'b': 2, 'a': 1})
Out[7]: [2, 1]
```

Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.

```python
In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2})

In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1})

In [10]: spec1
Out[10]: 
TreeSpec(dict, ['a', 'b'], [*,
  *])

In [11]: spec2
Out[11]: 
TreeSpec(dict, ['b', 'a'], [*,
  *])

In [12]: spec1 == spec2
Out[12]: False

In [13]: spec1.context
Out[13]: ['a', 'b']

In [14]: spec2.context
Out[14]: ['b', 'a']

In [15]: spec1.context == spec2.context
Out[15]: False
```

Not only contexts are differ, the order of the `children_specs` is also related.

This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary:

- old context: a list of keys in insertion order
- new context: a list consists of two elements:
    1. a list of sorted keys
    2. a dictionary with the original keys in insertion order, but all values are `None`.

cc zou3519

[ghstack-poisoned]
Fixes #114392

- #114392

Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries.

```python
In [1]: from collections import *

In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1}
Out[2]: True

In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1})
Out[3]: True

In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1})
Out[4]: False
```

Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different:

```python
In [5]: import torch.utils._pytree as pytree

In [6]: pytree.tree_leaves({'a': 1, 'b': 2})
Out[6]: [1, 2]

In [7]: pytree.tree_leaves({'b': 2, 'a': 1})
Out[7]: [2, 1]
```

Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.

```python
In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2})

In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1})

In [10]: spec1
Out[10]: 
TreeSpec(dict, ['a', 'b'], [*,
  *])

In [11]: spec2
Out[11]: 
TreeSpec(dict, ['b', 'a'], [*,
  *])

In [12]: spec1 == spec2
Out[12]: False

In [13]: spec1.context
Out[13]: ['a', 'b']

In [14]: spec2.context
Out[14]: ['b', 'a']

In [15]: spec1.context == spec2.context
Out[15]: False
```

Not only contexts are differ, the order of the `children_specs` is also related.

This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. It also changed the context for a dictionary:

- old context: a list of keys in insertion order
- new context: a list consists of two elements:
    1. a list of sorted keys
    2. a dictionary with the original keys in insertion order, but all values are `None`.

cc zou3519

[ghstack-poisoned]
XuehaiPan added a commit that referenced this pull request Jan 18, 2024
ghstack-source-id: b588d9835d827886846f37eb1375eaaae8bfbddb
Pull Request resolved: #114947
Fixes #114392

- #114392

Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries.

```python
In [1]: from collections import *

In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1}
Out[2]: True

In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1})
Out[3]: True

In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1})
Out[4]: False
```

Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different:

```python
In [5]: import torch.utils._pytree as pytree

In [6]: pytree.tree_leaves({'a': 1, 'b': 2})
Out[6]: [1, 2]

In [7]: pytree.tree_leaves({'b': 2, 'a': 1})
Out[7]: [2, 1]
```

Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.

```python
In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2})

In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1})

In [10]: spec1
Out[10]: 
TreeSpec(dict, ['a', 'b'], [*,
  *])

In [11]: spec2
Out[11]: 
TreeSpec(dict, ['b', 'a'], [*,
  *])

In [12]: spec1 == spec2
Out[12]: False

In [13]: spec1.context
Out[13]: ['a', 'b']

In [14]: spec2.context
Out[14]: ['b', 'a']

In [15]: spec1.context == spec2.context
Out[15]: False
```

Not only contexts differ, but the order of the `children_specs` is also related.

------

This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. This makes the behavior of `dict` traversal consistent with optree and JAX pytree. It also changed the context for a dictionary:

- old context: a list of keys in insertion order
- new context: a list consists of two elements:
    1. a list of sorted keys
    2. a dictionary with the original keys in insertion order, but all values are `None`. This is used to preserve the original insertion order while doing unflattening.

Some notes of the traversal order for `dict` type:

1. PyTorch before this PR: traverse `dict` in insertion order. Preserve the key order for `unflatten(flatten(dict))`.
    - Pros:
        - It's intuitive.
        - Do not have overhead for sorting.
        - Do not require the keys to be sortable.
        - Preserve the key order for `unflatten(flatten(dict))`.
    - Cons:
        - Do not guarantee equal `dict` get equal leaves and equal treespecs. Might be bad for flattening function keyword arguments (`**kwargs`).
2. JAX pytree: traverse `dict` in sorted order. Unflatten the `dict` back in sorted order rather than the original insertion order.
    - Pros:
        - Guarantee equal `dict` get equal leaves and equal treespecs.
    - Cons:
        - It's not intuitive. Need documentation.
        - Have a non-zero overhead for sorting.
        - Require the keys to be sortable.
        - Do not preserve the key order for `unflatten(flatten(dict))`.
3. optree: traverse `dict` in sorted order. Preserve the key order for `unflatten(flatten(dict))`.
    - Pros:
        - Guarantee equal `dict` get equal leaves and equal treespecs.
        - Preserve the key order for `unflatten(flatten(dict))`.
    - Cons:
        - It's not intuitive if users only use `tree_flatten` or combine `d.values()` with `tree_flatten(d)`. No concern about `tree_map` because we will do `tree_unflatten` in it.
        - Have a non-zero overhead for sorting.

cc zou3519 avikchaudhuri gmagogsfm zhxchen17 tugsbayasgalan angelayi suo ydwu4

[ghstack-poisoned]
XuehaiPan added a commit that referenced this pull request Jan 20, 2024
ghstack-source-id: 07ebc925efc5156685f8d7bd99c51cf717b81014
Pull Request resolved: #114947
Fixes #114392

- #114392

Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries.

```python
In [1]: from collections import *

In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1}
Out[2]: True

In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1})
Out[3]: True

In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1})
Out[4]: False
```

Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different:

```python
In [5]: import torch.utils._pytree as pytree

In [6]: pytree.tree_leaves({'a': 1, 'b': 2})
Out[6]: [1, 2]

In [7]: pytree.tree_leaves({'b': 2, 'a': 1})
Out[7]: [2, 1]
```

Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.

```python
In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2})

In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1})

In [10]: spec1
Out[10]: 
TreeSpec(dict, ['a', 'b'], [*,
  *])

In [11]: spec2
Out[11]: 
TreeSpec(dict, ['b', 'a'], [*,
  *])

In [12]: spec1 == spec2
Out[12]: False

In [13]: spec1.context
Out[13]: ['a', 'b']

In [14]: spec2.context
Out[14]: ['b', 'a']

In [15]: spec1.context == spec2.context
Out[15]: False
```

Not only contexts differ, but the order of the `children_specs` is also related.

------

This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. This makes the behavior of `dict` traversal consistent with optree and JAX pytree. It also changed the context for a dictionary:

- old context: a list of keys in insertion order
- new context: a list consists of two elements:
    1. a list of sorted keys
    2. a dictionary with the original keys in insertion order, but all values are `None`. This is used to preserve the original insertion order while doing unflattening.

Some notes of the traversal order for `dict` type:

1. PyTorch before this PR: traverse `dict` in insertion order. Preserve the key order for `unflatten(flatten(dict))`.
    - Pros:
        - It's intuitive.
        - Do not have overhead for sorting.
        - Do not require the keys to be sortable.
        - Preserve the key order for `unflatten(flatten(dict))`.
    - Cons:
        - Do not guarantee equal `dict` get equal leaves and equal treespecs. Might be bad for flattening function keyword arguments (`**kwargs`).
2. JAX pytree: traverse `dict` in sorted order. Unflatten the `dict` back in sorted order rather than the original insertion order.
    - Pros:
        - Guarantee equal `dict` get equal leaves and equal treespecs.
    - Cons:
        - It's not intuitive. Need documentation.
        - Have a non-zero overhead for sorting.
        - Require the keys to be sortable.
        - Do not preserve the key order for `unflatten(flatten(dict))`.
3. optree: traverse `dict` in sorted order. Preserve the key order for `unflatten(flatten(dict))`.
    - Pros:
        - Guarantee equal `dict` get equal leaves and equal treespecs.
        - Preserve the key order for `unflatten(flatten(dict))`.
    - Cons:
        - It's not intuitive if users only use `tree_flatten` or combine `d.values()` with `tree_flatten(d)`. No concern about `tree_map` because we will do `tree_unflatten` in it.
        - Have a non-zero overhead for sorting.

cc zou3519 avikchaudhuri gmagogsfm zhxchen17 tugsbayasgalan angelayi suo ydwu4

[ghstack-poisoned]
Fixes #114392

- #114392

Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries.

```python
In [1]: from collections import *

In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1}
Out[2]: True

In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1})
Out[3]: True

In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1})
Out[4]: False
```

Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different:

```python
In [5]: import torch.utils._pytree as pytree

In [6]: pytree.tree_leaves({'a': 1, 'b': 2})
Out[6]: [1, 2]

In [7]: pytree.tree_leaves({'b': 2, 'a': 1})
Out[7]: [2, 1]
```

Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.

```python
In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2})

In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1})

In [10]: spec1
Out[10]: 
TreeSpec(dict, ['a', 'b'], [*,
  *])

In [11]: spec2
Out[11]: 
TreeSpec(dict, ['b', 'a'], [*,
  *])

In [12]: spec1 == spec2
Out[12]: False

In [13]: spec1.context
Out[13]: ['a', 'b']

In [14]: spec2.context
Out[14]: ['b', 'a']

In [15]: spec1.context == spec2.context
Out[15]: False
```

Not only contexts differ, but the order of the `children_specs` is also related.

------

This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. This makes the behavior of `dict` traversal consistent with optree and JAX pytree. It also changed the context for a dictionary:

- old context: a list of keys in insertion order
- new context: a list consists of two elements:
    1. a list of sorted keys
    2. a dictionary with the original keys in insertion order, but all values are `None`. This is used to preserve the original insertion order while doing unflattening.

Some notes of the traversal order for `dict` type:

1. PyTorch before this PR: traverse `dict` in insertion order. Preserve the key order for `unflatten(flatten(dict))`.
    - Pros:
        - It's intuitive.
        - Do not have overhead for sorting.
        - Do not require the keys to be sortable.
        - Preserve the key order for `unflatten(flatten(dict))`.
    - Cons:
        - Do not guarantee equal `dict` get equal leaves and equal treespecs. Might be bad for flattening function keyword arguments (`**kwargs`).
2. JAX pytree: traverse `dict` in sorted order. Unflatten the `dict` back in sorted order rather than the original insertion order.
    - Pros:
        - Guarantee equal `dict` get equal leaves and equal treespecs.
    - Cons:
        - It's not intuitive. Need documentation.
        - Have a non-zero overhead for sorting.
        - Require the keys to be sortable.
        - Do not preserve the key order for `unflatten(flatten(dict))`.
3. optree: traverse `dict` in sorted order. Preserve the key order for `unflatten(flatten(dict))`.
    - Pros:
        - Guarantee equal `dict` get equal leaves and equal treespecs.
        - Preserve the key order for `unflatten(flatten(dict))`.
    - Cons:
        - It's not intuitive if users only use `tree_flatten` or combine `d.values()` with `tree_flatten(d)`. No concern about `tree_map` because we will do `tree_unflatten` in it.
        - Have a non-zero overhead for sorting.

cc zou3519 avikchaudhuri gmagogsfm zhxchen17 tugsbayasgalan angelayi suo ydwu4

[ghstack-poisoned]
XuehaiPan added a commit that referenced this pull request Jan 27, 2024
ghstack-source-id: bb2e088ae2e68bd6f064bf435ce790c25e7e7aee
Pull Request resolved: #114947
Fixes #114392

- #114392

Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries.

```python
In [1]: from collections import *

In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1}
Out[2]: True

In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1})
Out[3]: True

In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1})
Out[4]: False
```

Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different:

```python
In [5]: import torch.utils._pytree as pytree

In [6]: pytree.tree_leaves({'a': 1, 'b': 2})
Out[6]: [1, 2]

In [7]: pytree.tree_leaves({'b': 2, 'a': 1})
Out[7]: [2, 1]
```

Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.

```python
In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2})

In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1})

In [10]: spec1
Out[10]: 
TreeSpec(dict, ['a', 'b'], [*,
  *])

In [11]: spec2
Out[11]: 
TreeSpec(dict, ['b', 'a'], [*,
  *])

In [12]: spec1 == spec2
Out[12]: False

In [13]: spec1.context
Out[13]: ['a', 'b']

In [14]: spec2.context
Out[14]: ['b', 'a']

In [15]: spec1.context == spec2.context
Out[15]: False
```

Not only contexts differ, but the order of the `children_specs` is also related.

------

This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. This makes the behavior of `dict` traversal consistent with optree and JAX pytree. It also changed the context for a dictionary:

- old context: a list of keys in insertion order
- new context: a list consists of two elements:
    1. a list of sorted keys
    2. a dictionary with the original keys in insertion order, but all values are `None`. This is used to preserve the original insertion order while doing unflattening.

Some notes of the traversal order for `dict` type:

1. PyTorch before this PR: traverse `dict` in insertion order. Preserve the key order for `unflatten(flatten(dict))`.
    - Pros:
        - It's intuitive.
        - Do not have overhead for sorting.
        - Do not require the keys to be sortable.
        - Preserve the key order for `unflatten(flatten(dict))`.
    - Cons:
        - Do not guarantee equal `dict` get equal leaves and equal treespecs. Might be bad for flattening function keyword arguments (`**kwargs`).
2. JAX pytree: traverse `dict` in sorted order. Unflatten the `dict` back in sorted order rather than the original insertion order.
    - Pros:
        - Guarantee equal `dict` get equal leaves and equal treespecs.
    - Cons:
        - It's not intuitive. Need documentation.
        - Have a non-zero overhead for sorting.
        - Require the keys to be sortable.
        - Do not preserve the key order for `unflatten(flatten(dict))`.
3. optree: traverse `dict` in sorted order. Preserve the key order for `unflatten(flatten(dict))`.
    - Pros:
        - Guarantee equal `dict` get equal leaves and equal treespecs.
        - Preserve the key order for `unflatten(flatten(dict))`.
    - Cons:
        - It's not intuitive if users only use `tree_flatten` or combine `d.values()` with `tree_flatten(d)`. No concern about `tree_map` because we will do `tree_unflatten` in it.
        - Have a non-zero overhead for sorting.

cc zou3519 avikchaudhuri gmagogsfm zhxchen17 tugsbayasgalan angelayi suo ydwu4

[ghstack-poisoned]
XuehaiPan added a commit that referenced this pull request Jan 31, 2024
ghstack-source-id: 90afb190a27bf0913f35d6750155447c5dc1b7c8
Pull Request resolved: #114947
Fixes #114392

- #114392

Python `dict` and `defaultdict` do not take the order of keys into account while comparing two dictionaries.

```python
In [1]: from collections import *

In [2]: {'a': 1, 'b': 2} == {'b': 2, 'a': 1}
Out[2]: True

In [3]: defaultdict(int, {'a': 1, 'b': 2}) == defaultdict(int, {'b': 2, 'a': 1})
Out[3]: True

In [4]: OrderedDict({'a': 1, 'b': 2}) == OrderedDict({'b': 2, 'a': 1})
Out[4]: False
```

Before this PR, the traversing order of the `dict` and `defaultdict` nodes are in insertion order. This means if two equal `dict`s have the same keys but inserted in different order, the result leaves are different:

```python
In [5]: import torch.utils._pytree as pytree

In [6]: pytree.tree_leaves({'a': 1, 'b': 2})
Out[6]: [1, 2]

In [7]: pytree.tree_leaves({'b': 2, 'a': 1})
Out[7]: [2, 1]
```

Also we will get different `TreeSpec` objects because the context of the `TreeSpec` of a `dict` node is a list of keys in insertion order. But comparing between these two lists of keys will consider the order of elements.

```python
In [8]: spec1 = pytree.tree_structure({'a': 1, 'b': 2})

In [9]: spec2 = pytree.tree_structure({'b': 2, 'a': 1})

In [10]: spec1
Out[10]: 
TreeSpec(dict, ['a', 'b'], [*,
  *])

In [11]: spec2
Out[11]: 
TreeSpec(dict, ['b', 'a'], [*,
  *])

In [12]: spec1 == spec2
Out[12]: False

In [13]: spec1.context
Out[13]: ['a', 'b']

In [14]: spec2.context
Out[14]: ['b', 'a']

In [15]: spec1.context == spec2.context
Out[15]: False
```

Not only contexts differ, but the order of the `children_specs` is also related.

------

This PR makes the traversal order of `dict` / `defaultdict` follow the sorted key order. This makes the behavior of `dict` traversal consistent with optree and JAX pytree. It also changed the context for a dictionary:

- old context: a list of keys in insertion order
- new context: a list consists of two elements:
    1. a list of sorted keys
    2. a dictionary with the original keys in insertion order, but all values are `None`. This is used to preserve the original insertion order while doing unflattening.

Some notes of the traversal order for `dict` type:

1. PyTorch before this PR: traverse `dict` in insertion order. Preserve the key order for `unflatten(flatten(dict))`.
    - Pros:
        - It's intuitive.
        - Do not have overhead for sorting.
        - Do not require the keys to be sortable.
        - Preserve the key order for `unflatten(flatten(dict))`.
    - Cons:
        - Do not guarantee equal `dict` get equal leaves and equal treespecs. Might be bad for flattening function keyword arguments (`**kwargs`).
2. JAX pytree: traverse `dict` in sorted order. Unflatten the `dict` back in sorted order rather than the original insertion order.
    - Pros:
        - Guarantee equal `dict` get equal leaves and equal treespecs.
    - Cons:
        - It's not intuitive. Need documentation.
        - Have a non-zero overhead for sorting.
        - Require the keys to be sortable.
        - Do not preserve the key order for `unflatten(flatten(dict))`.
3. optree: traverse `dict` in sorted order. Preserve the key order for `unflatten(flatten(dict))`.
    - Pros:
        - Guarantee equal `dict` get equal leaves and equal treespecs.
        - Preserve the key order for `unflatten(flatten(dict))`.
    - Cons:
        - It's not intuitive if users only use `tree_flatten` or combine `d.values()` with `tree_flatten(d)`. No concern about `tree_map` because we will do `tree_unflatten` in it.
        - Have a non-zero overhead for sorting.

cc zou3519 avikchaudhuri gmagogsfm zhxchen17 tugsbayasgalan angelayi suo ydwu4

[ghstack-poisoned]
XuehaiPan added a commit that referenced this pull request Feb 14, 2024
ghstack-source-id: 99f54d2a525a318ab4a7ce00f35e590ddb5139f0
Pull Request resolved: #114947
[ghstack-poisoned]
XuehaiPan added a commit that referenced this pull request Mar 10, 2024
ghstack-source-id: 6eb084ea9fbaceac644cdb2035a6ff311dbc5046
Pull Request resolved: #114947
[ghstack-poisoned]
[ghstack-poisoned]
XuehaiPan added a commit that referenced this pull request Apr 21, 2024
ghstack-source-id: 0d68e27c2105767e8f28100058a8caf8f62e91cc
Pull Request resolved: #114947
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
XuehaiPan added a commit that referenced this pull request May 22, 2024
ghstack-source-id: 4abe7a89ac4c049431d81f0d7093e16dede3307f
Pull Request resolved: #114947
[ghstack-poisoned]
XuehaiPan added a commit that referenced this pull request Jun 21, 2024
ghstack-source-id: fdffd2ca0048f02e6ea77c0091c0dd88458d720f
Pull Request resolved: #114947
[ghstack-poisoned]
XuehaiPan added a commit that referenced this pull request Jun 22, 2024
ghstack-source-id: 4f38eac3dcc4d1e6691a425f9b2b91708451cc6a
Pull Request resolved: #114947
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request keep-going Don't stop on first failure, keep running tests until the end module: pytree oncall: export open source release notes: fx release notes category topic: bc breaking topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG][pytree] equal dicts do not imply equal leaves and equal treespecs
4 participants