Skip to content

Commit

Permalink
fix unexpected kwarg 'bf16' and 'fp32' when initializing
Browse files Browse the repository at this point in the history
  • Loading branch information
zhxieml committed Dec 22, 2023
1 parent 0a6d4ab commit c01b150
Showing 1 changed file with 8 additions and 16 deletions.
24 changes: 8 additions & 16 deletions train_weak_to_strong.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,23 @@
default_lr=5e-5,
eval_batch_size=32,
custom_kwargs={
"bf16": torch.cuda.is_bf16_supported(),
"fp32": not torch.cuda.is_bf16_supported(),
"torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
},
),
ModelConfig(
name="gpt2-medium",
default_lr=5e-5,
eval_batch_size=32,
custom_kwargs={
"bf16": torch.cuda.is_bf16_supported(),
"fp32": not torch.cuda.is_bf16_supported(),
"torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
},
),
ModelConfig(
name="gpt2-large",
default_lr=1e-5,
eval_batch_size=32,
custom_kwargs={
"bf16": torch.cuda.is_bf16_supported(),
"fp32": not torch.cuda.is_bf16_supported(),
"torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
},
),
ModelConfig(
Expand All @@ -49,8 +46,7 @@
gradient_checkpointing=True,
model_parallel=True,
custom_kwargs={
"bf16": torch.cuda.is_bf16_supported(),
"fp32": not torch.cuda.is_bf16_supported(),
"torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
},
),
ModelConfig(
Expand All @@ -61,8 +57,7 @@
model_parallel=True,
custom_kwargs={
"trust_remote_code": True,
"bf16": torch.cuda.is_bf16_supported(),
"fp32": not torch.cuda.is_bf16_supported(),
"torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
},
),
ModelConfig(
Expand All @@ -74,8 +69,7 @@
# note: you will probably not be able to run this without many gpus
custom_kwargs={
"trust_remote_code": True,
"bf16": torch.cuda.is_bf16_supported(),
"fp32": not torch.cuda.is_bf16_supported(),
"torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
},
),
ModelConfig(
Expand All @@ -87,8 +81,7 @@
# note: you will probably not be able to run this without bf16 support and many gpus
custom_kwargs={
"trust_remote_code": True,
"bf16": torch.cuda.is_bf16_supported(),
"fp32": not torch.cuda.is_bf16_supported(),
"torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
},
),
ModelConfig(
Expand All @@ -100,8 +93,7 @@
# note: you will probably not be able to run this without bf16 support and many gpus
custom_kwargs={
"trust_remote_code": True,
"bf16": torch.cuda.is_bf16_supported(),
"fp32": not torch.cuda.is_bf16_supported(),
"torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
},
# This model is really big, save space by using adafactor.
# Note that even then it will take up ~60GB per GPU on an 8-GPU machine.
Expand Down

0 comments on commit c01b150

Please sign in to comment.