From 46d179cc0d8cd357c8448c37b8f356b591dafe49 Mon Sep 17 00:00:00 2001 From: Jacob Hatef <74274091+jahatef@users.noreply.github.com> Date: Wed, 21 Feb 2024 13:46:22 -0500 Subject: [PATCH] Memory profiling (#1153) * Fixes distributed tests, and skips tests that are broken. * memory profiling for gpt-neox. Only works for pp=0, pp=1+ needs DS commits. * Update NeoXArgs docs automatically * adds memory profiling for pipeline parallel * Update NeoXArgs docs automatically * fix spacing * Update NeoXArgs docs automatically * fix spacing again * Update NeoXArgs docs automatically * get rid of unwanted changes * Update NeoXArgs docs automatically * get rid of file * Update NeoXArgs docs automatically * Update NeoXArgs docs automatically * add nsight systems support * remove tests changes again * Update NeoXArgs docs automatically * add tests * Update NeoXArgs docs automatically * Update training.py * Update NeoXArgs docs automatically * Add assertion message * pre-commit * Update NeoXArgs docs automatically --------- Co-authored-by: github-actions Co-authored-by: Quentin Anthony --- configs/neox_arguments.md | 50 ++++++++++++++++++- megatron/neox_arguments/neox_args.py | 33 +++++++++++++ megatron/training.py | 74 ++++++++++++++++++++++++++++ 3 files changed, 156 insertions(+), 1 deletion(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 022fa6218..46cc91c1c 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 0eb8b39 + Default = 8669123 current git hash of repository @@ -199,6 +199,54 @@ Logging Arguments +- **memory_profiling**: bool + + Default = False + + Whether to take a memory snapshot of the model. Useful for debugging memory issues. + + + +- **memory_profiling_path**: str + + Default = None + + Path to save memory snapshot to. + + + +- **profile**: bool + + Default = False + + Enable nsys profiling. When using this option, + nsys options should be specified in commandline. + An example nsys commandline is + ``` + nsys profile -s none -t nvtx,cuda -o + --force-overwrite true + --capture-range=cudaProfilerApi + --capture-range-end=stop + ``` + + + +- **profile_step_start**: int + + Default = 10 + + Step to start profiling at. + + + +- **profile_step_stop**: int + + Default = 12 + + Step to stop profiling at. + + + ## NeoXArgsModel Model Arguments diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index db1ffcc1f..7f3afffa5 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -573,6 +573,39 @@ class NeoXArgsLogging(NeoXArgsTemplate): Whether to offload the buffered gradients to cpu when measuring gradient noise scale. """ + memory_profiling: bool = False + """ + Whether to take a memory snapshot of the model. Useful for debugging memory issues. + """ + + memory_profiling_path: str = None + """ + Path to save memory snapshot to. + """ + + profile: bool = False + """ + Enable nsys profiling. When using this option, + nsys options should be specified in commandline. + An example nsys commandline is + ``` + nsys profile -s none -t nvtx,cuda -o + --force-overwrite true + --capture-range=cudaProfilerApi + --capture-range-end=stop + ``` + """ + + profile_step_start: int = 10 + """ + Step to start profiling at. + """ + + profile_step_stop: int = 12 + """ + Step to stop profiling at. + """ + @dataclass class NeoXArgsOther(NeoXArgsTemplate): diff --git a/megatron/training.py b/megatron/training.py index 5df179821..216878678 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -56,6 +56,9 @@ ) from megatron.model.gpt2_model import cross_entropy +from pickle import dump +import os + def mup_weights_reinit(neox_args, model): def has_method(o, name): @@ -368,6 +371,8 @@ def forward_step( return model.eval_batch(data_iterator, return_logits=return_logits) # Get the batch. + if neox_args.memory_profiling and neox_args.it: + torch.cuda.nvtx.range_push(f"Get batch") if timers is not None: timers("batch generator").start() tokens, labels, loss_mask, attention_mask, position_ids = get_batch( @@ -376,7 +381,11 @@ def forward_step( if timers is not None: timers("batch generator").stop() + if neox_args.memory_profiling: + torch.cuda.nvtx.range_pop() + if neox_args.memory_profiling: + torch.cuda.nvtx.range_push(f"Forward pass") outputs = model((tokens, position_ids, attention_mask), neox_args=neox_args) if ( is_train @@ -388,6 +397,8 @@ def forward_step( loss = cross_entropy( outputs, (labels, loss_mask), _fp16=neox_args.fp16_lm_cross_entropy ) + if neox_args.memory_profiling: + torch.cuda.nvtx.range_pop() if return_logits: return loss, outputs return loss @@ -628,6 +639,15 @@ def get_learning_rate_scheduler(optimizer, neox_args): def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): + """Setup memory profiler""" + if neox_args.memory_profiling: + torch.cuda.memory._record_memory_history( + True, + # keep a maximum 100,000 alloc/free events from before the snapshot + trace_alloc_max_entries=100000, + trace_alloc_record_context=True, + ) + """Setup model and optimizer.""" model = get_model(neox_args=neox_args, use_cache=use_cache) optimizer, param_groups = get_optimizer(model=model, neox_args=neox_args) @@ -727,6 +747,13 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler) reduced_loss = train_step_pipe( neox_args=neox_args, timers=timers, model=model, data_iterator=data_iterator ) + if ( + neox_args.memory_profiling + and neox_args.iteration >= neox_args.profile_step_start + and neox_args.iteration <= neox_args.profile_step_stop + and torch.distributed.get_rank() == 0 + ): + save_snapshot(neox_args) else: losses = [] for _ in range(neox_args.gradient_accumulation_steps): @@ -742,6 +769,12 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler) timers("forward").stop() losses.append(loss) # Calculate gradients, reduce across processes, and clip. + if ( + neox_args.profiling + and neox_args.iteration >= neox_args.profile_step_start + and neox_args.iteration <= neox_args.profile_step_stop + ): + torch.cuda.nvtx.range_push(f"Backward pass") timers("backward").start() backward_step( neox_args=neox_args, @@ -751,13 +784,38 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler) loss=loss, ) timers("backward").stop() + if ( + neox_args.profiling + and neox_args.iteration >= neox_args.profile_step_start + and neox_args.iteration <= neox_args.profile_step_stop + ): + torch.cuda.nvtx.range_pop() # Update parameters. + if ( + neox_args.profiling + and neox_args.iteration >= neox_args.profile_step_start + and neox_args.iteration <= neox_args.profile_step_stop + ): + torch.cuda.nvtx.range_push(f"Optimizer step") timers("optimizer").start() if neox_args.deepspeed: model.step() else: raise ValueError("Must be using deepspeed to run neox") timers("optimizer").stop() + if ( + neox_args.profiling + and neox_args.iteration >= neox_args.profile_step_start + and neox_args.iteration <= neox_args.profile_step_stop + ): + torch.cuda.nvtx.range_pop() + if ( + neox_args.profiling + and neox_args.iteration >= neox_args.profile_step_start + and neox_args.iteration <= neox_args.profile_step_stop + and torch.distributed.get_rank() == 0 + ): + save_snapshot(neox_args) reduced_loss = { "lm_loss": reduce_losses(losses).mean() } # reduces losses across machines for logging @@ -819,6 +877,8 @@ def train( # to monitor if we've skipped many iterations in a row and trigger an early exit overflow_monitor = OverflowMonitor(optimizer) while iteration < neox_args.train_iters: + if neox_args.profile and iteration == neox_args.profile_step_start: + torch.cuda.cudart().cudaProfilerStart() loss_dict, skipped_iter = train_step( neox_args=neox_args, timers=timers, @@ -827,6 +887,8 @@ def train( optimizer=optimizer, lr_scheduler=lr_scheduler, ) + if neox_args.profile and iteration == neox_args.profile_step_stop: + torch.cuda.cudart().cudaProfilerStop() iteration += 1 neox_args.iteration = iteration if neox_args.precision == "fp16": @@ -1033,3 +1095,15 @@ def evaluate_and_print_results( print_rank_0("-" * length) print_rank_0(string) print_rank_0("-" * length) + + +def save_snapshot(neox_args): + assert ( + neox_args.memory_profiling_path is not None + ), "Must pass memory_profiling_path config arg to use profiling" + snapshot = torch.cuda.memory._snapshot() + snapshot_path = os.path.join(neox_args.memory_profiling_path) + if not os.path.exists(snapshot_path): + os.makedirs(snapshot_path) + with open(os.path.join(snapshot_path, "mem_snapshot.pickle"), "wb") as f: + dump(snapshot, f)