Skip to content

Commit

Permalink
Add dataset_samples arg for alpaca_sample and dummy datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
warner-benjamin committed Mar 19, 2024
1 parent d7818ec commit 049515f
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,12 +307,12 @@ def get_dataloader(tokenizer:PreTrainedTokenizerFast, args:Dict):
if args["dataset"] == "alpaca":
dataset = load_dataset("yahma/alpaca-cleaned")['train']
elif args["dataset"] == "alpaca_sample":
dataset = load_dataset("yahma/alpaca-cleaned", split="train[:512]")
dataset = load_dataset("yahma/alpaca-cleaned", split=f"train[:{args['dataset_samples']}]")
elif args["dataset"] == "dummy":
dataset = Dataset.from_dict({
'instruction': ["instruction"]*512,
'input': ["input"]*512,
'output': ["output"*10000]*512} # A long output to test memory usage (gets truncated)
'instruction': ["instruction"]*args["dataset_samples"],
'input': ["input"]*args["dataset_samples"],
'output': ["output"*10000]*args["dataset_samples"]} # A long output to test memory usage (gets truncated)
)
elif args["dataset"] == "guanaco":
dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")
Expand Down Expand Up @@ -951,6 +951,7 @@ def main(
gradient_accumulation_steps: int = 1, # How many steps to accumulate gradients over (increases effective batch size)
num_epochs: int = 1, # How many epochs of training to do
dataset: Param("", choices=["alpaca", "alpaca_sample", "dummy", "guanaco", "sql"]) = "alpaca_sample", # alpaca, alpaca_sample (for a 128-sample test) or "dummy" for 16 long dummy samples
dataset_samples: int = 512, # Number of samples in an epoch if using "alpaca_sample" or "dummy" dataset
sharding_strategy: Param("", choices=["full_shard", "shard_grad_op", "ddp", "hybrid_full_shard", "hybrid_shard_grad_op"]) = "full_shard", # Sharding strategy for FSDP
use_gradient_checkpointing: bool_arg = True, # Use FSDP's activation checkpointing
reentrant_checkpointing: bool_arg = False, # Use re-entrant autograd activation checkpointing. Setting to True can use less GPU memory with BNB QLoRA
Expand Down

0 comments on commit 049515f

Please sign in to comment.