diff --git a/mup/test/__main__.py b/mup/test/__main__.py index fd5b5ea..5b448a8 100644 --- a/mup/test/__main__.py +++ b/mup/test/__main__.py @@ -27,6 +27,13 @@ def get_mlp_infshapes1(self): target_model = _generate_MLP(128, True, True, True) set_base_shapes(target_model, base_model, delta=delta_model, savefile=self.mlp_base_shapes_file) return get_infshapes(target_model) + + def get_mlp_infshapes1meta(self): + base_model = _generate_MLP(64, True, True, True, device='meta') + delta_model = _generate_MLP(65, True, True, True, device='meta') + target_model = _generate_MLP(128, True, True, True) + set_base_shapes(target_model, base_model, delta=delta_model, savefile=self.mlp_base_shapes_file) + return get_infshapes(target_model) def get_mlp_infshapes2(self): target_model = _generate_MLP(128, True, True, True) @@ -40,6 +47,14 @@ def get_mlp_infshapes3(self): target_model = _generate_MLP(128, True, True, True) set_base_shapes(target_model, base_infshapes) return get_infshapes(target_model) + + def get_mlp_infshapes3meta(self): + base_model = _generate_MLP(64, True, True, True, device='meta') + delta_model = _generate_MLP(65, True, True, True, device='meta') + base_infshapes = make_base_shapes(base_model, delta_model) + target_model = _generate_MLP(128, True, True, True) + set_base_shapes(target_model, base_infshapes) + return get_infshapes(target_model) def get_mlp_infshapes4(self): base_model = _generate_MLP(64, True, True, True) @@ -48,6 +63,13 @@ def get_mlp_infshapes4(self): set_base_shapes(target_model, get_shapes(base_model), delta=get_shapes(delta_model)) return get_infshapes(target_model) + def get_mlp_infshapes4meta(self): + base_model = _generate_MLP(64, True, True, True) + delta_model = _generate_MLP(65, True, True, True, device='meta') + target_model = _generate_MLP(128, True, True, True, device='meta') + set_base_shapes(target_model, get_shapes(base_model), delta=get_shapes(delta_model)) + return get_infshapes(target_model) + def get_mlp_infshapes5(self): delta_model = _generate_MLP(65, True, True, True) target_model = _generate_MLP(128, True, True, True) @@ -55,6 +77,13 @@ def get_mlp_infshapes5(self): set_base_shapes(target_model, self.mlp_base_shapes_file, delta=get_shapes(delta_model)) return get_infshapes(target_model) + def get_mlp_infshapes5meta(self): + delta_model = _generate_MLP(65, True, True, True, device='meta') + target_model = _generate_MLP(128, True, True, True) + # `delta` here doesn't do anything because of base shape file + set_base_shapes(target_model, self.mlp_base_shapes_file, delta=get_shapes(delta_model)) + return get_infshapes(target_model) + def get_mlp_infshapes_bad(self): base_model = _generate_MLP(64, True, True, True) target_model = _generate_MLP(128, True, True, True) @@ -62,10 +91,14 @@ def get_mlp_infshapes_bad(self): return get_infshapes(target_model) def test_set_base_shape(self): + self.assertEqual(self.get_mlp_infshapes1(), self.get_mlp_infshapes1meta()) self.assertEqual(self.get_mlp_infshapes1(), self.get_mlp_infshapes2()) self.assertEqual(self.get_mlp_infshapes3(), self.get_mlp_infshapes2()) self.assertEqual(self.get_mlp_infshapes3(), self.get_mlp_infshapes4()) + self.assertEqual(self.get_mlp_infshapes3(), self.get_mlp_infshapes3meta()) + self.assertEqual(self.get_mlp_infshapes4(), self.get_mlp_infshapes4meta()) self.assertEqual(self.get_mlp_infshapes5(), self.get_mlp_infshapes4()) + self.assertEqual(self.get_mlp_infshapes5(), self.get_mlp_infshapes5meta()) self.assertNotEqual(self.get_mlp_infshapes5(), self.get_mlp_infshapes_bad()) diff --git a/mup/test/models.py b/mup/test/models.py index 49e91ab..d931fbf 100644 --- a/mup/test/models.py +++ b/mup/test/models.py @@ -35,19 +35,19 @@ def init_model(model, sampler): k: partial(init_model, sampler=s) for k, s in samplers.items() } -def _generate_MLP(width, bias=True, mup=True, batchnorm=False): - mods = [Linear(3072, width, bias=bias), +def _generate_MLP(width, bias=True, mup=True, batchnorm=False, device='cpu'): + mods = [Linear(3072, width, bias=bias, device=device), nn.ReLU(), - Linear(width, width, bias=bias), + Linear(width, width, bias=bias, device=device), nn.ReLU() ] if mup: - mods.append(MuReadout(width, 10, bias=bias, readout_zero_init=False)) + mods.append(MuReadout(width, 10, bias=bias, readout_zero_init=False, device=device)) else: - mods.append(Linear(width, 10, bias=bias)) + mods.append(Linear(width, 10, bias=bias, device=device)) if batchnorm: - mods.insert(1, nn.BatchNorm1d(width)) - mods.insert(4, nn.BatchNorm1d(width)) + mods.insert(1, nn.BatchNorm1d(width, device=device)) + mods.insert(4, nn.BatchNorm1d(width, device=device)) model = nn.Sequential(*mods) return model @@ -58,7 +58,7 @@ def generate_MLP(width, bias=True, mup=True, readout_zero_init=True, batchnorm=F return set_base_shapes(model, None) # it's important we make `model` first, because of random seed model = _generate_MLP(width, bias, mup, batchnorm) - base_model = _generate_MLP(base_width, bias, mup, batchnorm) + base_model = _generate_MLP(base_width, bias, mup, batchnorm, device='meta') set_base_shapes(model, base_model) init_methods[init](model) if readout_zero_init: @@ -73,29 +73,29 @@ def generate_MLP(width, bias=True, mup=True, readout_zero_init=True, batchnorm=F return model -def _generate_CNN(width, bias=True, mup=True, batchnorm=False): +def _generate_CNN(width, bias=True, mup=True, batchnorm=False, device='cpu'): mods = [ - nn.Conv2d(3, width, kernel_size=5, bias=bias), + nn.Conv2d(3, width, kernel_size=5, bias=bias, device=device), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), - nn.Conv2d(width, 2*width, kernel_size=5, bias=bias), + nn.Conv2d(width, 2*width, kernel_size=5, bias=bias, device=device), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), nn.Flatten(), - nn.Linear(2*width*25, width*16, bias=bias), + nn.Linear(2*width*25, width*16, bias=bias, device=device), nn.ReLU(inplace=True), - nn.Linear(width*16, width*10, bias=bias), + nn.Linear(width*16, width*10, bias=bias, device=device), nn.ReLU(inplace=True), ] if mup: - mods.append(MuReadout(width*10, 10, bias=bias, readout_zero_init=False)) + mods.append(MuReadout(width*10, 10, bias=bias, readout_zero_init=False, device=device)) else: - mods.append(nn.Linear(width*10, 10, bias=bias)) + mods.append(nn.Linear(width*10, 10, bias=bias, device=device)) if batchnorm: - mods.insert(1, nn.BatchNorm2d(width)) - mods.insert(5, nn.BatchNorm2d(2*width)) - mods.insert(10, nn.BatchNorm1d(16*width)) - mods.insert(13, nn.BatchNorm1d(10*width)) + mods.insert(1, nn.BatchNorm2d(width, device=device)) + mods.insert(5, nn.BatchNorm2d(2*width, device=device)) + mods.insert(10, nn.BatchNorm1d(16*width, device=device)) + mods.insert(13, nn.BatchNorm1d(10*width, device=device)) return nn.Sequential(*mods) def generate_CNN(width, bias=True, mup=True, readout_zero_init=True, batchnorm=False, init='default', bias_zero_init=False, base_width=8): @@ -105,7 +105,7 @@ def generate_CNN(width, bias=True, mup=True, readout_zero_init=True, batchnorm=F return set_base_shapes(model, None) # it's important we make `model` first, because of random seed model = _generate_CNN(width, bias, mup, batchnorm) - base_model = _generate_CNN(base_width, bias, mup, batchnorm) + base_model = _generate_CNN(base_width, bias, mup, batchnorm, device='meta') set_base_shapes(model, base_model) init_methods[init](model) if readout_zero_init: