Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory profiling #1153

Merged
merged 32 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
f5fd54c
Fixes distributed tests, and skips tests that are broken.
jahatef Feb 14, 2024
4a4a934
Merge branch 'main' of github.com:EleutherAI/gpt-neox into main
jahatef Feb 18, 2024
f63593b
memory profiling for gpt-neox. Only works for pp=0, pp=1+ needs DS co…
jahatef Feb 20, 2024
4ed9d42
Update NeoXArgs docs automatically
invalid-email-address Feb 21, 2024
89efc48
adds memory profiling for pipeline parallel
jahatef Feb 21, 2024
95f31f0
Merge branch 'memory_profiling' of github.com:EleutherAI/gpt-neox int…
jahatef Feb 21, 2024
9551afe
Update NeoXArgs docs automatically
invalid-email-address Feb 21, 2024
4135743
fix spacing
jahatef Feb 21, 2024
7b0cdaf
Merge branch 'memory_profiling' of github.com:EleutherAI/gpt-neox int…
jahatef Feb 21, 2024
45aea7a
Update NeoXArgs docs automatically
invalid-email-address Feb 21, 2024
3bff276
fix spacing again
jahatef Feb 21, 2024
2452697
Merge branch 'memory_profiling' of github.com:EleutherAI/gpt-neox int…
jahatef Feb 21, 2024
d9c7e4b
Update NeoXArgs docs automatically
invalid-email-address Feb 21, 2024
7af1c9d
get rid of unwanted changes
jahatef Feb 21, 2024
47f76af
Merge branch 'memory_profiling' of github.com:EleutherAI/gpt-neox int…
jahatef Feb 21, 2024
7994909
Update NeoXArgs docs automatically
invalid-email-address Feb 21, 2024
a2893db
get rid of file
jahatef Feb 21, 2024
db8b70b
Merge branch 'memory_profiling' of github.com:EleutherAI/gpt-neox int…
jahatef Feb 21, 2024
80b1e30
Update NeoXArgs docs automatically
invalid-email-address Feb 21, 2024
7467632
Merge branch 'main' into memory_profiling
Quentin-Anthony Feb 21, 2024
5c51c43
Update NeoXArgs docs automatically
invalid-email-address Feb 21, 2024
20bc950
add nsight systems support
jahatef Feb 21, 2024
fd0b471
Merge branch 'memory_profiling' of github.com:EleutherAI/gpt-neox int…
jahatef Feb 21, 2024
87bca9d
remove tests changes again
jahatef Feb 21, 2024
65ce859
Update NeoXArgs docs automatically
invalid-email-address Feb 21, 2024
49cf95d
add tests
jahatef Feb 21, 2024
ab8126d
Update NeoXArgs docs automatically
invalid-email-address Feb 21, 2024
ae2c61d
Update training.py
jahatef Feb 21, 2024
21eba94
Update NeoXArgs docs automatically
invalid-email-address Feb 21, 2024
edfcdaf
Add assertion message
Quentin-Anthony Feb 21, 2024
8669123
pre-commit
Quentin-Anthony Feb 21, 2024
80aa4cb
Update NeoXArgs docs automatically
invalid-email-address Feb 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fixes distributed tests, and skips tests that are broken.
  • Loading branch information
jahatef committed Feb 14, 2024
commit f5fd54c19e457919755f90be014bdb5d8c128b60
429 changes: 326 additions & 103 deletions tests/common.py

Large diffs are not rendered by default.

86 changes: 86 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

# tests directory-specific settings - this file is run automatically by pytest before any tests are run

import sys
import pytest
import os
from os.path import abspath, dirname, join
import torch
import warnings

# Set this environment variable for the T5 inference unittest(s) (e.g. google/t5-v1_1-small)
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'

# allow having multiple repository checkouts and not needing to remember to rerun
# 'pip install -e .[dev]' when switching between checkouts and running tests.
git_repo_path = abspath(join(dirname(dirname(__file__)), "src"))
sys.path.insert(1, git_repo_path)


def pytest_configure(config):
# config.option.color = "yes"
# config.option.durations = 0
# config.option.durations_min = 1
config.option.verbose = True


def pytest_addoption(parser):
parser.addoption("--torch_ver", default=None, type=str)
parser.addoption("--cuda_ver", default=None, type=str)


def validate_version(expected, found):
version_depth = expected.count('.') + 1
found = '.'.join(found.split('.')[:version_depth])
return found == expected


@pytest.fixture(scope="session", autouse=True)
def check_environment(pytestconfig):
expected_torch_version = pytestconfig.getoption("torch_ver")
expected_cuda_version = pytestconfig.getoption("cuda_ver")
if expected_torch_version is None:
warnings.warn(
"Running test without verifying torch version, please provide an expected torch version with --torch_ver")
elif not validate_version(expected_torch_version, torch.__version__):
pytest.exit(
f"expected torch version {expected_torch_version} did not match found torch version {torch.__version__}",
returncode=2)
if expected_cuda_version is None:
warnings.warn(
"Running test without verifying cuda version, please provide an expected cuda version with --cuda_ver")
elif not validate_version(expected_cuda_version, torch.version.cuda):
pytest.exit(
f"expected cuda version {expected_cuda_version} did not match found cuda version {torch.version.cuda}",
returncode=2)


# Override of pytest "runtest" for DistributedTest class
# This hook is run before the default pytest_runtest_call
@pytest.hookimpl(tryfirst=True)
def pytest_runtest_call(item):
# We want to use our own launching function for distributed tests
if getattr(item.cls, "is_dist_test", False):
dist_test_class = item.cls()
dist_test_class(item._request)
item.runtest = lambda: True # Dummy function so test is not run twice


# We allow DistributedTest to reuse distributed environments. When the last
# test for a class is run, we want to make sure those distributed environments
# are destroyed.
def pytest_runtest_teardown(item, nextitem):
if getattr(item.cls, "reuse_dist_env", False) and not nextitem:
dist_test_class = item.cls()
for num_procs, pool in dist_test_class._pool_cache.items():
dist_test_class._close_pool(pool, num_procs, force=True)

@pytest.hookimpl(tryfirst=True)
def pytest_fixture_setup(fixturedef, request):
if getattr(fixturedef.func, "is_dist_fixture", False):
dist_fixture_class = fixturedef.func()
dist_fixture_class(request)
4 changes: 0 additions & 4 deletions tests/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .test_model_instantiation import run_test_model_instantiation
from .test_model_train import run_train_test
from .test_model_checkpoint import run_checkpoint_test
95 changes: 46 additions & 49 deletions tests/model/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import pytest
from tests.common import (
distributed_test,
DistributedTest,
clear_test_dirs,
model_setup,
binary,
Expand Down Expand Up @@ -73,60 +73,57 @@ def test_train(param_dict):
d = tempfile.mkdtemp()
param_dict["save"] = d

@distributed_test(world_size=2)
def wrapper():
run_checkpoint_test(param_dict=param_dict)
t1 = test_run_checkpoint_test_class()
t1.run_checkpoint_test(param_dict=param_dict)

wrapper()
class test_run_checkpoint_test_class(DistributedTest):
def run_checkpoint_test(yaml_list=None, param_dict=None):

from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint

def run_checkpoint_test(yaml_list=None, param_dict=None):

from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint

model, optimizer, lr_scheduler, args_loaded = model_setup(
yaml_list, param_dict, clear_data=True
)

# save model checkpoint
save_checkpoint(
neox_args=args_loaded,
iteration=42,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)

# reload model from checkpoint
(
reloaded_model,
reloaded_optimizer,
reloaded_lr_scheduler,
args_reloaded,
) = model_setup(yaml_list, param_dict, clear_data=False)
iteration = load_checkpoint(
neox_args=args_reloaded,
model=reloaded_model,
optimizer=reloaded_optimizer,
lr_scheduler=reloaded_lr_scheduler,
)
model, optimizer, lr_scheduler, args_loaded = model_setup(
yaml_list, param_dict, clear_data=True
)

# ensure same checkpoint is loaded
assert (
iteration == 42
), "run_checkpoint_test() iteration loaded from checkpoint correct"
# save model checkpoint
save_checkpoint(
neox_args=args_loaded,
iteration=42,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)

# check all weight groups are the same
for idx, ((n1, p1), (n2, p2)) in enumerate(
zip(
list(model.module.named_parameters()),
list(reloaded_model.module.named_parameters()),
# reload model from checkpoint
(
reloaded_model,
reloaded_optimizer,
reloaded_lr_scheduler,
args_reloaded,
) = model_setup(yaml_list, param_dict, clear_data=False)
iteration = load_checkpoint(
neox_args=args_reloaded,
model=reloaded_model,
optimizer=reloaded_optimizer,
lr_scheduler=reloaded_lr_scheduler,
)
):
assert n1 == n2
params_equal = (p1 == p2).all().item()
assert params_equal, "run_checkpoint_test() params equal: " + str(n1)

# ensure same checkpoint is loaded
assert (
iteration == 42
), "run_checkpoint_test() iteration loaded from checkpoint correct"

# check all weight groups are the same
for idx, ((n1, p1), (n2, p2)) in enumerate(
zip(
list(model.module.named_parameters()),
list(reloaded_model.module.named_parameters()),
)
):
assert n1 == n2
params_equal = (p1 == p2).all().item()
assert params_equal, "run_checkpoint_test() params equal: " + str(n1)


if __name__ == "__main__":
Expand Down
78 changes: 38 additions & 40 deletions tests/model/test_model_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import os
import pytest
from tests.common import distributed_test, model_setup, parametrize
from tests.common import DistributedTest, model_setup, parametrize

PARAMS_TO_TEST = {
"pipe_parallel_size,model_parallel_size,world_size": [
Expand Down Expand Up @@ -67,47 +67,45 @@
@pytest.mark.skip
@pytest.mark.parametrize("param_dict", parameters, ids=names)
def test_train(param_dict):
@distributed_test(world_size=param_dict.pop("world_size", 2))
def wrapper():
run_generate_test(param_dict=param_dict, prompt=param_dict.pop("prompt"))
t1 = run_generate_test_class()
t1.run_generate_test(param_dict, param_dict.pop("prompt"))

wrapper()
class run_generate_test_class(DistributedTest):
world_size = 2
def run_generate_test(param_dict, prompt):
from megatron.text_generation_utils import generate_samples_from_prompt
from megatron.utils import is_mp_rank_0

fixed_params = {
"num_samples": 3,
"maximum_tokens": 50,
"make_vocab_size_divisible_by": 2,
"sample_output_file": "test_sample_output.txt",
"checkpoint_activations": False,
"partition_activations": False,
"no_load_optim": True,
}

def run_generate_test(param_dict, prompt):
from megatron.text_generation_utils import generate_samples_from_prompt
from megatron.utils import is_mp_rank_0
param_dict.update(fixed_params)
# TODO: we don't need to reinstantiate the model every time if we're only changing sampling settings - should be a workaround for this
model, _, _, args_loaded = model_setup(None, param_dict, clear_data=True)
model.eval()

fixed_params = {
"num_samples": 3,
"maximum_tokens": 50,
"make_vocab_size_divisible_by": 2,
"sample_output_file": "test_sample_output.txt",
"checkpoint_activations": False,
"partition_activations": False,
"no_load_optim": True,
}
prompts = [prompt for _ in range(args_loaded.num_samples)]
output = generate_samples_from_prompt(
neox_args=args_loaded,
model=model,
text=prompts,
maximum_tokens=args_loaded.maximum_tokens,
recompute=False,
temperature=args_loaded.temperature,
top_k=args_loaded.top_k,
top_p=args_loaded.top_p,
)

param_dict.update(fixed_params)
# TODO: we don't need to reinstantiate the model every time if we're only changing sampling settings - should be a workaround for this
model, _, _, args_loaded = model_setup(None, param_dict, clear_data=True)
model.eval()

prompts = [prompt for _ in range(args_loaded.num_samples)]
output = generate_samples_from_prompt(
neox_args=args_loaded,
model=model,
text=prompts,
maximum_tokens=args_loaded.maximum_tokens,
recompute=False,
temperature=args_loaded.temperature,
top_k=args_loaded.top_k,
top_p=args_loaded.top_p,
)

# outputs only get generated on mp rank 0
if is_mp_rank_0():
assert len(output) == len(prompts)
for prompt, out in zip(prompts, output):
assert prompt == out["context"]
assert len(out["text"]) > 0
# outputs only get generated on mp rank 0
if is_mp_rank_0():
assert len(output) == len(prompts)
for prompt, out in zip(prompts, output):
assert prompt == out["context"]
assert len(out["text"]) > 0
53 changes: 23 additions & 30 deletions tests/model/test_model_instantiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch
import os
from tests.common import (
distributed_test,
DistributedTest,
model_setup,
clear_test_dirs,
parametrize,
Expand Down Expand Up @@ -74,18 +74,13 @@
PARAMS_TO_TEST, max_tests=int(os.getenv("MAX_TESTCASES", 50)), seed=None
)


@pytest.mark.xfail(
reason="Either fused kernels are not installed, or Cannot re-initialize CUDA in forked subprocess'"
)
@pytest.mark.parametrize("param_dict", parameters, ids=names)
def test_instantiate(param_dict):
@distributed_test(world_size=param_dict.pop("world_size", 2))
def wrapper():
run_test_model_instantiation(param_dict=param_dict)

wrapper()

t1 = test_instantiate_optimizers_class()
t1.run_test_model_instantiation(param_dict)

OPTIMIZER_PARAMS = {
"optimizer": [
Expand All @@ -102,30 +97,28 @@ def wrapper():
OPTIMIZER_PARAMS, max_tests=int(os.getenv("MAX_TESTCASES", 50)), seed=None
)


@pytest.mark.xfail(
reason="Either fused kernels are not installed, or 'Cannot re-initialize CUDA in forked subprocess'"
)
@pytest.mark.parametrize("param_dict", opt_params, ids=opt_name)
def test_instantiate_optimizers(param_dict):
@distributed_test(world_size=2)
def wrapper():
run_test_model_instantiation(param_dict=param_dict)

wrapper()


def run_test_model_instantiation(yaml_list=None, param_dict=None):
from deepspeed.runtime.pipe.engine import PipelineEngine, DeepSpeedEngine

model, optimizer, lr_scheduler, args_loaded = model_setup(yaml_list, param_dict)
if args_loaded.pipe_parallel_size < 2:
assert isinstance(model, DeepSpeedEngine), "test model instantiation " + str(
yaml_list
)
else:
assert isinstance(model, PipelineEngine), "test model instantiation " + str(
yaml_list
)
if torch.distributed.get_world_size() == 1 or torch.distributed.get_rank() == 0:
clear_test_dirs()
t1 = test_instantiate_optimizers_class()
t1.run_test_model_instantiation(param_dict)

class test_instantiate_optimizers_class(DistributedTest):
world_size = 2

def run_test_model_instantiation(yaml_list=None, param_dict=None):
from deepspeed.runtime.pipe.engine import PipelineEngine, DeepSpeedEngine

model, optimizer, lr_scheduler, args_loaded = model_setup(yaml_list, param_dict)
if args_loaded.pipe_parallel_size < 2:
assert isinstance(model, DeepSpeedEngine), "test model instantiation " + str(
yaml_list
)
else:
assert isinstance(model, PipelineEngine), "test model instantiation " + str(
yaml_list
)
if torch.distributed.get_world_size() == 1 or torch.distributed.get_rank() == 0:
clear_test_dirs()
Loading
Loading