Skip to content

Commit

Permalink
enable to set save path
Browse files Browse the repository at this point in the history
  • Loading branch information
lintangsutawika committed Jan 16, 2023
1 parent bcb70f8 commit 899add0
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions utils/batch_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def view_data(
args,
neox_args,
batch_fn: callable = None,
save_path: str = None,
):
# fake MPU setup (needed to init dataloader without actual GPUs or parallelism)
mpu.mock_model_parallel()
Expand All @@ -37,12 +38,14 @@ def view_data(

if args.mode == "save":
# save full batches for each step in the range (WARNING: this may consume lots of storage!)
np.save(f"./dump_data/batch{i}_bs{neox_args.train_micro_batch_size_per_gpu}", batch)
filename = f"batch{i}_bs{neox_args.train_micro_batch_size_per_gpu}"
np.save(os.path.join(save_path, filename), batch)
elif args.mode == "custom":
# dump user_defined statistic to a jsonl file (save_fn must return a dict)
log = batch_fn(batch, i)

with open("./dump_data/stats.jsonl", "w+") as f:
filename = "stats.jsonl"
with open(os.path.join(save_path, filename), "w+") as f:
f.write(json.dumps(log) + "\n")
else:
raise ValueError(f'mode={mode} not acceptable--please pass either "save" or "custom" !')
Expand Down Expand Up @@ -74,6 +77,12 @@ def view_data(
choices=["save", "custom"],
help="Choose mode: 'save' to log all batches, and 'custom' to use user-defined statistic"
)
parser.add_argument(
"--save_path",
type=str,
default=0,
help="Save path for files"
)
args = parser.parse_known_args()[0]

# init neox args
Expand All @@ -86,10 +95,11 @@ def save_fn(batch: np.array, iteration: int):
# define your own logic here
return {"iteration": iteration, "text": None}

os.makedirs(args.save_path, exist_ok=True)

view_data(
args,
neox_args,
batch_fn=save_fn,
save_path=args.save_path,
)

0 comments on commit 899add0

Please sign in to comment.