Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory profiling #1153

Merged
merged 32 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
f5fd54c
Fixes distributed tests, and skips tests that are broken.
jahatef Feb 14, 2024
4a4a934
Merge branch 'main' of github.com:EleutherAI/gpt-neox into main
jahatef Feb 18, 2024
f63593b
memory profiling for gpt-neox. Only works for pp=0, pp=1+ needs DS co…
jahatef Feb 20, 2024
4ed9d42
Update NeoXArgs docs automatically
invalid-email-address Feb 21, 2024
89efc48
adds memory profiling for pipeline parallel
jahatef Feb 21, 2024
95f31f0
Merge branch 'memory_profiling' of github.com:EleutherAI/gpt-neox int…
jahatef Feb 21, 2024
9551afe
Update NeoXArgs docs automatically
invalid-email-address Feb 21, 2024
4135743
fix spacing
jahatef Feb 21, 2024
7b0cdaf
Merge branch 'memory_profiling' of github.com:EleutherAI/gpt-neox int…
jahatef Feb 21, 2024
45aea7a
Update NeoXArgs docs automatically
invalid-email-address Feb 21, 2024
3bff276
fix spacing again
jahatef Feb 21, 2024
2452697
Merge branch 'memory_profiling' of github.com:EleutherAI/gpt-neox int…
jahatef Feb 21, 2024
d9c7e4b
Update NeoXArgs docs automatically
invalid-email-address Feb 21, 2024
7af1c9d
get rid of unwanted changes
jahatef Feb 21, 2024
47f76af
Merge branch 'memory_profiling' of github.com:EleutherAI/gpt-neox int…
jahatef Feb 21, 2024
7994909
Update NeoXArgs docs automatically
invalid-email-address Feb 21, 2024
a2893db
get rid of file
jahatef Feb 21, 2024
db8b70b
Merge branch 'memory_profiling' of github.com:EleutherAI/gpt-neox int…
jahatef Feb 21, 2024
80b1e30
Update NeoXArgs docs automatically
invalid-email-address Feb 21, 2024
7467632
Merge branch 'main' into memory_profiling
Quentin-Anthony Feb 21, 2024
5c51c43
Update NeoXArgs docs automatically
invalid-email-address Feb 21, 2024
20bc950
add nsight systems support
jahatef Feb 21, 2024
fd0b471
Merge branch 'memory_profiling' of github.com:EleutherAI/gpt-neox int…
jahatef Feb 21, 2024
87bca9d
remove tests changes again
jahatef Feb 21, 2024
65ce859
Update NeoXArgs docs automatically
invalid-email-address Feb 21, 2024
49cf95d
add tests
jahatef Feb 21, 2024
ab8126d
Update NeoXArgs docs automatically
invalid-email-address Feb 21, 2024
ae2c61d
Update training.py
jahatef Feb 21, 2024
21eba94
Update NeoXArgs docs automatically
invalid-email-address Feb 21, 2024
edfcdaf
Add assertion message
Quentin-Anthony Feb 21, 2024
8669123
pre-commit
Quentin-Anthony Feb 21, 2024
80aa4cb
Update NeoXArgs docs automatically
invalid-email-address Feb 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = 0eb8b39
Default = 8669123

current git hash of repository

Expand Down Expand Up @@ -199,6 +199,54 @@ Logging Arguments



- **memory_profiling**: bool

Default = False

Whether to take a memory snapshot of the model. Useful for debugging memory issues.



- **memory_profiling_path**: str

Default = None

Path to save memory snapshot to.



- **profile**: bool

Default = False

Enable nsys profiling. When using this option,
nsys options should be specified in commandline.
An example nsys commandline is
```
nsys profile -s none -t nvtx,cuda -o <path/to/output_file>
--force-overwrite true
--capture-range=cudaProfilerApi
--capture-range-end=stop
```



- **profile_step_start**: int

Default = 10

Step to start profiling at.



- **profile_step_stop**: int

Default = 12

Step to stop profiling at.



## NeoXArgsModel

Model Arguments
Expand Down
33 changes: 33 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,39 @@ class NeoXArgsLogging(NeoXArgsTemplate):
Whether to offload the buffered gradients to cpu when measuring gradient noise scale.
"""

memory_profiling: bool = False
"""
Whether to take a memory snapshot of the model. Useful for debugging memory issues.
"""

memory_profiling_path: str = None
"""
Path to save memory snapshot to.
"""

profile: bool = False
"""
Enable nsys profiling. When using this option,
nsys options should be specified in commandline.
An example nsys commandline is
```
nsys profile -s none -t nvtx,cuda -o <path/to/output_file>
--force-overwrite true
--capture-range=cudaProfilerApi
--capture-range-end=stop
```
"""

profile_step_start: int = 10
"""
Step to start profiling at.
"""

profile_step_stop: int = 12
"""
Step to stop profiling at.
"""


@dataclass
class NeoXArgsOther(NeoXArgsTemplate):
Expand Down
74 changes: 74 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@
)
from megatron.model.gpt2_model import cross_entropy

from pickle import dump
import os


def mup_weights_reinit(neox_args, model):
def has_method(o, name):
Expand Down Expand Up @@ -368,6 +371,8 @@ def forward_step(
return model.eval_batch(data_iterator, return_logits=return_logits)

# Get the batch.
if neox_args.memory_profiling and neox_args.it:
torch.cuda.nvtx.range_push(f"Get batch")
if timers is not None:
timers("batch generator").start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
Expand All @@ -376,7 +381,11 @@ def forward_step(

if timers is not None:
timers("batch generator").stop()
if neox_args.memory_profiling:
torch.cuda.nvtx.range_pop()

if neox_args.memory_profiling:
torch.cuda.nvtx.range_push(f"Forward pass")
outputs = model((tokens, position_ids, attention_mask), neox_args=neox_args)
if (
is_train
Expand All @@ -388,6 +397,8 @@ def forward_step(
loss = cross_entropy(
outputs, (labels, loss_mask), _fp16=neox_args.fp16_lm_cross_entropy
)
if neox_args.memory_profiling:
torch.cuda.nvtx.range_pop()
if return_logits:
return loss, outputs
return loss
Expand Down Expand Up @@ -628,6 +639,15 @@ def get_learning_rate_scheduler(optimizer, neox_args):


def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None):
"""Setup memory profiler"""
if neox_args.memory_profiling:
torch.cuda.memory._record_memory_history(
True,
# keep a maximum 100,000 alloc/free events from before the snapshot
trace_alloc_max_entries=100000,
trace_alloc_record_context=True,
)

"""Setup model and optimizer."""
model = get_model(neox_args=neox_args, use_cache=use_cache)
optimizer, param_groups = get_optimizer(model=model, neox_args=neox_args)
Expand Down Expand Up @@ -727,6 +747,13 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler)
reduced_loss = train_step_pipe(
neox_args=neox_args, timers=timers, model=model, data_iterator=data_iterator
)
if (
neox_args.memory_profiling
and neox_args.iteration >= neox_args.profile_step_start
and neox_args.iteration <= neox_args.profile_step_stop
and torch.distributed.get_rank() == 0
):
save_snapshot(neox_args)
else:
losses = []
for _ in range(neox_args.gradient_accumulation_steps):
Expand All @@ -742,6 +769,12 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler)
timers("forward").stop()
losses.append(loss)
# Calculate gradients, reduce across processes, and clip.
if (
neox_args.profiling
and neox_args.iteration >= neox_args.profile_step_start
and neox_args.iteration <= neox_args.profile_step_stop
):
torch.cuda.nvtx.range_push(f"Backward pass")
timers("backward").start()
backward_step(
neox_args=neox_args,
Expand All @@ -751,13 +784,38 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler)
loss=loss,
)
timers("backward").stop()
if (
neox_args.profiling
and neox_args.iteration >= neox_args.profile_step_start
and neox_args.iteration <= neox_args.profile_step_stop
):
torch.cuda.nvtx.range_pop()
# Update parameters.
if (
neox_args.profiling
and neox_args.iteration >= neox_args.profile_step_start
and neox_args.iteration <= neox_args.profile_step_stop
):
torch.cuda.nvtx.range_push(f"Optimizer step")
timers("optimizer").start()
if neox_args.deepspeed:
model.step()
else:
raise ValueError("Must be using deepspeed to run neox")
timers("optimizer").stop()
if (
neox_args.profiling
and neox_args.iteration >= neox_args.profile_step_start
and neox_args.iteration <= neox_args.profile_step_stop
):
torch.cuda.nvtx.range_pop()
if (
neox_args.profiling
and neox_args.iteration >= neox_args.profile_step_start
and neox_args.iteration <= neox_args.profile_step_stop
and torch.distributed.get_rank() == 0
):
save_snapshot(neox_args)
reduced_loss = {
"lm_loss": reduce_losses(losses).mean()
} # reduces losses across machines for logging
Expand Down Expand Up @@ -819,6 +877,8 @@ def train(
# to monitor if we've skipped many iterations in a row and trigger an early exit
overflow_monitor = OverflowMonitor(optimizer)
while iteration < neox_args.train_iters:
if neox_args.profile and iteration == neox_args.profile_step_start:
torch.cuda.cudart().cudaProfilerStart()
loss_dict, skipped_iter = train_step(
neox_args=neox_args,
timers=timers,
Expand All @@ -827,6 +887,8 @@ def train(
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
if neox_args.profile and iteration == neox_args.profile_step_stop:
torch.cuda.cudart().cudaProfilerStop()
iteration += 1
neox_args.iteration = iteration
if neox_args.precision == "fp16":
Expand Down Expand Up @@ -1033,3 +1095,15 @@ def evaluate_and_print_results(
print_rank_0("-" * length)
print_rank_0(string)
print_rank_0("-" * length)


def save_snapshot(neox_args):
assert (
neox_args.memory_profiling_path is not None
), "Must pass memory_profiling_path config arg to use profiling"
snapshot = torch.cuda.memory._snapshot()
snapshot_path = os.path.join(neox_args.memory_profiling_path)
jahatef marked this conversation as resolved.
Show resolved Hide resolved
if not os.path.exists(snapshot_path):
os.makedirs(snapshot_path)
with open(os.path.join(snapshot_path, "mem_snapshot.pickle"), "wb") as f:
dump(snapshot, f)