Skip to content

Commit

Permalink
Memory profiling (#1153)
Browse files Browse the repository at this point in the history
* Fixes distributed tests, and skips tests that are broken.

* memory profiling for gpt-neox. Only works for pp=0, pp=1+ needs DS commits.

* Update NeoXArgs docs automatically

* adds memory profiling for pipeline parallel

* Update NeoXArgs docs automatically

* fix spacing

* Update NeoXArgs docs automatically

* fix spacing again

* Update NeoXArgs docs automatically

* get rid of unwanted changes

* Update NeoXArgs docs automatically

* get rid of file

* Update NeoXArgs docs automatically

* Update NeoXArgs docs automatically

* add nsight systems support

* remove tests changes again

* Update NeoXArgs docs automatically

* add tests

* Update NeoXArgs docs automatically

* Update training.py

* Update NeoXArgs docs automatically

* Add assertion message

* pre-commit

* Update NeoXArgs docs automatically

---------

Co-authored-by: github-actions <[email protected]>
Co-authored-by: Quentin Anthony <[email protected]>
  • Loading branch information
3 people committed Feb 21, 2024
1 parent 412cf6e commit 46d179c
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 1 deletion.
50 changes: 49 additions & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = 0eb8b39
Default = 8669123

current git hash of repository

Expand Down Expand Up @@ -199,6 +199,54 @@ Logging Arguments



- **memory_profiling**: bool

Default = False

Whether to take a memory snapshot of the model. Useful for debugging memory issues.



- **memory_profiling_path**: str

Default = None

Path to save memory snapshot to.



- **profile**: bool

Default = False

Enable nsys profiling. When using this option,
nsys options should be specified in commandline.
An example nsys commandline is
```
nsys profile -s none -t nvtx,cuda -o <path/to/output_file>
--force-overwrite true
--capture-range=cudaProfilerApi
--capture-range-end=stop
```



- **profile_step_start**: int

Default = 10

Step to start profiling at.



- **profile_step_stop**: int

Default = 12

Step to stop profiling at.



## NeoXArgsModel

Model Arguments
Expand Down
33 changes: 33 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,39 @@ class NeoXArgsLogging(NeoXArgsTemplate):
Whether to offload the buffered gradients to cpu when measuring gradient noise scale.
"""

memory_profiling: bool = False
"""
Whether to take a memory snapshot of the model. Useful for debugging memory issues.
"""

memory_profiling_path: str = None
"""
Path to save memory snapshot to.
"""

profile: bool = False
"""
Enable nsys profiling. When using this option,
nsys options should be specified in commandline.
An example nsys commandline is
```
nsys profile -s none -t nvtx,cuda -o <path/to/output_file>
--force-overwrite true
--capture-range=cudaProfilerApi
--capture-range-end=stop
```
"""

profile_step_start: int = 10
"""
Step to start profiling at.
"""

profile_step_stop: int = 12
"""
Step to stop profiling at.
"""


@dataclass
class NeoXArgsOther(NeoXArgsTemplate):
Expand Down
74 changes: 74 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@
)
from megatron.model.gpt2_model import cross_entropy

from pickle import dump
import os


def mup_weights_reinit(neox_args, model):
def has_method(o, name):
Expand Down Expand Up @@ -368,6 +371,8 @@ def forward_step(
return model.eval_batch(data_iterator, return_logits=return_logits)

# Get the batch.
if neox_args.memory_profiling and neox_args.it:
torch.cuda.nvtx.range_push(f"Get batch")
if timers is not None:
timers("batch generator").start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
Expand All @@ -376,7 +381,11 @@ def forward_step(

if timers is not None:
timers("batch generator").stop()
if neox_args.memory_profiling:
torch.cuda.nvtx.range_pop()

if neox_args.memory_profiling:
torch.cuda.nvtx.range_push(f"Forward pass")
outputs = model((tokens, position_ids, attention_mask), neox_args=neox_args)
if (
is_train
Expand All @@ -388,6 +397,8 @@ def forward_step(
loss = cross_entropy(
outputs, (labels, loss_mask), _fp16=neox_args.fp16_lm_cross_entropy
)
if neox_args.memory_profiling:
torch.cuda.nvtx.range_pop()
if return_logits:
return loss, outputs
return loss
Expand Down Expand Up @@ -628,6 +639,15 @@ def get_learning_rate_scheduler(optimizer, neox_args):


def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None):
"""Setup memory profiler"""
if neox_args.memory_profiling:
torch.cuda.memory._record_memory_history(
True,
# keep a maximum 100,000 alloc/free events from before the snapshot
trace_alloc_max_entries=100000,
trace_alloc_record_context=True,
)

"""Setup model and optimizer."""
model = get_model(neox_args=neox_args, use_cache=use_cache)
optimizer, param_groups = get_optimizer(model=model, neox_args=neox_args)
Expand Down Expand Up @@ -727,6 +747,13 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler)
reduced_loss = train_step_pipe(
neox_args=neox_args, timers=timers, model=model, data_iterator=data_iterator
)
if (
neox_args.memory_profiling
and neox_args.iteration >= neox_args.profile_step_start
and neox_args.iteration <= neox_args.profile_step_stop
and torch.distributed.get_rank() == 0
):
save_snapshot(neox_args)
else:
losses = []
for _ in range(neox_args.gradient_accumulation_steps):
Expand All @@ -742,6 +769,12 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler)
timers("forward").stop()
losses.append(loss)
# Calculate gradients, reduce across processes, and clip.
if (
neox_args.profiling
and neox_args.iteration >= neox_args.profile_step_start
and neox_args.iteration <= neox_args.profile_step_stop
):
torch.cuda.nvtx.range_push(f"Backward pass")
timers("backward").start()
backward_step(
neox_args=neox_args,
Expand All @@ -751,13 +784,38 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler)
loss=loss,
)
timers("backward").stop()
if (
neox_args.profiling
and neox_args.iteration >= neox_args.profile_step_start
and neox_args.iteration <= neox_args.profile_step_stop
):
torch.cuda.nvtx.range_pop()
# Update parameters.
if (
neox_args.profiling
and neox_args.iteration >= neox_args.profile_step_start
and neox_args.iteration <= neox_args.profile_step_stop
):
torch.cuda.nvtx.range_push(f"Optimizer step")
timers("optimizer").start()
if neox_args.deepspeed:
model.step()
else:
raise ValueError("Must be using deepspeed to run neox")
timers("optimizer").stop()
if (
neox_args.profiling
and neox_args.iteration >= neox_args.profile_step_start
and neox_args.iteration <= neox_args.profile_step_stop
):
torch.cuda.nvtx.range_pop()
if (
neox_args.profiling
and neox_args.iteration >= neox_args.profile_step_start
and neox_args.iteration <= neox_args.profile_step_stop
and torch.distributed.get_rank() == 0
):
save_snapshot(neox_args)
reduced_loss = {
"lm_loss": reduce_losses(losses).mean()
} # reduces losses across machines for logging
Expand Down Expand Up @@ -819,6 +877,8 @@ def train(
# to monitor if we've skipped many iterations in a row and trigger an early exit
overflow_monitor = OverflowMonitor(optimizer)
while iteration < neox_args.train_iters:
if neox_args.profile and iteration == neox_args.profile_step_start:
torch.cuda.cudart().cudaProfilerStart()
loss_dict, skipped_iter = train_step(
neox_args=neox_args,
timers=timers,
Expand All @@ -827,6 +887,8 @@ def train(
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
if neox_args.profile and iteration == neox_args.profile_step_stop:
torch.cuda.cudart().cudaProfilerStop()
iteration += 1
neox_args.iteration = iteration
if neox_args.precision == "fp16":
Expand Down Expand Up @@ -1033,3 +1095,15 @@ def evaluate_and_print_results(
print_rank_0("-" * length)
print_rank_0(string)
print_rank_0("-" * length)


def save_snapshot(neox_args):
assert (
neox_args.memory_profiling_path is not None
), "Must pass memory_profiling_path config arg to use profiling"
snapshot = torch.cuda.memory._snapshot()
snapshot_path = os.path.join(neox_args.memory_profiling_path)
if not os.path.exists(snapshot_path):
os.makedirs(snapshot_path)
with open(os.path.join(snapshot_path, "mem_snapshot.pickle"), "wb") as f:
dump(snapshot, f)

0 comments on commit 46d179c

Please sign in to comment.