Skip to content

Commit

Permalink
Fix optimizer reloading from checkpoint (#1329)
Browse files Browse the repository at this point in the history
* fix optimizer reloading from checkpoint

* comment
  • Loading branch information
awni authored Aug 15, 2024
1 parent d0630ff commit ae5b5ca
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 12 deletions.
26 changes: 23 additions & 3 deletions python/mlx/optimizers/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,28 @@ def init(self, parameters: dict):
>>> optimizer.state.keys()
dict_keys(['step', 'learning_rate', 'weight', 'bias'])
"""
self._state.update(tree_map(lambda x: {}, parameters))
tree_map(self.init_single, parameters, self._state)

# Iniatilize the optimizer state to match the parameter state
def update_state(params, state):
if isinstance(params, (list, tuple)):
state = list(state)
for i in range(len(state)):
state[i] = update_state(params[i], state[i])
if len(state) != len(params):
state.extend(tree_map(lambda x: {}, params[len(state) :]))
return type(params)(state)
elif isinstance(params, dict):
for k, v in params.items():
if k not in state:
state[k] = tree_map(lambda x: {}, v)
else:
state[k] = update_state(v, state[k])
return state
else:
return state

update_state(parameters, self._state)
tree_map(lambda p, s: s or self.init_single(p, s), parameters, self._state)
self._initialized = True

def init_single(self, parameter: mx.array, state: dict):
Expand Down Expand Up @@ -104,7 +124,7 @@ def state(self):

@state.setter
def state(self, state: dict):
self._initialized = True
self._initialized = False
self._state = state

@property
Expand Down
44 changes: 35 additions & 9 deletions python/tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import mlx.optimizers as opt
import mlx.utils
import mlx_tests
from mlx.utils import tree_flatten, tree_map
from mlx.utils import tree_flatten, tree_map, tree_unflatten


def get_all_optimizers():
Expand Down Expand Up @@ -206,20 +206,22 @@ def test_lion(self):

def test_adafactor(self):
x = mx.zeros((5, 5))
grad = mx.ones_like(x)
params = {"x": x}
grad = {"x": mx.ones_like(x)}
optimizer = opt.Adafactor()
for _ in range(2):
xp = optimizer.apply_gradients(grad, x)
self.assertEqual(xp.dtype, x.dtype)
self.assertEqual(xp.shape, x.shape)
xp = optimizer.apply_gradients(grad, params)
self.assertEqual(xp["x"].dtype, x.dtype)
self.assertEqual(xp["x"].shape, x.shape)

x = mx.zeros((5, 5), mx.float16)
grad = mx.ones_like(x)
params = {"x": x}
grad = {"x": mx.ones_like(x)}
optimizer = opt.Adafactor()
for _ in range(2):
xp = optimizer.apply_gradients(grad, x)
self.assertEqual(xp.dtype, x.dtype)
self.assertEqual(xp.shape, x.shape)
xp = optimizer.apply_gradients(grad, params)
self.assertEqual(xp["x"].dtype, x.dtype)
self.assertEqual(xp["x"].shape, x.shape)
self.assertEqual(optimizer.state["step"], 2)

def test_compiled_optimizer(self):
Expand Down Expand Up @@ -420,6 +422,30 @@ def test_clip_grad_norm(self):
"Gradients were not scaled correctly during clipping.",
)

def test_init_from_state(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(2, 2)
self.drop = nn.Dropout(p=0.5)
self.l2 = nn.Linear(2, 2)
self.vals = [nn.Linear(2, 2), nn.ReLU(), nn.ReLU()]

model = Model()
optimizer = opt.Adam(learning_rate=3e-4)
optimizer.init(model.trainable_parameters())

# Flatten the state for serialization
state = tree_flatten(optimizer.state)

# Make a new optimizer and load the state
optimizer = opt.Adam(learning_rate=3e-4)
optimizer.state = tree_unflatten(state)

# This should work without any errors
grads = model.trainable_parameters()
optimizer.update(model, grads)


if __name__ == "__main__":
unittest.main()

0 comments on commit ae5b5ca

Please sign in to comment.