Skip to content

Commit

Permalink
get rid of reset_position_ids / reset_attention_mask (never used...)
Browse files Browse the repository at this point in the history
  • Loading branch information
sdtblck committed Aug 25, 2021
1 parent 157051e commit 0fb2c09
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 55 deletions.
10 changes: 0 additions & 10 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,16 +433,6 @@ class NeoXArgsOther(NeoXArgsTemplate):
Probability of producing a short sequence.
"""

reset_position_ids: bool = False
"""
Reset posistion ids after end-of-document token.
"""

reset_attention_mask: bool = False
"""
Reset self attention mask after end-of-document token.
"""

eod_mask_loss: bool = False
"""
Mask loss for the end of document tokens.
Expand Down
14 changes: 6 additions & 8 deletions megatron/text_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_batch(neox_args, context_tokens: torch.Tensor):
"""
Generate batch from context tokens. Attention mask and position ids are created. Returned tensors will be on CUDA.
neox_args: NeoXArgs with tokenizer, reset_position_ids, reset_attention_mask and eod_mask_loss
neox_args: NeoXArgs.
context_tokens: torch tensor with dimensions [batch, context_size]
returns: tuple of torch tensors (tokens, attention_mask, position_ids) on CUDA
Expand All @@ -47,8 +47,6 @@ def get_batch(neox_args, context_tokens: torch.Tensor):
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens,
neox_args.tokenizer.eod,
neox_args.reset_position_ids,
neox_args.reset_attention_mask,
neox_args.eod_mask_loss)
return tokens, attention_mask, position_ids

Expand Down Expand Up @@ -166,7 +164,7 @@ def stream_tokens(neox_args, model, context_tokens: List[List[int]], eos_token_i
"""
iterator producing text completions
neox_args: NeoXArgs with tokenizer, reset_position_ids, reset_attention_mask and eod_mask_loss
neox_args: NeoXArgs.
model: a Megatron model.
context_tokens: the prompt to complete; unpadded list of lists of tokens ids
Expand Down Expand Up @@ -332,7 +330,7 @@ def generate_samples_from_prompt(neox_args, model, text: Union[List[str], str],
"""
Generates samples from raw text and returns them in a dictionary.
neox_args: NeoXArgs with tokenizer, reset_position_ids, reset_attention_mask and eod_mask_loss
neox_args: NeoXArgs.
model: a Megatron model
text: either a single prompt (str) or a list of prompts (List[str]).
Expand Down Expand Up @@ -450,7 +448,7 @@ def generate_samples_input_from_file(neox_args, model, input_file, output_file=N
Reads prompts from neox_args.sample_input_file and writes completions to neox_args.sample_output_file
neox_args: NeoXArgs with tokenizer, reset_position_ids, reset_attention_mask and eod_mask_loss
neox_args: NeoXArgs.
model: a Megatron model
input_file: path to input file. Each line in the input file will be treated as separate prompt. The line break at the end of the line is not included in the prompt.
Expand Down Expand Up @@ -514,7 +512,7 @@ def generate_samples_unconditional(neox_args, model, number_of_samples: int = 10
"""
Generates samples unconditionially (no prompt) and yields them in a dictionary.
neox_args: NeoXArgs with tokenizer, reset_position_ids, reset_attention_mask and eod_mask_loss
neox_args: NeoXArgs.
model: a Megatron model
number_of_samples (default 10): number of unconditional samples to be generated
Expand Down Expand Up @@ -567,7 +565,7 @@ def generate_samples_interactive(neox_args, model, maximum_tokens: int = 64, eos
"""
Generates samples unconditionially (no prompt) and yields them in a dictionary.
neox_args: NeoXArgs with tokenizer, reset_position_ids, reset_attention_mask and eod_mask_loss
neox_args: NeoXArgs.
model: a Megatron model
maximum_tokens: maximum number of tokens to be generated
Expand Down
4 changes: 1 addition & 3 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,10 @@ def _get_batch(neox_args, tokenizer, keys, data, datatype):
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()

# Get the masks and postition ids.
# Get the masks and position ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
neox_args.reset_position_ids,
neox_args.reset_attention_mask,
neox_args.eod_mask_loss)

return tokens, labels, loss_mask, attention_mask, position_ids
Expand Down
37 changes: 3 additions & 34 deletions megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,16 @@ def report_memory(name):

def get_ltor_masks_and_position_ids(data,
eod_token,
reset_position_ids,
reset_attention_mask,
eod_mask_loss):
eod_mask_loss=False):
"""Build masks and position id for left to right model."""

# Extract batch size and sequence length.
batch_size, seq_length = data.size()

# Attention mask (lower triangular).
if reset_attention_mask:
att_mask_batch = batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(torch.ones(
(att_mask_batch, seq_length, seq_length), device=data.device)).view(
att_mask_batch, 1, seq_length, seq_length)
(1, seq_length, seq_length), device=data.device)).view(
1, 1, seq_length, seq_length)

# Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
Expand All @@ -89,31 +83,6 @@ def get_ltor_masks_and_position_ids(data,
position_ids = torch.arange(seq_length, dtype=torch.long,
device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# We need to clone as the ids will be modifed based on batch index.
if reset_position_ids:
position_ids = position_ids.clone()

if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(batch_size):

# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# Detach indecies from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()

# Loop through EOD indecies:
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask:
attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
# Reset positions.
if reset_position_ids:
position_ids[b, (i + 1):] -= (i + 1 - prev_index)
prev_index = i + 1

# Convert attention mask to binary:
attention_mask = (attention_mask < 0.5)
Expand Down

0 comments on commit 0fb2c09

Please sign in to comment.