Skip to content

Commit

Permalink
Enable multiline prompts by setting a custom prompt end (EleutherAI#754)
Browse files Browse the repository at this point in the history
* add multiline prompt support to input-file mode

* reroll the input in the interactive mode

* add comments about customizable prompt ends

* reformat files modified in EleutherAI#754 with Black

* add comments about 'input' stripping '\n's

* Formatting and spelling

Co-authored-by: Quentin Anthony <[email protected]>
  • Loading branch information
kabachuha and Quentin-Anthony committed Dec 25, 2022
1 parent 18c1cbe commit 93f4efd
Show file tree
Hide file tree
Showing 10 changed files with 58 additions and 32 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,10 @@ To reproduce our evaluation numbers on, for example, TriviaQA and PIQA use:

You can add an arbitrary list of evaluation tasks here, for details of all tasks available, see [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness).

For more details on each entry point, see the [Training and Finetuning](#training-and-finetuning), [Inference](#inference) and [Evaluation](#evaluation)
For more details on each entry point, see the [Training and Finetuning](#training-and-finetuning), [Inference](#inference) and [Evaluation](#evaluation)
# Configuration

GPT-NeoX parameters are defined in a YAML configuration file which is passed to the deepy.py launcher. We have provided some example .yaml files in [configs](./configs/), including one for GPT-NeoX-20B, and example configuration files for other model sizes.
GPT-NeoX parameters are defined in a YAML configuration file which is passed to the deepy.py launcher. We have provided some example .yaml files in [configs](./configs/), including one for GPT-NeoX-20B, and example configuration files for other model sizes.

These files are generally complete, but non-optimal. For example, depending on your specific GPU configuration, you may need to change some settings such as `pipe-parallel-size`, `model-parallel-size` to increase or decrease the degree of parallelisation, `train_micro_batch_size_per_gpu` or `gradient-accumulation-steps` to modify batch size related settings, or the `zero_optimization` dict to modify how optimizer states are parallelised across workers.

Expand Down
9 changes: 4 additions & 5 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,7 @@ Text Generation arguments

- **eval_results_prefix**: str

Default =
Default =

prefix to which to save evaluation results - final fp will be {eval_results_prefix}_eval_results_yy-mm-dd-HH-MM.json

Expand Down Expand Up @@ -1155,10 +1155,10 @@ Training Arguments

Acts as a multiplier on either the "log" or "linear" checkpoint spacing.

With `checkpoint-scale="linear"`, `checkpoint-factor=20`, and `train-iters=100`, checkpoints will be saved at
With `checkpoint-scale="linear"`, `checkpoint-factor=20`, and `train-iters=100`, checkpoints will be saved at
steps [20, 40, 60, 80, 100].

With `checkpoint-scale="log"`, `checkpoint-factor=2`, and `train-iters=100`, checkpoints will be saved at
With `checkpoint-scale="log"`, `checkpoint-factor=2`, and `train-iters=100`, checkpoints will be saved at
steps [1, 2, 4, 8, 16, 32, 64, 100].

Note that the last checkpoint step is always saved.
Expand Down Expand Up @@ -1572,7 +1572,7 @@ Args for deepspeed config

Default = None





Expand Down Expand Up @@ -1706,4 +1706,3 @@ Args for deepspeed runner (deepspeed.launcher.runner).
Default = None

Adds a `--comment` to the DeepSpeed launch command. In DeeperSpeed this is passed on to the SlurmLauncher as well. Sometime necessary for cluster rules, or so I've heard.

1 change: 1 addition & 0 deletions configs/text_generation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# Params for all
"maximum_tokens": 102,
"prompt_end": "\n",
"temperature": 1.0,
"top_p": 0.0,
"top_k": 0,
Expand Down
2 changes: 2 additions & 0 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def main():
input_file=neox_args.sample_input_file,
output_file=neox_args.sample_output_file,
maximum_tokens=neox_args.maximum_tokens,
prompt_end=neox_args.prompt_end,
recompute=neox_args.recompute,
temperature=neox_args.temperature,
top_k=neox_args.top_k,
Expand All @@ -75,6 +76,7 @@ def main():
recompute=neox_args.recompute,
temperature=neox_args.temperature,
maximum_tokens=neox_args.maximum_tokens,
prompt_end=neox_args.prompt_end,
top_k=neox_args.top_k,
top_p=neox_args.top_p,
)
Expand Down
6 changes: 3 additions & 3 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,15 +663,15 @@ def forward(self, x, attention_mask, layer_past=None):
# to save communication time (we can do a single allreduce after we add mlp / attn outputs).
# due to a bug, the two layernorms are not tied in GPT-NeoX-20B. This is non-desirable, but
# we preserve the functionality for backwards compatibility

residual = x
# applies the correct normalization depending on if the norms are tied
if self.gpt_j_tied:
x1, x2 = self.input_layernorm(x), self.post_attention_layernorm(x)
else:
x = self.input_layernorm(x)
x1, x2 = x, x

# attention operator
attention_output, attention_bias = self.attention(
x1, attention_mask, layer_past=layer_past
Expand Down Expand Up @@ -699,7 +699,7 @@ def forward(self, x, attention_mask, layer_past=None):
)

# output = (x + attn(ln(x)) + mlp(ln(x))
output = residual + self.reduce(output)
output = residual + self.reduce(output)
else:
# pseudocode:
# x = x + attn(ln1(x))
Expand Down
16 changes: 10 additions & 6 deletions megatron/neox_arguments/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,18 +741,18 @@ def calculate_derived(self):
save_iters = set(self.extra_save_iters)
else:
save_iters = set()
step = self.checkpoint_factor # don't save step 0 or 1

step = self.checkpoint_factor # don't save step 0 or 1
while step < self.train_iters:
save_iters.add(step)
if self.checkpoint_scale == "log":
step *= self.checkpoint_factor
elif self.checkpoint_scale == "linear":
step += self.checkpoint_factor

save_iters = list(save_iters)
save_iters.sort()

self.update_values(
{
"save_iters": save_iters,
Expand Down Expand Up @@ -848,7 +848,7 @@ def calculate_derived(self):
if self.sparsity_config is None:
# Can't have a default value as an empty dict so need to set it here
self.update_value("sparsity_config", {})

# Adding equal dataset weights if none are provided
if self.train_data_paths and (self.train_data_weights is None):
self.train_data_weights = [1.0] * len(self.train_data_paths)
Expand Down Expand Up @@ -947,7 +947,11 @@ def validate_values(self):
raise ValueError(error_message)
return False

if self.save is not None and self.checkpoint_factor is None and self.extra_save_iters is None:
if (
self.save is not None
and self.checkpoint_factor is None
and self.extra_save_iters is None
):
error_message = (
self.__class__.__name__
+ ".validate_values() checkpoint_factor or extra_save_iters must be defined if save is defined"
Expand Down
11 changes: 8 additions & 3 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ class NeoXArgsModel(NeoXArgsTemplate):
x = ln(x)
x = x + attn(x) + mlp(x)
"""

gpt_j_tied: bool = False
"""
If false, we use
Expand Down Expand Up @@ -785,10 +785,10 @@ class NeoXArgsTraining(NeoXArgsTemplate):
"""
Acts as a multiplier on either the "log" or "linear" checkpoint spacing.
With `checkpoint-scale="linear"`, `checkpoint-factor=20`, and `train-iters=100`, checkpoints will be saved at
With `checkpoint-scale="linear"`, `checkpoint-factor=20`, and `train-iters=100`, checkpoints will be saved at
steps [20, 40, 60, 80, 100].
With `checkpoint-scale="log"`, `checkpoint-factor=2`, and `train-iters=100`, checkpoints will be saved at
With `checkpoint-scale="log"`, `checkpoint-factor=2`, and `train-iters=100`, checkpoints will be saved at
steps [1, 2, 4, 8, 16, 32, 64, 100].
Note that the last checkpoint step is always saved.
Expand Down Expand Up @@ -1008,6 +1008,11 @@ class NeoXArgsTextgen(NeoXArgsTemplate):
maximum number of tokens to be generated
"""

prompt_end: str = "\n"
"""
a single prompt's end. Defaults to newline
"""

sample_input_file: str = None
"""
Get input from file instead of interactive mode, each line is an input.
Expand Down
24 changes: 21 additions & 3 deletions megatron/text_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,7 @@ def generate_samples_input_from_file(
output_file=None,
eos_token_id: int = None,
maximum_tokens: int = 64,
prompt_end: str = "\n",
recompute: bool = False,
temperature: float = 0.0,
top_k: int = 0,
Expand All @@ -570,6 +571,7 @@ def generate_samples_input_from_file(
eos_token_id: end of text token at which completion is terminated, even if max_tokes count has not been reached
maximum_tokens: maximum number of tokens to be generated
prompt_end: end of a single input prompt. Defaults to newline character '\n'. Other prompt-end sequences may be useful when generating indent-aware completions (e.g. code)
recompute: flag indicating whether a cache is used for already forwarded tokens (true) or whether all tokens are recomputed at every iteration (false)
Expand All @@ -592,8 +594,9 @@ def generate_samples_input_from_file(
print_rank_0(
"generate_samples_input_from_file() loading input from {}".format(input_file)
)
with open(input_file, "r") as f:
prompts = f.readlines()
with open(input_file, "r", encoding="utf-8") as f:
prompts = f.read()
prompts = prompts.split(prompt_end)
prompts = [p.strip() for p in prompts]
prompts = [p for p in prompts if len(p) > 0]
print_rank_0(
Expand Down Expand Up @@ -654,6 +657,7 @@ def generate_samples_unconditional(
eos_token_id: end of text token at which completion is terminated, even if max_tokes count has not been reached
maximum_tokens: maximum number of tokens to be generated
prompt_end: end of a single input prompt. Defaults to newline character '\n'. Other prompt-end sequences may be useful when generating indent-aware completions (e.g. code). The interactive mode will reroll the user-input request until the stop-char is met
recompute: flag indicating whether a cache is used for already forwarded tokens (true) or whether all tokens are recomputed at every iteration (false)
Expand Down Expand Up @@ -699,6 +703,7 @@ def generate_samples_interactive(
neox_args,
model,
maximum_tokens: int = 64,
prompt_end: str = "\n",
eos_token_id: int = None,
recompute: bool = False,
temperature: float = 0.0,
Expand Down Expand Up @@ -738,7 +743,20 @@ def generate_samples_interactive(

if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
os.system("clear")
raw_text = input("Context prompt >>> ")
raw_text = ""
while True:
current_input = input("Context prompt >>> ")
if (
prompt_end == "\n"
): # we need to handle '\n' case as 'input' strips it and leads to lines being squashed
raw_text += current_input
break
if prompt_end in current_input:
raw_text += current_input.split(prompt_end)[0]
break
raw_text += (
current_input + "\n"
) # re-add newline since we stripped it on input
context_tokens = neox_args.tokenizer.tokenize(raw_text)
if len(context_tokens) == 0:
context_tokens = [neox_args.tokenizer.eod]
Expand Down
7 changes: 2 additions & 5 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def pretrain(neox_args):
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)

iteration = train(
neox_args=neox_args,
timers=timers,
Expand Down Expand Up @@ -615,10 +615,7 @@ def train(
)

# Checkpointing
if (
neox_args.save
and iteration in neox_args.save_iters
):
if neox_args.save and iteration in neox_args.save_iters:
save_checkpoint(
neox_args=neox_args,
iteration=iteration,
Expand Down
10 changes: 5 additions & 5 deletions tools/convert_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,16 @@ def __init__(self, neox_config):
1 # pad defaulting to 1. follows convention from GPT-NeoX-20b tokenizer
)


# TODO: change the default value here based on discussion regarding `gpt_j_tied` config parameter's default
use_tied_lns = get_key(neox_config, 'gpt-j-tied', False)
use_tied_lns = get_key(neox_config, "gpt-j-tied", False)

if use_tied_lns:
raise NotImplementedError(
"""ERROR: Huggingface Transformers does not yet support a single shared layernorm
"""ERROR: Huggingface Transformers does not yet support a single shared layernorm
per transformer block for GPT-NeoX models trained w/ GPT-J parallel residuals.
See https://github.com/EleutherAI/gpt-neox/pull/481 for further details.""")

See https://github.com/EleutherAI/gpt-neox/pull/481 for further details."""
)

# set all config values.
hf_config = GPTNeoXConfig(
vocab_size=args.padded_vocab_size,
Expand Down

0 comments on commit 93f4efd

Please sign in to comment.