Skip to content

Commit

Permalink
add tests for meta tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
edwardjhu committed Mar 19, 2022
1 parent a2fec5f commit 7758dae
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 20 deletions.
33 changes: 33 additions & 0 deletions mup/test/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ def get_mlp_infshapes1(self):
target_model = _generate_MLP(128, True, True, True)
set_base_shapes(target_model, base_model, delta=delta_model, savefile=self.mlp_base_shapes_file)
return get_infshapes(target_model)

def get_mlp_infshapes1meta(self):
base_model = _generate_MLP(64, True, True, True, device='meta')
delta_model = _generate_MLP(65, True, True, True, device='meta')
target_model = _generate_MLP(128, True, True, True)
set_base_shapes(target_model, base_model, delta=delta_model, savefile=self.mlp_base_shapes_file)
return get_infshapes(target_model)

def get_mlp_infshapes2(self):
target_model = _generate_MLP(128, True, True, True)
Expand All @@ -40,6 +47,14 @@ def get_mlp_infshapes3(self):
target_model = _generate_MLP(128, True, True, True)
set_base_shapes(target_model, base_infshapes)
return get_infshapes(target_model)

def get_mlp_infshapes3meta(self):
base_model = _generate_MLP(64, True, True, True, device='meta')
delta_model = _generate_MLP(65, True, True, True, device='meta')
base_infshapes = make_base_shapes(base_model, delta_model)
target_model = _generate_MLP(128, True, True, True)
set_base_shapes(target_model, base_infshapes)
return get_infshapes(target_model)

def get_mlp_infshapes4(self):
base_model = _generate_MLP(64, True, True, True)
Expand All @@ -48,24 +63,42 @@ def get_mlp_infshapes4(self):
set_base_shapes(target_model, get_shapes(base_model), delta=get_shapes(delta_model))
return get_infshapes(target_model)

def get_mlp_infshapes4meta(self):
base_model = _generate_MLP(64, True, True, True)
delta_model = _generate_MLP(65, True, True, True, device='meta')
target_model = _generate_MLP(128, True, True, True, device='meta')
set_base_shapes(target_model, get_shapes(base_model), delta=get_shapes(delta_model))
return get_infshapes(target_model)

def get_mlp_infshapes5(self):
delta_model = _generate_MLP(65, True, True, True)
target_model = _generate_MLP(128, True, True, True)
# `delta` here doesn't do anything because of base shape file
set_base_shapes(target_model, self.mlp_base_shapes_file, delta=get_shapes(delta_model))
return get_infshapes(target_model)

def get_mlp_infshapes5meta(self):
delta_model = _generate_MLP(65, True, True, True, device='meta')
target_model = _generate_MLP(128, True, True, True)
# `delta` here doesn't do anything because of base shape file
set_base_shapes(target_model, self.mlp_base_shapes_file, delta=get_shapes(delta_model))
return get_infshapes(target_model)

def get_mlp_infshapes_bad(self):
base_model = _generate_MLP(64, True, True, True)
target_model = _generate_MLP(128, True, True, True)
set_base_shapes(target_model, base_model, delta=base_model)
return get_infshapes(target_model)

def test_set_base_shape(self):
self.assertEqual(self.get_mlp_infshapes1(), self.get_mlp_infshapes1meta())
self.assertEqual(self.get_mlp_infshapes1(), self.get_mlp_infshapes2())
self.assertEqual(self.get_mlp_infshapes3(), self.get_mlp_infshapes2())
self.assertEqual(self.get_mlp_infshapes3(), self.get_mlp_infshapes4())
self.assertEqual(self.get_mlp_infshapes3(), self.get_mlp_infshapes3meta())
self.assertEqual(self.get_mlp_infshapes4(), self.get_mlp_infshapes4meta())
self.assertEqual(self.get_mlp_infshapes5(), self.get_mlp_infshapes4())
self.assertEqual(self.get_mlp_infshapes5(), self.get_mlp_infshapes5meta())
self.assertNotEqual(self.get_mlp_infshapes5(), self.get_mlp_infshapes_bad())


Expand Down
40 changes: 20 additions & 20 deletions mup/test/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,19 @@ def init_model(model, sampler):
k: partial(init_model, sampler=s) for k, s in samplers.items()
}

def _generate_MLP(width, bias=True, mup=True, batchnorm=False):
mods = [Linear(3072, width, bias=bias),
def _generate_MLP(width, bias=True, mup=True, batchnorm=False, device='cpu'):
mods = [Linear(3072, width, bias=bias, device=device),
nn.ReLU(),
Linear(width, width, bias=bias),
Linear(width, width, bias=bias, device=device),
nn.ReLU()
]
if mup:
mods.append(MuReadout(width, 10, bias=bias, readout_zero_init=False))
mods.append(MuReadout(width, 10, bias=bias, readout_zero_init=False, device=device))
else:
mods.append(Linear(width, 10, bias=bias))
mods.append(Linear(width, 10, bias=bias, device=device))
if batchnorm:
mods.insert(1, nn.BatchNorm1d(width))
mods.insert(4, nn.BatchNorm1d(width))
mods.insert(1, nn.BatchNorm1d(width, device=device))
mods.insert(4, nn.BatchNorm1d(width, device=device))
model = nn.Sequential(*mods)
return model

Expand All @@ -58,7 +58,7 @@ def generate_MLP(width, bias=True, mup=True, readout_zero_init=True, batchnorm=F
return set_base_shapes(model, None)
# it's important we make `model` first, because of random seed
model = _generate_MLP(width, bias, mup, batchnorm)
base_model = _generate_MLP(base_width, bias, mup, batchnorm)
base_model = _generate_MLP(base_width, bias, mup, batchnorm, device='meta')
set_base_shapes(model, base_model)
init_methods[init](model)
if readout_zero_init:
Expand All @@ -73,29 +73,29 @@ def generate_MLP(width, bias=True, mup=True, readout_zero_init=True, batchnorm=F
return model


def _generate_CNN(width, bias=True, mup=True, batchnorm=False):
def _generate_CNN(width, bias=True, mup=True, batchnorm=False, device='cpu'):
mods = [
nn.Conv2d(3, width, kernel_size=5, bias=bias),
nn.Conv2d(3, width, kernel_size=5, bias=bias, device=device),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(width, 2*width, kernel_size=5, bias=bias),
nn.Conv2d(width, 2*width, kernel_size=5, bias=bias, device=device),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Flatten(),
nn.Linear(2*width*25, width*16, bias=bias),
nn.Linear(2*width*25, width*16, bias=bias, device=device),
nn.ReLU(inplace=True),
nn.Linear(width*16, width*10, bias=bias),
nn.Linear(width*16, width*10, bias=bias, device=device),
nn.ReLU(inplace=True),
]
if mup:
mods.append(MuReadout(width*10, 10, bias=bias, readout_zero_init=False))
mods.append(MuReadout(width*10, 10, bias=bias, readout_zero_init=False, device=device))
else:
mods.append(nn.Linear(width*10, 10, bias=bias))
mods.append(nn.Linear(width*10, 10, bias=bias, device=device))
if batchnorm:
mods.insert(1, nn.BatchNorm2d(width))
mods.insert(5, nn.BatchNorm2d(2*width))
mods.insert(10, nn.BatchNorm1d(16*width))
mods.insert(13, nn.BatchNorm1d(10*width))
mods.insert(1, nn.BatchNorm2d(width, device=device))
mods.insert(5, nn.BatchNorm2d(2*width, device=device))
mods.insert(10, nn.BatchNorm1d(16*width, device=device))
mods.insert(13, nn.BatchNorm1d(10*width, device=device))
return nn.Sequential(*mods)

def generate_CNN(width, bias=True, mup=True, readout_zero_init=True, batchnorm=False, init='default', bias_zero_init=False, base_width=8):
Expand All @@ -105,7 +105,7 @@ def generate_CNN(width, bias=True, mup=True, readout_zero_init=True, batchnorm=F
return set_base_shapes(model, None)
# it's important we make `model` first, because of random seed
model = _generate_CNN(width, bias, mup, batchnorm)
base_model = _generate_CNN(base_width, bias, mup, batchnorm)
base_model = _generate_CNN(base_width, bias, mup, batchnorm, device='meta')
set_base_shapes(model, base_model)
init_methods[init](model)
if readout_zero_init:
Expand Down

0 comments on commit 7758dae

Please sign in to comment.