Skip to content

Commit

Permalink
Merge pull request EleutherAI#5 from lintangsutawika/stability_multitask
Browse files Browse the repository at this point in the history
Changes when trying to run finetuning on cluster
  • Loading branch information
lintangsutawika committed Sep 23, 2022
2 parents a16bcda + 007d6d0 commit f29538e
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 15 deletions.
30 changes: 20 additions & 10 deletions eval_tasks/eval_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ def __init__(self, model, forward_step_fn, neox_args, batch_size=None):
self.cache_hook = base.CacheHook(None)
self.model = model
self.neox_args = neox_args
# HACK: disable packing at inference time
if self.neox_args.train_mtf:
self.was_packing = True
self.neox_args.train_mtf = False
self.tokenizer = neox_args.tokenizer
self._device = torch.device(f"cuda:{neox_args.local_rank}")
self._eot_token_id = neox_args.tokenizer.eod_id
Expand Down Expand Up @@ -558,26 +562,31 @@ def _collate(x):
(targetlen,) = target.shape

target_padding_length = (
target_padding_length if target_padding_length is not None else targetlen
max(target_padding_length, targetlen) if target_padding_length is not None else targetlen
)

inps.append(inp.unsqueeze(0))
targets.append(target)
contlens.append(cont)
inplens.append(inplen)
targetlens.append(targetlen)

padded_targets = []
for target, targetlen in zip(targets, targetlens):
# pad to length
target = torch.cat(
[
target, # [seq]
torch.zeros(padding_length - targetlen, dtype=torch.long).to(
torch.zeros(target_padding_length - targetlen + 1, dtype=torch.long).to(
target.device
), # [padding_length - seq]
],
dim=0,
)
).unsqueeze(0)

inps.append(inp.unsqueeze(0))
targets.append(target.unsqueeze(0))
contlens.append(cont)
inplens.append(inplen)
targetlens.append(targetlen)
padded_targets.append(target)

logits = self._model_call(torch.cat(inps, dim=0), targets=torch.cat(targets, dim=0))
logits = self._model_call(torch.cat(inps, dim=0), targets=torch.cat(padded_targets, dim=0))
res_len += len(chunk)

if logits is not None:
Expand Down Expand Up @@ -649,10 +658,11 @@ def run_eval_harness(
num_fewshot=0,
bootstrap_iters=2,
):
batch_size=2 # TODO(Hailey): don't merge this change into main. hack to stop OOM errors
batch_size=1 # TODO(Hailey): don't merge this change into main. hack to stop OOM errors
print_rank_0("Running evaluation harness...")
if neox_args.model_arch == "t5":
adapter = Seq2SeqEvalHarnessAdapter(model, forward_step_fn, neox_args, batch_size)
print("using packing:",neox_args.train_mtf)
else:
adapter = EvalHarnessAdapter(model, forward_step_fn, neox_args, batch_size)
return adapter.run_eval(
Expand Down
3 changes: 3 additions & 0 deletions megatron/model/positional_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,13 @@ def apply_rotary_pos_emb(q, cos, sin, offset: int = 0):
def apply_rotary_pos_emb_torch(
q, cos, sin, offset: int = 0
): # jitting fails with bf16
og_shape = cos.shape
cos, sin = (
cos[offset : q.shape[0] + offset, ...],
sin[offset : q.shape[0] + offset, ...],
)
if q.shape[0] != cos.shape[0] or rotate_half(q).shape[0] != sin.shape[0]:
print(offset, og_shape, q.shape, cos.shape, sin.shape, rotate_half(q).shape)
return (q * cos) + (rotate_half(q) * sin)


Expand Down
5 changes: 4 additions & 1 deletion megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,10 +505,13 @@ def forward(self, hidden_states, attention_mask, encoder_hidden_states=None, lay
if exists(layer_past) and layer_past.numel() > 0:
offset = layer_past[0].shape[0]
seq_len += offset
cos, sin = self.rotary_emb(value_layer, seq_len=seq_len)
# TODO(Hailey): make applying rotary embs more efficient. the 2 calls to rotary_emb capture an edge case in evals
cos, sin = self.rotary_emb(value_layer, seq_len=query_layer.shape[0] + offset)

query_layer = apply_rotary_fn(
query_rot, cos, sin, offset=offset
)
cos, sin = self.rotary_emb(value_layer, seq_len=seq_len)
key_layer = apply_rotary_fn(
key_rot, cos, sin, offset=offset
)
Expand Down
9 changes: 5 additions & 4 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def _get_batch_encdec(neox_args, keys, data, datatype):
tokens_dec = tokens_dec_[:, :-1].contiguous()

batch_size, src_length = tokens_enc.size()
batch_size, target_length = tokens_dec_.size()
batch_size, target_length = tokens_dec.size()

if neox_args.packing:
segment_ids_enc = data_b['input_segment_ids'].long()
Expand Down Expand Up @@ -254,6 +254,7 @@ def _get_batch_encdec(neox_args, keys, data, datatype):
data=tokens_enc,
)

enc_dec_mask = get_full_mask(target_length, src_length, device=tokens_enc.device)
enc_mask = get_full_mask(src_length, src_length, device=tokens_enc.device)

# Get the decoder self-attn mask and position ids.
Expand All @@ -264,7 +265,7 @@ def _get_batch_encdec(neox_args, keys, data, datatype):
segment_ids=None
)

return tokens_enc, tokens_dec, labels, loss_mask, enc_mask, attention_mask, \
return tokens_enc, tokens_dec, labels, loss_mask, enc_mask, enc_dec_mask, attention_mask, \
position_ids_enc, position_ids_dec,


Expand Down Expand Up @@ -319,12 +320,12 @@ def get_batch_encdec_pipe(data, neox_args):
(labels, loss_mask)

else:
tokens_enc, tokens_dec, labels, loss_mask, encoder_attn_mask, attention_mask, \
tokens_enc, tokens_dec, labels, loss_mask, encoder_attn_mask, enc_dec_mask, attention_mask, \
position_ids_enc, position_ids_dec = _get_batch_encdec(
neox_args, keys, data, datatype
)

return (tokens_enc, tokens_dec, position_ids_enc, position_ids_dec, encoder_attn_mask, attention_mask),\
return (tokens_enc, tokens_dec, position_ids_enc, position_ids_dec, encoder_attn_mask, enc_dec_mask, attention_mask),\
(labels, loss_mask)


Expand Down

0 comments on commit f29538e

Please sign in to comment.