Skip to content

Commit

Permalink
Add DS inference
Browse files Browse the repository at this point in the history
  • Loading branch information
yang committed Jan 25, 2024
1 parent a743484 commit 6e3f22b
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 1 deletion.
2 changes: 1 addition & 1 deletion megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ def __init__(
else:
from torch import distributed as dist

if self.num_experts > dist.get_world_size():
if neox_args.ds_inference or self.num_experts > dist.get_world_size():
moe_mp_size = 1
else:
moe_mp_size = dist.get_world_size() // self.num_experts
Expand Down
10 changes: 10 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,16 @@ class NeoXArgsTextgen(NeoXArgsTemplate):
prefix to which to save evaluation results - final fp will be {eval_results_prefix}_eval_results_yy-mm-dd-HH-MM.json
"""

ds_inference: bool = False
"""
Use DeepSpeed inference.
"""

moe_type: str = "standard"
"""
Specify the type of MoE layer. We have two types of MoE layer: standard and residual.
"""

eval_tasks: list = None
"""
Tasks to evaluate on using lm_eval_harness
Expand Down
21 changes: 21 additions & 0 deletions megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,9 +452,30 @@ def setup_for_inference_or_eval(use_cache=True, overwrite_values=None, input_arg
print_rank_0("Finished loading model")

model.module.inference_mode(use_cache=use_cache)

if neox_args.ds_inference:
model = ds_inference(model, neox_args)
print("> DeepSpeed Inference engine initialized")

return model, neox_args


def ds_inference(model, neox_args):
import deepspeed

engine = deepspeed.init_inference(
model=model,
mp_size=neox_args.model_parallel_size,
tensor_parallel={"mpu": mpu},
dtype=torch.half,
replace_with_kernel_inject=True,
moe_experts=[neox_args.num_experts],
moe_type=neox_args.moe_type,
)

return engine.module


class CharCounter:
"""
Wraps the data_iterator to count the number of characters in a batch
Expand Down

0 comments on commit 6e3f22b

Please sign in to comment.