Skip to content

Commit

Permalink
moved compile helper to initialize
Browse files Browse the repository at this point in the history
  • Loading branch information
mshoeybi committed Dec 30, 2020
1 parent a495871 commit 242770d
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 15 deletions.
2 changes: 0 additions & 2 deletions megatron/data/bert_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,6 @@ def get_samples_mapping_(indexed_dataset,
print_rank_0(' > building sapmles index mapping for {} ...'.format(
name))
# First compile and then import.
from megatron.data.dataset_utils import compile_helper
compile_helper()
from megatron.data import helpers
samples_mapping = helpers.build_mapping(
indexed_dataset.doc_idx,
Expand Down
7 changes: 0 additions & 7 deletions megatron/data/blendable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,6 @@ def __init__(self, datasets, weights):
self.dataset_index = np.zeros(self.size, dtype=np.uint8)
self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)

if torch.distributed.get_rank() == 0:
from megatron.data.dataset_utils import compile_helper
compile_helper()
# Simple barrier
tmp = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(tmp, group=mpu.get_data_parallel_group())

from megatron.data import helpers
helpers.build_blending_indices(self.dataset_index,
self.dataset_sample_index,
Expand Down
2 changes: 0 additions & 2 deletions megatron/data/gpt2_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,6 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
start_time = time.time()
# Use C++ implementation for speed.
# First compile and then import.
from megatron.data.dataset_utils import compile_helper
compile_helper()
from megatron.data import helpers
assert doc_idx.dtype == np.int32
assert sizes.dtype == np.int32
Expand Down
4 changes: 0 additions & 4 deletions megatron/data/realm_dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,6 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
print_rank_0(' > building samples index mapping for {} ...'.format(
name))

# compile/bind the C++ helper code
from megatron.data.dataset_utils import compile_helper
compile_helper()

from megatron.data import helpers
mapping_array = helpers.build_blocks_mapping(
block_dataset.doc_idx,
Expand Down
10 changes: 10 additions & 0 deletions megatron/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@ def finish_mpu_init():

# Autoresume.
_init_autoresume()

# Compile dataset C++ code.
try:
from megatron.data import helpers
except:
if torch.distributed.get_rank() == 0:
from megatron.data.dataset_utils import compile_helper
compile_helper()
# Simple barrier
torch.distributed.barrier()

# No continuation function
return None
Expand Down

0 comments on commit 242770d

Please sign in to comment.