Skip to content

Commit

Permalink
Fix pp generation (#505)
Browse files Browse the repository at this point in the history
* remove layernorm test in `test_fused_kernels.py`

* don't pass around layer_pasts / presents. Instead - cache them in the class state

* remove fused kernels layernorm test

* fix text generation

* fix pipe + model forward

* fix error if config files already exist

* remove assertiion that pp <= 1, cleanup

* wip push changes

* fix generation for pp>1 & mp>1

* rename `get_key_value` to `use_cache`

* `cache` -> `use_cache`

* never pass in a full size attn mask
  • Loading branch information
sdtblck committed Feb 9, 2022
1 parent 2155b91 commit 7aed133
Show file tree
Hide file tree
Showing 14 changed files with 533 additions and 549 deletions.
44 changes: 32 additions & 12 deletions eval_tasks/eval_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class EvalHarnessAdapter(GPT2LM):
neox_args: a NeoXArgs object containing the model configuration.
batch_size (optional): An argument to override the batch size, which defaults to batch size per gpu * dp world size.
"""

def __init__(self, model, forward_step_fn, neox_args, batch_size=None):

self.device = torch.device(f"cuda:{neox_args.local_rank}")
Expand All @@ -66,7 +67,9 @@ def __init__(self, model, forward_step_fn, neox_args, batch_size=None):
self.dp_group = mpu.get_data_parallel_group()
self.is_mp_rank_0 = mpu.get_model_parallel_rank() == 0

self.batch_size = batch_size or (neox_args.batch_size * self.dp_world_size) # default batch size to bs per gpu * dp size
self.batch_size = batch_size or (
neox_args.batch_size * self.dp_world_size
) # default batch size to bs per gpu * dp size

# some utility functions:
# we need to patch tokenizer methods, because lm_eval uses them internally:
Expand All @@ -81,7 +84,6 @@ def __init__(self, model, forward_step_fn, neox_args, batch_size=None):
model=model,
maximum_tokens=self.max_gen_toks,
temperature=0.0,
broadcast_generated_tokens=True,
)

def greedy_until(self, requests):
Expand All @@ -95,7 +97,7 @@ def greedy_until(self, requests):
:param requests: Dictionary of requests containing the context (prompt) and 'until' - a token or
list of stop tokens.
"""
self.model.module.inference_mode(cache=True) # tell model to cache kv pairs
self.model.module.inference_mode(use_cache=True) # tell model to cache kv pairs
res = []

def _collate(x):
Expand Down Expand Up @@ -137,7 +139,7 @@ def _loglikelihood_tokens(self, requests, disable_tqdm=False):
:param disable_tqdm: If True, disable tqdm progress bar.
"""
self.model.module.inference_mode(
cache=False
use_cache=False
) # tell model to gather parallel outputs, but not cache key-value pairs

disable_tqdm = disable_tqdm if self.is_main else True
Expand Down Expand Up @@ -229,8 +231,14 @@ def _collate(x):
else:
logits_sums = torch.zeros(res_len, dtype=torch.float32).cuda()
max_equals = torch.zeros(res_len, dtype=torch.int64).cuda()
torch.distributed.broadcast(tensor=logits_sums, src=src_rank)
torch.distributed.broadcast(tensor=max_equals, src=src_rank)
torch.distributed.broadcast(
tensor=logits_sums,
src=src_rank,
group=mpu.get_pipe_parallel_group(),
)
torch.distributed.broadcast(
tensor=max_equals, src=src_rank, group=mpu.get_pipe_parallel_group()
)
max_equals = [bool(i) for i in max_equals.tolist()]
logits_sums = logits_sums.tolist()
res = list(zip(logits_sums, max_equals))
Expand All @@ -250,9 +258,13 @@ def _dp_scatter(self, inps):
# In this case we pad the batch
padded_size = self.dp_world_size - (batch_size % self.dp_world_size)

print_rank_0(f'WARNING: Batch size ({batch_size}) must be divisible by dp world size ({self.dp_world_size}). Padding inputs to {padded_size}.')

inps = torch.cat([inps] + [ inps[0:1, ...] for _ in range(padded_size) ], dim=0) # pad with first inp item
print_rank_0(
f"WARNING: Batch size ({batch_size}) must be divisible by dp world size ({self.dp_world_size}). Padding inputs to {padded_size}."
)

inps = torch.cat(
[inps] + [inps[0:1, ...] for _ in range(padded_size)], dim=0
) # pad with first inp item
padded = True

assert (
Expand All @@ -277,7 +289,7 @@ def _dp_gather(self, logits):
tensor_list, logits, group=mpu.get_data_parallel_group()
)
logits = torch.cat(tensor_list, dim=0)
return logits
return logits

def _model_call(self, inps):
batch_size = inps.shape[0]
Expand Down Expand Up @@ -346,8 +358,16 @@ def run_eval(self, eval_tasks=None, num_fewshot=0, bootstrap_iters=2):


def run_eval_harness(
model, forward_step_fn, neox_args, batch_size=None, eval_tasks=None, num_fewshot=0, bootstrap_iters=2
model,
forward_step_fn,
neox_args,
batch_size=None,
eval_tasks=None,
num_fewshot=0,
bootstrap_iters=2,
):
print_rank_0("Running evaluation harness...")
adapter = EvalHarnessAdapter(model, forward_step_fn, neox_args, batch_size)
return adapter.run_eval(eval_tasks=eval_tasks, num_fewshot=num_fewshot, bootstrap_iters=bootstrap_iters)
return adapter.run_eval(
eval_tasks=eval_tasks, num_fewshot=num_fewshot, bootstrap_iters=bootstrap_iters
)
26 changes: 19 additions & 7 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,37 @@

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))

sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))
)
from megatron.training import forward_step
from megatron.utils import setup_for_inference_or_eval
from eval_tasks import run_eval_harness
from pprint import pprint
from datetime import datetime
import json


def main():
model, neox_args = setup_for_inference_or_eval(inference=False, get_key_value=False)
results = run_eval_harness(model, forward_step, neox_args, eval_tasks=neox_args.eval_tasks, bootstrap_iters=10000)
model, neox_args = setup_for_inference_or_eval(use_cache=False)
results = run_eval_harness(
model,
forward_step,
neox_args,
eval_tasks=neox_args.eval_tasks,
bootstrap_iters=10000,
)
if neox_args.rank == 0:
pprint(results)
results_path = f'eval_results_{datetime.now().strftime("%m-%d-%Y-%H-%M-%S")}.json'
results_path = (
f'eval_results_{datetime.now().strftime("%m-%d-%Y-%H-%M-%S")}.json'
)
if neox_args.eval_results_prefix:
results_path = f"{neox_args.eval_results_prefix}_{results_path}"
with open(results_path, 'w') as f:
with open(results_path, "w") as f:
json.dump(results, f, indent=4)


if __name__ == "__main__":
main()
main()
7 changes: 3 additions & 4 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@ def main():
"""
Generate text/sample model
"""
model, neox_args = setup_for_inference_or_eval()
assert (
neox_args.pipe_parallel_size <= 1
), "Pipe parallel size must be <= 1 in generation"
model, neox_args = setup_for_inference_or_eval(use_cache=True)
if neox_args.recompute:
model.module.inference_mode(use_cache=False) # don't use kv cache if recomputing
if neox_args.text_gen_type == "unconditional":
print_rank_0(
f"Generating samples unconditionally and saving results to {neox_args.sample_output_file}"
Expand Down
2 changes: 1 addition & 1 deletion megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def save_ds_checkpoint(iteration, model, neox_args):
# save config files
if torch.distributed.get_rank() == 0 and neox_args.config_files is not None:
configs_directory = os.path.join(neox_args.save, tag, "configs")
os.makedirs(configs_directory)
os.makedirs(configs_directory, exist_ok=True)
for config_filename, config_data in neox_args.config_files.items():
with open(os.path.join(configs_directory, config_filename), "w") as f:
f.write(config_data)
Expand Down
59 changes: 32 additions & 27 deletions megatron/model/gmlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,28 @@

from megatron import mpu


class TinyAttention(nn.Module):
def __init__(self, neox_args, d_attn, d_ff, mask_fn):
super().__init__()
self.proj_qkv = nn.Linear(d_ff * 2, 3 * d_attn)
self.scale = d_attn ** -0.5
self.scale = d_attn**-0.5
self.proj_ffn = nn.Linear(d_attn, d_ff)
self.softmax = FusedScaleMaskSoftmax(
input_in_fp16=neox_args.precision == "fp16",
input_in_bf16=neox_args.precision == "bfloat16",
fusion_type=get_fusion_type(neox_args),
mask_func=mask_fn,
softmax_in_fp32=neox_args.attention_softmax_in_fp32,
scale=None
)
scale=None,
)

def forward(self, x, attention_mask):
q, k, v = torch.chunk(self.proj_qkv(x), 3, dim=-1)
w = torch.einsum("bnd,bmd->bnm", q, k).unsqueeze(1) * self.scale
a = self.softmax(w, mask=attention_mask[..., :w.size(-2), :w.size(-1)]).squeeze(1)
a = self.softmax(
w, mask=attention_mask[..., : w.size(-2), : w.size(-1)]
).squeeze(1)
x = torch.einsum("bnm,bmd->bnd", a, v)
return self.proj_ffn(x)

Expand All @@ -43,13 +46,15 @@ def __init__(self, neox_args, d_ff, d_attn=None, causal=True, mask_fn=None):
self.proj = nn.Linear(neox_args.seq_length, neox_args.seq_length)
if self.use_attn:
assert mask_fn is not None
self.attn = TinyAttention(neox_args=neox_args, d_attn=d_attn, d_ff=d_ff, mask_fn=mask_fn)
self.attn = TinyAttention(
neox_args=neox_args, d_attn=d_attn, d_ff=d_ff, mask_fn=mask_fn
)
nn.init.zeros_(self.proj.weight)
nn.init.constant_(self.proj.bias, 1.)
nn.init.constant_(self.proj.bias, 1.0)

def forward(self, x, attention_mask):
device, n = x.device, x.shape[1]
x = x.transpose(0, 1) # [s, b, d] -> [b, s, d]
x = x.transpose(0, 1) # [s, b, d] -> [b, s, d]

res, gate = x.chunk(2, dim=-1) # split along dim
gate = self.norm(gate)
Expand All @@ -58,7 +63,7 @@ def forward(self, x, attention_mask):
if self.causal:
weight, bias = weight[:n, :n], bias[:n]
mask = torch.ones(weight.shape[:2], device=device).triu_(1).bool()
weight = weight.masked_fill(mask, 0.)
weight = weight.masked_fill(mask, 0.0)

gate = F.linear(gate.transpose(2, 1), weight, self.proj.bias).transpose(2, 1)

Expand All @@ -69,7 +74,15 @@ def forward(self, x, attention_mask):


class GMLPBlock(nn.Module):
def __init__(self, neox_args, init_method, output_layer_init_method, layer_number, ff_mult=4, mask_fn=None):
def __init__(
self,
neox_args,
init_method,
output_layer_init_method,
layer_number,
ff_mult=4,
mask_fn=None,
):
super().__init__()
self.layer_number = layer_number

Expand All @@ -82,40 +95,32 @@ def __init__(self, neox_args, init_method, output_layer_init_method, layer_numbe
output_size=ff_dim * 2,
gather_output=False,
init_method=init_method,
skip_bias_add=True)
skip_bias_add=True,
)
self.activation_func = get_activation(neox_args)
ff_dim_parallel = mpu.divide(ff_dim, mpu.get_model_parallel_world_size())
if neox_args.attention_config[layer_number] == "amlp":
d_attn = neox_args.gmlp_attn_dim
else:
d_attn = None
self.sgu = SpatialGatingUnit(neox_args, ff_dim_parallel, d_attn, causal=True, mask_fn=mask_fn)
self.sgu = SpatialGatingUnit(
neox_args, ff_dim_parallel, d_attn, causal=True, mask_fn=mask_fn
)
self.output_linear = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=ff_dim,
output_size=neox_args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True)
skip_bias_add=True,
)

def forward(self, args):
in_inference = len(args) == 4
in_train = len(args) == 2

if in_train:
x, attention_mask = args
elif in_inference:
x, layer_past, presents, attention_mask = args
else:
raise ValueError

assert len(args) == 2, "GMLPBlock expects 2 arguments"
x, attention_mask = args
x = self.norm(x)
x, _ = self.input_linear(x)
x = self.activation_func(x)
x = self.sgu(x, attention_mask)
x, _ = self.output_linear(x)

if in_train:
return x, attention_mask
elif in_inference:
return x, layer_past, presents, attention_mask
return x, attention_mask
Loading

0 comments on commit 7aed133

Please sign in to comment.