From 89ed7636beed09aaadc7a8391c4a05f1f6fc5bf8 Mon Sep 17 00:00:00 2001 From: Edward Hu Date: Sat, 19 Mar 2022 16:21:58 -0400 Subject: [PATCH] add usage of the meta flag to README --- README.md | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f209b41..4d60d58 100644 --- a/README.md +++ b/README.md @@ -76,11 +76,14 @@ class MyModel(nn.Module): ### Instantiate a base model base_model = MyModel(width=1) +### Optionally, use `device='meta'` to avoid instantiating the model +### This requires you to pass the device flag down to all sub-modules +# base_model = MyModel(width=1, device='meta') ### Instantiate a "delta" model that differs from the base model ### in all dimensions ("widths") that one wishes to scale. ### Here it's simple, but e.g., in a Transformer, you may want to scale ### both nhead and dhead, so the delta model should differ in both. -delta_model = MyModel(width=2) +delta_model = MyModel(width=2) # Optionally add the `device='meta'` to avoid instantiating ### Instantiate the target model (the model you actually want to train). ### This should be the same as the base model except @@ -123,7 +126,8 @@ optimizer = MuSGD(model.parameters(), lr=0.1) ``` Note the base and delta models *do not need to be trained* --- we are only extracting parameter shape information from them. -Ideally, we can do so without instantiating the model parameters at all, like in [JAX](https://github.com/google/jax), but unfortunately we currently can't do that in PyTorch --- but upvote [this PyTorch issue](https://github.com/pytorch/pytorch/issues/74143) if you want to see this feature happening! +Therefore, optionally, we can avoid instantiating these potentially large models by passing `device='meta'` to their constructor. +However, you need to make sure that the `device` flag is appropriately passed down to the constructor of all submodules. ## How `mup` Works Under the Hood