Skip to content

Commit

Permalink
[fix] allow saving python attr on Tensor and Parameter via torch.save (
Browse files Browse the repository at this point in the history
…pytorch#81616)

Fixes: pytorch#72129

TODO:
* [x] Fix for Parameter

Benchmark
(Measurable diff for small tensors)
```
[-------------- Save and Load --------------]
                    |  After PR  |  Before PR
1 threads: ----------------------------------
      ()            |    111.7   |     106.9
      (4, 4)        |    114.4   |     109.2
      (128, 128)    |    135.2   |     128.3
      (1024, 1024)  |   1431.9   |    1431.3

Times are in microseconds (us).
```

<details>

<summary> Benchmark Script </summary>

```python
import torch
from torch.testing._internal.common_utils import BytesIOContext
from torch.utils import benchmark
import pickle

shapes = ((), (4, 4), (128, 128), (1024, 1024))

sizes = [1, 64, 1024, 10000]
results = []

def save_load_fn(t):
    with BytesIOContext() as f:
        torch.save(t, f)
        f.seek(0)
        torch.load(f)

for shape in shapes:
    t = torch.randn(shape)
    label = 'Save and Load'
    sub_label = f'{shape}'
    results.append(benchmark.Timer(
        stmt='save_load_fn(t)',
        globals={'t': t, 'save_load_fn':save_load_fn},
        label=label,
        sub_label=sub_label,
        description='Before PR',
    ).blocked_autorange(min_run_time=2))

compare = benchmark.Compare(results)
compare.print()

with open('before_pr.pkl', 'wb') as f:
    pickle.dump(results, f)

# with open('after_pr.pkl', 'rb') as f:
#     after_pr = pickle.load(f)

# with open('before_pr.pkl', 'rb') as f:
#     before_pr = pickle.load(f)

# compare = benchmark.Compare(after_pr + before_pr)
# compare.print()
```

</details>

NOTE : **BC-Breaking** : After this PR, all tensors (also regular tensors) will be serialised using `_rebuild_from_type_v2`.

Pull Request resolved: pytorch#81616
Approved by: https://github.com/albanD, https://github.com/kurtamohler
  • Loading branch information
kshitij12345 authored and kulinseth committed Dec 9, 2022
1 parent 46fa6d8 commit dc5e3c7
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 39 deletions.
22 changes: 22 additions & 0 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,28 @@ def test_meta_serialization(self, weights_only):

self.assertEqual(state['weight'].size(), big_model.weight.size())

def test_serialization_python_attr(self):
def _test_save_load_attr(t):
t.foo = 'foo'
t.pi = 3.14

with BytesIOContext() as f:
torch.save(t, f)
f.seek(0)
loaded_t = torch.load(f)

self.assertEqual(t, loaded_t)
self.assertEqual(t.foo, loaded_t.foo)
self.assertEqual(t.pi, loaded_t.pi)

t = torch.zeros(3, 3)
_test_save_load_attr(t)
# This should start failing once Parameter
# supports saving Python Attribute.
err_msg = "'Parameter' object has no attribute"
with self.assertRaisesRegex(AttributeError, err_msg):
_test_save_load_attr(torch.nn.Parameter(t))

def test_weights_only_assert(self):
class HelloWorld:
def __reduce__(self):
Expand Down
43 changes: 4 additions & 39 deletions torch/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@ def _rebuild_from_type(func, type, args, dict):


def _rebuild_from_type_v2(func, new_type, args, state):
if new_type is Tensor:
return func(*args)

ret = func(*args)
if type(ret) is not new_type:
ret = ret.as_subclass(new_type)
Expand All @@ -70,21 +67,7 @@ def _rebuild_from_type_v2(func, new_type, args, state):
):
ret.__setstate__(state)
else:
if isinstance(state, tuple):
if not len(state) == 2:
raise RuntimeError(f"Invalid serialized state: {state}")
dict_state = state[0]
slots_state = state[1]
else:
dict_state = state
slots_state = None

for k, v in dict_state.items():
setattr(ret, k, v)

if slots_state:
for k, v in slots_state.items():
setattr(ret, k, v)
ret = torch._utils._set_obj_state(ret, state)
return ret


Expand Down Expand Up @@ -223,31 +206,13 @@ def __deepcopy__(self, memo):
return new_tensor

def __reduce_ex__(self, proto):
if type(self) is Tensor:
state = torch._utils._get_obj_state(self)
if type(self) is Tensor and not state:
# Fast path for regular tensor without Python state.
return self._reduce_ex_internal(proto)
if has_torch_function_unary(self):
return handle_torch_function(Tensor.__reduce_ex__, (self,), self, proto)
func, args = self._reduce_ex_internal(proto)
# Get the state of the python subclass
# This loosely mimicks the function on the object class but since Tensor do not inherit
# from it, we cannot call that function directly
# https://github.com/python/cpython/blob/c83919bd635f4433f1c6ae8504996a9fe3c215e5/Objects/typeobject.c#L4891
getstate_fn = getattr(self, "__getstate__", None)
if getstate_fn:
state = getstate_fn()
else:
slots_to_save = copyreg._slotnames(self.__class__) # type: ignore[attr-defined]
if slots_to_save:
state = (
self.__dict__,
{
name: getattr(self, name)
for name in slots_to_save
if hasattr(self, name)
},
)
else:
state = self.__dict__
return (_rebuild_from_type_v2, (func, type(self), args, state))

def storage(self):
Expand Down
59 changes: 59 additions & 0 deletions torch/_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copyreg
import sys
import traceback
import warnings
Expand Down Expand Up @@ -335,6 +336,64 @@ def _rebuild_parameter(data, requires_grad, backward_hooks):
return param


# TODO(kshitij12345): Support serializing nn.Parameter with Python Attributes.
# NOTE: We are just defining it here now for future use.
def _rebuild_parameter_with_state(data, requires_grad, backward_hooks, state):
param = torch.nn.Parameter(data, requires_grad)
# NB: This line exists only for backwards compatibility; the
# general expectation is that backward_hooks is an empty
# OrderedDict. See Note [Don't serialize hooks]
param._backward_hooks = backward_hooks

# Restore state on Parameter like python attr.
param = _set_obj_state(param, state)
return param


def _get_obj_state(obj):
# Get the state of the python subclass
# This loosely mimicks the function on the object class but since Tensor do not inherit
# from it, we cannot call that function directly
# https://github.com/python/cpython/blob/c83919bd635f4433f1c6ae8504996a9fe3c215e5/Objects/typeobject.c#L4891
getstate_fn = getattr(obj, "__getstate__", None)
if getstate_fn:
state = getstate_fn()
else:
slots_to_save = copyreg._slotnames(obj.__class__) # type: ignore[attr-defined]
if slots_to_save:
state = (
obj.__dict__,
{
name: getattr(obj, name)
for name in slots_to_save
if hasattr(obj, name)
},
)
else:
state = obj.__dict__

return state


def _set_obj_state(obj, state):
if isinstance(state, tuple):
if not len(state) == 2:
raise RuntimeError(f"Invalid serialized state: {state}")
dict_state = state[0]
slots_state = state[1]
else:
dict_state = state
slots_state = None

for k, v in dict_state.items():
setattr(obj, k, v)

if slots_state:
for k, v in slots_state.items():
setattr(obj, k, v)
return obj


def _import_dotted_name(name):
components = name.split(".")
obj = __import__(components[0])
Expand Down
4 changes: 4 additions & 0 deletions torch/_weights_only_unpickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ def _get_allowed_globals():
torch._utils._rebuild_sparse_csr_tensor,
]:
rc[f"torch._utils.{f.__name__}"] = f

# Handles Tensor Subclasses, Tensor's with attributes.
# NOTE: It calls into above rebuild functions for regular Tensor types.
rc["torch._tensor._rebuild_from_type_v2"] = torch._tensor._rebuild_from_type_v2
return rc


Expand Down
71 changes: 71 additions & 0 deletions torch/csrc/jit/serialization/unpickler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,21 @@ PickleOpCode Unpickler::readInstruction() {
}
stack_.emplace_back(std::move(tensor));
} break;
case PickleOpCode::SETITEM: {
// At this OpCode, stack looks like
// | Stack Bottom |
// | ...... |
// | Dict | -> (stack_size - 3)
// | Key | -> (stack_size - 2)
// | Value | -> (stack_size - 1)
auto stack_size = stack_.size();
auto dict_pos = stack_size - 3;
auto key_pos = stack_size - 2;
auto val_pos = stack_size - 1;
auto dict = stack_.at(dict_pos).toGenericDict();
dict.insert_or_assign(stack_.at(key_pos), stack_.at(val_pos));
stack_.erase(stack_.begin() + (key_pos), stack_.end());
} break;
default: {
AT_ERROR(
"Unknown opcode for unpickling at ",
Expand All @@ -546,6 +561,23 @@ PickleOpCode Unpickler::readInstruction() {
void Unpickler::readGlobal(
const std::string& module_name,
const std::string& class_name) {
if (this->skip_next_read_global) {
// See [NOTE] skip_next_read_global
this->skip_next_read_global--;
if (this->skip_next_read_global == 1) {
// Pass through to the correct handler
} else if (this->skip_next_read_global == 0) {
// Corresponds to the type of `Tensor` being unpickled
if (module_name != "torch" || class_name != "Tensor") {
TORCH_WARN(
"Trying to load a Subclassed Tensor, it will be converted to at::Tensor in C++");
}
stack_.emplace_back(int64_t(globals_.size() - 1));
return;
} else {
TORCH_CHECK(false, "INVALID VALUES")
}
}
// TODO [unpickler refactor] __main__ isn't used by the pickler anymore, this
// is only here for bc-compatibility reasons
if (module_name == "__main__") {
Expand Down Expand Up @@ -631,6 +663,12 @@ void Unpickler::readGlobal(
// Unpickle a tensor
bool quantized = class_name == "_rebuild_qtensor";
rebuildTensor(quantized);
} else if (
module_name == "torch._tensor" &&
(class_name == "_rebuild_from_type_v2")) {
// Unpickle a Tensor with Python attributes or
// a Subclassed Tensor.
rebuildTensorFromTypeV2();
} else if (
module_name == "torch._utils" && class_name == "_rebuild_sparse_tensor") {
rebuildSparseTensor();
Expand Down Expand Up @@ -849,6 +887,39 @@ void Unpickler::rebuildTensor(bool quantized) {
});
}

void Unpickler::rebuildTensorFromTypeV2() {
// [NOTE] skip_next_read_global
// When rebuilding Tensor with Python Attr or Subclassed Tensor,
// we receive `(func, type(self), args, state)` on stack for
// `rebuildTensorFromTypeV2`.
// Thus next call to readGlobal corresponds to `func` which is
// the function to rebuild the base tensor.
// The call after `func` to readGlobal corresponds to `type` of the
// Tensor where we raise warning if the type is not `torch.Tensor`.
this->skip_next_read_global = 2;
auto curr_globals_idx = globals_.size();
globals_.emplace_back([this, curr_globals_idx] {
// args is a tuple with following data
// (function to rebuild base tensor, type of tensor,
// arguments to construct base tensor, Python State (as dict))
auto args = pop(stack_).toTuple();
size_t tup_idx = 0;
const auto args_elems = args->elements();
auto base_tensor_args = args_elems.at(tup_idx + 2).toTuple();
auto py_state = args_elems.at(tup_idx + 3).toGenericDict();
if (py_state.size() > 0) {
TORCH_WARN(
"Loading Tensor with Python attributes will return at::Tensor with Python attributes being discarded");
}
// This calls the function to rebuild the
// base tensor.
// Eg. `rebuildTensor`, `rebuildSpareTensor`.
stack_.emplace_back(base_tensor_args);
globals_[curr_globals_idx + 1]();
stack_.emplace_back(pop(stack_));
});
}

#ifdef USE_RPC
void Unpickler::rebuildRRef() {
globals_.emplace_back([this] {
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/serialization/unpickler.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class TORCH_API Unpickler {
const std::string& module_name,
const std::string& class_name);
void rebuildTensor(bool quantized);
void rebuildTensorFromTypeV2();
void rebuildSparseTensor();
#ifdef USE_DISTRIBUTED
void rebuildRRef();
Expand Down Expand Up @@ -176,6 +177,9 @@ class TORCH_API Unpickler {

// See [type tag serialization]
uint64_t version_;

// See [NOTE] skip_next_read_global
uint8_t skip_next_read_global = 0;
};

void restoreAccurateTypeTags(const IValue& root, const c10::TypePtr& type_tag);
Expand Down
1 change: 1 addition & 0 deletions torch/nn/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __repr__(self):
return 'Parameter containing:\n' + super(Parameter, self).__repr__()

def __reduce_ex__(self, proto):
# TODO(kshitij12345): Support saving Python Attribute
# See Note [Don't serialize hooks]
return (
torch._utils._rebuild_parameter,
Expand Down

0 comments on commit dc5e3c7

Please sign in to comment.