Skip to content

Commit

Permalink
Merge pull request #22 from fffffarmer/hotfix/invalid_dtype_kwargs
Browse files Browse the repository at this point in the history
fix unexpected kwarg 'bf16' and 'fp32' when initializing
  • Loading branch information
WuTheFWasThat committed Jan 23, 2024
2 parents 7055122 + 6c908db commit 6b450f2
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions train_weak_to_strong.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,35 @@
name="gpt2",
default_lr=5e-5,
eval_batch_size=32,
custom_kwargs={
"torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
},
),
ModelConfig(
name="gpt2-medium",
default_lr=5e-5,
eval_batch_size=32,
custom_kwargs={
"torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
},
),
ModelConfig(
name="gpt2-large",
default_lr=1e-5,
eval_batch_size=32,
custom_kwargs={
"torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
},
),
ModelConfig(
name="gpt2-xl",
default_lr=1e-5,
eval_batch_size=2,
gradient_checkpointing=True,
model_parallel=True,
custom_kwargs={
"torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
},
),
ModelConfig(
name="Qwen/Qwen-1_8B",
Expand All @@ -45,8 +57,7 @@
model_parallel=True,
custom_kwargs={
"trust_remote_code": True,
"bf16": torch.cuda.is_bf16_supported(),
"fp32": not torch.cuda.is_bf16_supported(),
"torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
},
),
ModelConfig(
Expand All @@ -58,8 +69,7 @@
# note: you will probably not be able to run this without many gpus
custom_kwargs={
"trust_remote_code": True,
"bf16": torch.cuda.is_bf16_supported(),
"fp32": not torch.cuda.is_bf16_supported(),
"torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
},
),
ModelConfig(
Expand All @@ -71,8 +81,7 @@
# note: you will probably not be able to run this without bf16 support and many gpus
custom_kwargs={
"trust_remote_code": True,
"bf16": torch.cuda.is_bf16_supported(),
"fp32": not torch.cuda.is_bf16_supported(),
"torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
},
),
ModelConfig(
Expand All @@ -84,8 +93,7 @@
# note: you will probably not be able to run this without bf16 support and many gpus
custom_kwargs={
"trust_remote_code": True,
"bf16": torch.cuda.is_bf16_supported(),
"fp32": not torch.cuda.is_bf16_supported(),
"torch_dtype": torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32,
},
# This model is really big, save space by using adafactor.
# Note that even then it will take up ~60GB per GPU on an 8-GPU machine.
Expand Down

0 comments on commit 6b450f2

Please sign in to comment.