Skip to content

Commit

Permalink
add logits and is_done to return values of stream_tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
Samuel Weinbach committed May 6, 2021
1 parent a1ce559 commit 186406c
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions megatron/text_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,14 @@ def stream_tokens(neox_args, model, context_tokens: List[List[int]], eos_token_i
note: greedy decoding is used if temperature is 0.0, top_k is 0 and top_p is 0.0
yields: tokens (completions from model), token_generation_start_index (token index per batch item for the first generated token), token_generation_end_index (token index per batch item for the last generated token)
yields: (
tokens (completions from model),
token_generation_start_index (token index per batch item for the first generated token),
token_generation_end_index (token index per batch item for the last generated token),
logits (logits which are so far computed, zeros otherwise),
is_done (flag for each bach item indicating whether an eod token was generated)
)
* each iteration adds a generated token to the context_tokens
* output contains both context_tokens from input and generated tokens
* if batch items have different lengths, the iterator will start at the first completion and return the unchanged input context token otherwise
Expand Down Expand Up @@ -215,6 +222,8 @@ def stream_tokens(neox_args, model, context_tokens: List[List[int]], eos_token_i
token_index_to_generate + maximum_tokens -1
)

all_logits = torch.zeros((batch_size, neox_args.seq_length, neox_args.padded_vocab_size))

with torch.no_grad():
# initialize generation variables
state_is_done = torch.zeros([batch_size]).byte().cuda()
Expand Down Expand Up @@ -242,6 +251,11 @@ def stream_tokens(neox_args, model, context_tokens: List[List[int]], eos_token_i

logits, layer_past = forward_model(neox_args, model, model_inputs)

if recompute or (token_index_to_generate == first_token_index_to_generate):
all_logits[:, :token_index_to_generate, :] = logits[:, :token_index_to_generate, :]
else:
all_logits[:, token_index_to_generate - 1, :] = logits[:, 0, :] # only one token will is computed

# TODO: we are replicating computation across all machines here, which is really unecessary,
# we should probably just do it on one then communicate the results?

Expand Down Expand Up @@ -270,7 +284,7 @@ def stream_tokens(neox_args, model, context_tokens: List[List[int]], eos_token_i
token_index_to_generate += 1


yield context_tokens, token_generation_start_index, token_generation_end_index
yield context_tokens, token_generation_start_index, token_generation_end_index, all_logits, state_is_done.bool()
if torch.all(state_is_done): break

def generate_samples_from_prompt(neox_args, model, text: Union[List[str], str], eos_token_id: int = None,
Expand Down Expand Up @@ -345,7 +359,7 @@ def generate_samples_from_prompt(neox_args, model, text: Union[List[str], str],
if terminate_runs == 1:
return generated_texts

for batch_context_tokens, batch_token_generation_start_index, batch_token_generation_end_index in stream_tokens(
for batch_context_tokens, batch_token_generation_start_index, batch_token_generation_end_index, batch_logits, is_done in stream_tokens(
neox_args=neox_args,
model=model,
context_tokens=[context_tokens],
Expand All @@ -361,7 +375,8 @@ def generate_samples_from_prompt(neox_args, model, text: Union[List[str], str],
batch_context_tokens = batch_context_tokens.cpu().numpy().tolist()
batch_token_generation_start_index = batch_token_generation_start_index.cpu().numpy().tolist()
batch_token_generation_end_index = batch_token_generation_end_index.cpu().numpy().tolist()
for tokens, start_index, end_index in zip(batch_context_tokens, batch_token_generation_start_index, batch_token_generation_end_index):
batch_is_done = is_done.cpu().numpy().tolist()
for tokens, start_index, end_index, is_done in zip(batch_context_tokens, batch_token_generation_start_index, batch_token_generation_end_index, batch_is_done):
if end_index >= start_index:
generated_tokens = tokens[start_index:end_index + 1]
try:
Expand All @@ -373,13 +388,12 @@ def generate_samples_from_prompt(neox_args, model, text: Union[List[str], str],
else:
generated_tokens = list()
message = "WARNING: text generation did not start; try different batching or adjust parameters"
is_finished = (end_index < neox_args.seq_length - 1) and end_index > -1
if is_mp_rank_0():
data = {
'context': raw_text,
'text': generated_text,
'length': len(generated_tokens),
'finished': is_finished,
'finished': is_done,
'message': message,
'duration_seconds': float(time.time() - start_time)
}
Expand Down Expand Up @@ -556,7 +570,7 @@ def generate_samples_interactive(neox_args, model, maximum_tokens: int = 64, eos
terminate_runs = broadcast_terminate_signal(terminate_runs)
if terminate_runs == 1:
return
for batch_context_tokens, batch_token_generation_start_index, batch_token_generation_end_index in stream_tokens(
for batch_context_tokens, batch_token_generation_start_index, batch_token_generation_end_index, batch_logits, is_done in stream_tokens(
neox_args=neox_args,
model=model,
context_tokens=[context_tokens],
Expand Down

0 comments on commit 186406c

Please sign in to comment.