diff --git a/megatron/text_generation_utils.py b/megatron/text_generation_utils.py index 059a30dc2..a8dc3f0c0 100644 --- a/megatron/text_generation_utils.py +++ b/megatron/text_generation_utils.py @@ -459,10 +459,10 @@ def generate_samples_from_prompt( "\nPlease give smaller context (e.g. half of the " "max sequence length)!", ) - if not is_mp_rank_0(): - context_tokens = neox_args.tokenizer.tokenize("EMPTY TEXT") - context_length = len(context_tokens) - terminate_runs = 0 + # if not is_mp_rank_0(): + # context_tokens = neox_args.tokenizer.tokenize("EMPTY TEXT") + # context_length = len(context_tokens) + # terminate_runs = 0 terminate_runs = broadcast_terminate_signal(terminate_runs) if terminate_runs == 1: diff --git a/tests/model/test_model_generation.py b/tests/model/test_model_generation.py index 0c1ed0da1..7fe991905 100644 --- a/tests/model/test_model_generation.py +++ b/tests/model/test_model_generation.py @@ -50,7 +50,7 @@ ) -@pytest.mark.skip +#@pytest.mark.skip @pytest.mark.parametrize("param_dict", parameters, ids=names) def test_train(param_dict): @distributed_test(world_size=param_dict.pop("world_size", 2))