diff --git a/python/sglang/lang/tracer.py b/python/sglang/lang/tracer.py index 1ab1e8238bf..4b670423340 100644 --- a/python/sglang/lang/tracer.py +++ b/python/sglang/lang/tracer.py @@ -40,7 +40,7 @@ def extract_prefix_by_tracing(program, backend): try: with TracingScope(tracer): tracer.ret_value = program.func(tracer, **arguments) - except (StopTracing, TypeError): + except (StopTracing, TypeError, AttributeError): # Some exceptions may not be catched pass diff --git a/python/sglang/srt/managers/router/infer_batch.py b/python/sglang/srt/managers/router/infer_batch.py index c98d6d5196a..3cc61dd085d 100644 --- a/python/sglang/srt/managers/router/infer_batch.py +++ b/python/sglang/srt/managers/router/infer_batch.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from enum import Enum, auto from typing import List @@ -38,6 +39,7 @@ def __init__(self, rid): self.adjust_input_len = 0 self.prefix_indices = [] + self.last_node = None self.normalized_logprob = None @@ -81,27 +83,56 @@ def __repr__(self): return f"rid(n={self.rid}, " f"input_ids={self.input_ids}, " +@dataclass class Batch: - def __init__( - self, - reqs: List[Req], - req_to_token_pool: ReqToTokenPool, - token_to_kv_pool: TokenToKVPool, - tree_cache: RadixCache, - ): - self.reqs = reqs - self.req_to_token_pool = req_to_token_pool - self.token_to_kv_pool = token_to_kv_pool - self.tree_cache = tree_cache - - self.return_normalized_logprob = any( - req.return_normalized_logprob for req in reqs + reqs: List[Req] + req_to_token_pool: ReqToTokenPool + token_to_kv_pool: TokenToKVPool + tree_cache: RadixCache + + # batched arguments to model runner + input_ids: torch.Tensor = None + req_pool_indices: torch.Tensor = None + seq_lens: torch.Tensor = None + prefix_lens: torch.Tensor = None + position_ids_offsets: torch.Tensor = None + out_cache_loc: torch.Tensor = None + out_cache_cont_start: torch.Tensor = None + out_cache_cont_end: torch.Tensor = None + return_normalized_logprob: bool = False + + # for multimodal + pixel_values: List[torch.Tensor] = None + image_offsets: List[int] = None + + # other arguments for control + output_ids: torch.Tensor = None + extend_num_tokens: int = None + + # batched sampling params + temperatures: torch.Tensor = None + top_ps: torch.Tensor = None + top_ks: torch.Tensor = None + frequency_penalties: torch.Tensor = None + presence_penalties: torch.Tensor = None + logit_bias: torch.Tensor = None + + @classmethod + def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache): + return_normalized_logprob = any(req.return_normalized_logprob for req in reqs) + + return cls( + reqs=reqs, + req_to_token_pool=req_to_token_pool, + token_to_kv_pool=token_to_kv_pool, + tree_cache=tree_cache, + return_normalized_logprob=return_normalized_logprob, ) def is_empty(self): return len(self.reqs) == 0 - def init_extend_batch(self, vocab_size: int, int_token_logit_bias: torch.Tensor): + def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor): device = "cuda" bs = len(self.reqs) reqs = self.reqs @@ -141,7 +172,7 @@ def init_extend_batch(self, vocab_size: int, int_token_logit_bias: torch.Tensor) out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens) if out_cache_loc is None: - print("Prefill out of memory.") + print("Prefill out of memory. This should nerver happen.") self.tree_cache.pretty_print() exit() @@ -196,7 +227,50 @@ def init_extend_batch(self, vocab_size: int, int_token_logit_bias: torch.Tensor) ) self.logit_bias = logit_bias - def update_for_decode(self, input_ids=None): + def check_decode_mem(self): + bs = len(self.reqs) + avai_size = self.token_to_kv_pool.available_size() + if avai_size >= bs: + return True + + self.tree_cache.evict(bs, self.token_to_kv_pool.free) + if self.token_to_kv_pool.available_size() >= bs: + return True + + return False + + def retract_decode(self): + sorted_indices = [i for i in range(len(self.reqs))] + sorted_indices.sort( + key=lambda i: (len(self.reqs[i].output_ids), -len(self.reqs[i].input_ids)), + reverse=True, + ) + + retracted_reqs = [] + seq_lens_np = self.seq_lens.cpu().numpy() + req_pool_indices_np = self.req_pool_indices.cpu().numpy() + while self.token_to_kv_pool.available_size() < len(self.reqs): + idx = sorted_indices.pop() + req = self.reqs[idx] + retracted_reqs.append(req) + + self.tree_cache.dec_ref_counter(req.last_node) + req.prefix_indices = None + req.last_node = None + req.adjust_input_len = 0 + req.output_ids = [] + # TODO: apply more fine-grained retraction + + token_indices = self.req_to_token_pool.req_to_token[ + req_pool_indices_np[idx] + ][: seq_lens_np[idx]] + self.token_to_kv_pool.free(token_indices) + + self.filter_batch(sorted_indices) + + return retracted_reqs + + def prepare_for_decode(self, input_ids=None): if input_ids is None: input_ids = [ r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs @@ -212,13 +286,9 @@ def update_for_decode(self, input_ids=None): self.out_cache_loc = self.token_to_kv_pool.alloc(bs) if self.out_cache_loc is None: - self.tree_cache.evict(bs, self.token_to_kv_pool.free) - self.out_cache_loc = self.token_to_kv_pool.alloc(bs) - - if self.out_cache_loc is None: - print("Decode out of memory.") - self.tree_cache.pretty_print() - exit() + print("Decode out of memory. This should nerver happen.") + self.tree_cache.pretty_print() + exit() self.out_cache_cont_start = None self.out_cache_cont_end = None @@ -240,6 +310,9 @@ def filter_batch(self, unfinished_indices: List[int]): self.prefix_lens = None self.position_ids_offsets = self.position_ids_offsets[new_indices] self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None + self.return_normalized_logprob = any( + req.return_normalized_logprob for req in self.reqs + ) for item in [ "temperatures", @@ -263,6 +336,9 @@ def merge(self, other): [self.position_ids_offsets, other.position_ids_offsets] ) self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None + self.return_normalized_logprob = any( + req.return_normalized_logprob for req in self.reqs + ) for item in [ "temperatures", diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 94613ce3761..9de71b60bab 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -45,7 +45,6 @@ def exposed_init_model( self.tp_rank = tp_rank self.tp_size = server_args.tp_size self.schedule_heuristic = server_args.schedule_heuristic - self.schedule_conservativeness = server_args.schedule_conservativeness # Init model and tokenizer self.model_config = ModelConfig( @@ -114,6 +113,11 @@ def exposed_init_model( # Init the FSM cache for constrained generation self.regex_fsm_cache = FSMCache(self.tokenizer) + # Init new token estimation + self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0) + self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0) + self.new_token_ratio_step = (0.0001, 0.05) # (down, up) + def exposed_step(self, recv_reqs): if self.tp_size != 1: recv_reqs = obtain(recv_reqs) @@ -209,11 +213,6 @@ def handle_generate_request( req.stream = recv_req.stream req.tokenizer = self.tokenizer - # init the regex fsm - if req.sampling_params.regex is not None: - req.regex_fsm_state = 0 - req.regex_fsm = self.regex_fsm_cache.get_fsm(req.sampling_params.regex) - # Truncate long prompts req.input_ids = req.input_ids[: self.model_config.context_len - 1] req.sampling_params.max_new_tokens = min( @@ -249,13 +248,10 @@ def get_new_fill_batch(self): available_size = ( self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() ) - new_ratio = ( - self.scheduler.new_token_estimation_ratio() * self.schedule_conservativeness - ) if self.running_batch: available_size -= sum( [ - (r.max_new_tokens() - len(r.output_ids)) * new_ratio + (r.max_new_tokens() - len(r.output_ids)) * self.new_token_ratio for r in self.running_batch.reqs ] ) @@ -311,7 +307,7 @@ def get_new_fill_batch(self): f"#running_req: {0 if self.running_batch is None else len(self.running_batch.reqs)}" ) - new_batch = Batch( + new_batch = Batch.init_new( can_run_list, self.req_to_token_pool, self.token_to_kv_pool, @@ -322,7 +318,16 @@ def get_new_fill_batch(self): def forward_fill_batch(self, batch: Batch): # Build batch tensors - batch.init_extend_batch(self.model_config.vocab_size, self.int_token_logit_bias) + batch.prepare_for_extend( + self.model_config.vocab_size, self.int_token_logit_bias + ) + + # init the regex fsm before first sampling + for req in batch.reqs: + if req.sampling_params.regex is not None: + req.regex_fsm_state = 0 + req.regex_fsm = self.regex_fsm_cache.get_fsm(req.sampling_params.regex) + if batch.extend_num_tokens != 0: # Forward logits, normalized_logprobs = self.model_runner.forward( @@ -350,9 +355,27 @@ def forward_fill_batch(self, batch: Batch): self.handle_finished_requests(batch) def forward_decode_batch(self, batch: Batch): + # check if decode out of memory + if not batch.check_decode_mem(): + old_ratio = self.new_token_ratio + self.new_token_ratio = min(old_ratio + self.new_token_ratio_step[1], 1.0) + + retracted_reqs = batch.retract_decode() + logger.info( + "decode out of memory happened, " + f"#retracted_reqs: {len(retracted_reqs)}, " + f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}" + ) + self.forward_queue.extend(retracted_reqs) + else: + self.new_token_ratio = max( + self.new_token_ratio - self.new_token_ratio_step[0], + self.min_new_token_ratio, + ) + # Update batch tensors self.decode_forward_ct += 1 - batch.update_for_decode() + batch.prepare_for_decode() # Forward logits = self.model_runner.forward(batch, ForwardMode.DECODE) diff --git a/python/sglang/srt/managers/router/scheduler.py b/python/sglang/srt/managers/router/scheduler.py index 582268f60b8..9affd970f3d 100644 --- a/python/sglang/srt/managers/router/scheduler.py +++ b/python/sglang/srt/managers/router/scheduler.py @@ -17,9 +17,6 @@ def __init__( self.max_total_num_token = max_total_num_token self.tree_cache = tree_cache - def new_token_estimation_ratio(self): - return 0.5 if self.schedule_heuristic != "fcfs" else 0.6 - def get_priority_queue(self, forward_queue): if self.schedule_heuristic == "lpm": # longest prefix match diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e0d1c236db3..967f41ce46f 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -112,7 +112,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--schedule-conservativeness", type=float, default=ServerArgs.schedule_conservativeness, - help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see out-of-memory errors.", + help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.", ) parser.add_argument( "--random-seed", diff --git a/test/srt/model/test_llama_extend.py b/test/srt/model/test_llama_extend.py index fdd7bbb13cd..b01549878dc 100644 --- a/test/srt/model/test_llama_extend.py +++ b/test/srt/model/test_llama_extend.py @@ -34,8 +34,8 @@ def test_generate_worker(model_path, tp_rank, tp_size): reqs.append(req) # Prefill - batch = Batch(reqs, model.req_to_token_pool, model.token_to_kv_pool, None) - batch.init_extend_batch(model.model_config.vocab_size(), None) + batch = Batch.init_new(reqs, model.req_to_token_pool, model.token_to_kv_pool, None) + batch.prepare_for_extend(model.model_config.vocab_size, None) logits, _ = model.forward(batch, ForwardMode.EXTEND) next_token_ids, next_token_probs = batch.sample(logits) print("extend logits (first)", logits) @@ -47,8 +47,8 @@ def test_generate_worker(model_path, tp_rank, tp_size): req.prefix_indices = model.req_to_token_pool.req_to_token[ batch.req_pool_indices[i], :cut_num ] - batch = Batch(reqs, model.req_to_token_pool, model.token_to_kv_pool, None) - batch.init_extend_batch(model.model_config.vocab_size(), None) + batch = Batch.init_new(reqs, model.req_to_token_pool, model.token_to_kv_pool, None) + batch.prepare_for_extend(model.model_config.vocab_size, None) logits, _ = model.forward(batch, ForwardMode.EXTEND) next_token_ids, next_token_probs = batch.sample(logits) @@ -59,7 +59,7 @@ def test_generate_worker(model_path, tp_rank, tp_size): # Decode for i in range(6): - batch.update_for_decode(next_token_ids.cpu().numpy()) + batch.prepare_for_decode(next_token_ids.cpu().numpy()) logits = model.forward(batch, ForwardMode.DECODE) next_token_ids, next_token_probs = batch.sample(logits) diff --git a/test/srt/test_robust.py b/test/srt/test_robust.py new file mode 100644 index 00000000000..5b479318f58 --- /dev/null +++ b/test/srt/test_robust.py @@ -0,0 +1,132 @@ +import argparse +import random +import string + +from sglang.test.test_utils import ( + add_common_sglang_args_and_parse, + select_sglang_backend, +) +from vllm.transformers_utils.tokenizer import get_tokenizer + +import sglang as sgl + +TOKENIZER = None +RANDOM_PREFILL_LEN = None +RANDOM_DECODE_LEN = None + + +def gen_prompt(token_num): + if RANDOM_PREFILL_LEN: + token_num = random.randint(1, token_num) + + cha_set = string.ascii_letters + string.digits + ret = "".join(random.choices(cha_set, k=token_num)) + while len(TOKENIZER(ret).input_ids) < token_num: + ret += random.choice(cha_set) + + return ret + + +def robust_test_dfs(s, d, args, leaf_states): + if d == 0: + s += "END" + leaf_states.append(s) + return + + s += gen_prompt(args.len_prefill) + forks = s.fork(args.num_fork) + for fork_s in forks: + fork_s += gen_prompt(args.len_prefill) + new_tokens = ( + args.len_decode + if not RANDOM_DECODE_LEN + else random.randint(1, args.len_decode) + ) + fork_s += sgl.gen( + max_tokens=new_tokens, + ignore_eos=True, + ) + + for fork_s in forks: + robust_test_dfs(fork_s, d - 1, args, leaf_states) + + +def robust_test_bfs(s, args, leaf_states): + old_forks = [s] + new_forks = [] + for _ in range(args.depth): + for old_fork in old_forks: + old_fork += gen_prompt(args.len_prefill) + forks = old_fork.fork(args.num_fork) + for fork_s in forks: + fork_s += gen_prompt(args.len_prefill) + new_tokens = ( + args.len_decode + if not RANDOM_DECODE_LEN + else random.randint(1, args.len_decode) + ) + fork_s += sgl.gen( + max_tokens=new_tokens, + ignore_eos=True, + ) + new_forks.extend(forks) + + old_forks = new_forks + new_forks = [] + + for old_fork in old_forks: + old_fork += "END" + leaf_states.append(old_fork) + + +@sgl.function +def robust_test(s, args): + leaf_states = [] + if args.mode == "bfs": + robust_test_bfs(s, args, leaf_states) + else: + robust_test_dfs(s, args.depth, args, leaf_states) + return leaf_states + + +def main(args): + backend = select_sglang_backend(args) + + arguments = [{"args": args} for _ in range(args.num_req)] + + states = robust_test.run_batch( + arguments, temperature=0, backend=backend, num_threads=args.parallel + ) + + with open(f"tmp_robust_{args.mode}.txt", "w") as f: + for state in states: + leaf_states = state.ret_value + for leaf_state in leaf_states: + assert leaf_state.text()[-3:] == "END" + f.write(leaf_state.text()[:-3] + "\n") + + +if __name__ == "__main__": + # fmt: off + parser = argparse.ArgumentParser() + parser.add_argument("--num-req", type=int, default=2) + parser.add_argument("--depth", type=int, default=3) + parser.add_argument("--num-fork", type=int, default=2) + parser.add_argument("--len-prefill", type=int, default=128) + parser.add_argument("--len-decode", type=int, default=128) + parser.add_argument("--random-prefill-len", action="store_true") + parser.add_argument("--random-decode-len", action="store_true") + parser.add_argument("--mode", type=str, default="bfs", choices=["dfs", "bfs"]) + parser.add_argument("--tokenizer", type=str, default = "meta-llama/Llama-2-7b-chat-hf") + parser.add_argument("--trust-remote-code", action="store_true") + parser.add_argument("--seed", type=int, default=42) + args = add_common_sglang_args_and_parse(parser) + # fmt: on + + RANDOM_PREFILL_LEN = args.random_prefill_len + RANDOM_DECODE_LEN = args.random_decode_len + TOKENIZER = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code) + + random.seed(args.seed) + + main(args)