Skip to content

Commit

Permalink
Add sweep command
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed May 10, 2024
1 parent a8e3e8d commit c99e534
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
17 changes: 17 additions & 0 deletions w2s/sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from argparse import ArgumentParser

from w2s.ds_registry import _REGISTRY
from w2s.train import TrainConfig, train


def main():
parser = ArgumentParser()
parser.add_argument("rank", type=int)
args = parser.parse_args()

DATASETS = sorted(_REGISTRY.keys())
train(TrainConfig("Qwen/Qwen1.5-0.5B", DATASETS[args.rank]))


if __name__ == "__main__":
main()
5 changes: 2 additions & 3 deletions w2s/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ def move_best_ckpt(trainer: Trainer):
print(f"Best model (loss {perf:.3f}) saved at: {dest}")


def main():
cfg = parse(TrainConfig)
def train(cfg: TrainConfig):
lora_cfg = LoraConfig(target_modules=LORA_MODULES)

STRONG_NAME = "meta-llama/Meta-Llama-3-8B"
Expand Down Expand Up @@ -241,4 +240,4 @@ def strong_processor(examples):


if __name__ == "__main__":
main()
train(parse(TrainConfig))

0 comments on commit c99e534

Please sign in to comment.