From 6e3f22bfe5c696c64f1fa30c488503eade9f9e7f Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Thu, 25 Jan 2024 02:02:46 +0000 Subject: [PATCH] Add DS inference Closes https://github.com/EleutherAI/gpt-neox/issues/845 --- megatron/model/transformer.py | 2 +- megatron/neox_arguments/neox_args.py | 10 ++++++++++ megatron/utils.py | 21 +++++++++++++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index e881f2229..27e0554a2 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -840,7 +840,7 @@ def __init__( else: from torch import distributed as dist - if self.num_experts > dist.get_world_size(): + if neox_args.ds_inference or self.num_experts > dist.get_world_size(): moe_mp_size = 1 else: moe_mp_size = dist.get_world_size() // self.num_experts diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 67ff29380..db0ec5566 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -1155,6 +1155,16 @@ class NeoXArgsTextgen(NeoXArgsTemplate): prefix to which to save evaluation results - final fp will be {eval_results_prefix}_eval_results_yy-mm-dd-HH-MM.json """ + ds_inference: bool = False + """ + Use DeepSpeed inference. + """ + + moe_type: str = "standard" + """ + Specify the type of MoE layer. We have two types of MoE layer: standard and residual. + """ + eval_tasks: list = None """ Tasks to evaluate on using lm_eval_harness diff --git a/megatron/utils.py b/megatron/utils.py index 3769495d9..e98612884 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -452,9 +452,30 @@ def setup_for_inference_or_eval(use_cache=True, overwrite_values=None, input_arg print_rank_0("Finished loading model") model.module.inference_mode(use_cache=use_cache) + + if neox_args.ds_inference: + model = ds_inference(model, neox_args) + print("> DeepSpeed Inference engine initialized") + return model, neox_args +def ds_inference(model, neox_args): + import deepspeed + + engine = deepspeed.init_inference( + model=model, + mp_size=neox_args.model_parallel_size, + tensor_parallel={"mpu": mpu}, + dtype=torch.half, + replace_with_kernel_inject=True, + moe_experts=[neox_args.num_experts], + moe_type=neox_args.moe_type, + ) + + return engine.module + + class CharCounter: """ Wraps the data_iterator to count the number of characters in a batch