Skip to content

Commit

Permalink
Support Faster JSON decoding for llava (#137)
Browse files Browse the repository at this point in the history
When sending fast-forwarded reqs to model_rpc, re-calculate `pad_input_ids`
  • Loading branch information
hnyls2002 committed Feb 3, 2024
1 parent 45d6592 commit bb3a3b6
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 8 deletions.
16 changes: 11 additions & 5 deletions python/sglang/srt/managers/router/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self, rid, input_text, input_ids):
self.pixel_values = None
self.image_size = None
self.image_offset = 0
self.pad_value = None

self.sampling_params = None
self.return_logprob = False
Expand Down Expand Up @@ -58,7 +59,7 @@ def __init__(self, rid, input_text, input_ids):
def max_new_tokens(self):
return self.sampling_params.max_new_tokens

def tokenize_fast_forward(self, fast_forward_str, next_state):
def fast_forward_and_retokenize(self, fast_forward_str, next_state):
old_output_str = self.tokenizer.decode(self.output_ids)
# FIXME: This logic does not really solve the problem of determining whether
# there should be a leading space.
Expand All @@ -75,9 +76,14 @@ def tokenize_fast_forward(self, fast_forward_str, next_state):
+ fast_forward_str
)
new_input_ids = self.tokenizer.encode(new_input_string)
fast_forward_tokens_len = (
len(new_input_ids) - len(self.input_ids) - len(self.output_ids)
)
if self.pixel_values is not None:
# NOTE: This is a hack because the old input_ids contains the image padding
fast_forward_tokens_len = len(self.tokenizer.encode(fast_forward_str))
else:
fast_forward_tokens_len = (
len(new_input_ids) - len(self.input_ids) - len(self.output_ids)
)

# print("=" * 100)
# print(f"Catch fast forward:\n{fast_forward_str}")
# print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
Expand Down Expand Up @@ -351,7 +357,7 @@ def check_for_fast_forward(self):
self.tree_cache.dec_ref_counter(req.last_node)

# fast forward
req.tokenize_fast_forward(fast_forward_str, next_state)
req.fast_forward_and_retokenize(fast_forward_str, next_state)

fast_forward_reqs.append(req)
filter_indices.remove(i)
Expand Down
22 changes: 19 additions & 3 deletions python/sglang/srt/managers/router/model_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def exposed_init_model(
self.max_num_running_seq = self.max_total_num_token // 2
self.max_prefill_num_token = max(
self.model_config.context_len,
self.max_total_num_token // 6 if server_args.max_prefill_num_token is None else server_args.max_prefill_num_token,
self.max_total_num_token // 6
if server_args.max_prefill_num_token is None
else server_args.max_prefill_num_token,
)
self.int_token_logit_bias = torch.tensor(
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
Expand Down Expand Up @@ -233,15 +235,15 @@ def handle_generate_request(
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
req.pixel_values = recv_req.pixel_values
if req.pixel_values is not None:
pad_value = [
req.pad_value = [
(recv_req.image_hash) % self.model_config.vocab_size,
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
]
req.image_size = recv_req.image_size
req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
req.input_ids, pad_value, req.pixel_values.shape, req.image_size
req.input_ids, req.pad_value, req.pixel_values.shape, req.image_size
)
req.sampling_params = recv_req.sampling_params
req.return_logprob = recv_req.return_logprob
Expand Down Expand Up @@ -438,6 +440,20 @@ def forward_decode_batch(self, batch: Batch):
if not self.no_regex_fast_forward:
# check for fast forward
fast_forward_reqs = batch.check_for_fast_forward()

# check for image fast forward
for req in fast_forward_reqs:
if req.pixel_values is not None:
(
req.input_ids,
req.image_offset,
) = self.model_runner.model.pad_input_ids(
req.input_ids,
req.pad_value,
req.pixel_values.shape,
req.image_size,
)

self.forward_queue.extend(fast_forward_reqs)
if batch.is_empty():
return
Expand Down

0 comments on commit bb3a3b6

Please sign in to comment.