Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac committed Feb 6, 2024
1 parent d266837 commit 4b2cd26
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 29 deletions.
4 changes: 4 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,7 @@ class BatchStrOut:
@dataclass
class FlushCacheReq:
pass

@dataclass
class DetokenizeReqInput:
input_ids: List[int]
29 changes: 14 additions & 15 deletions python/sglang/srt/managers/router/model_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,10 +402,14 @@ def forward_fill_batch(self, batch: Batch):
next_token_ids = next_token_ids.cpu().tolist()
else:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
logits = logprobs = normalized_logprobs = None
logits = logprobs = normalized_logprobs = last_logprobs = None

# Check finish condition
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
reqs = batch.reqs
if last_logprobs is not None:
last_logprobs = last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist()

# Check finish condition
pt = 0
for i, req in enumerate(reqs):
req.output_ids = [next_token_ids[i]]
Expand All @@ -416,11 +420,8 @@ def forward_fill_batch(self, batch: Batch):
req.normalized_logprob = normalized_logprobs[i]

token_ids = req.input_ids + [next_token_ids[i]]
next_token_logprobs = last_logprobs[i].cpu().tolist()
token_texts = req.tokenizer.convert_ids_to_tokens(token_ids)
token_texts = [t.decode() if isinstance(t, bytes) else t for t in token_texts]
token_logprobs = [None] + req.logprob + [next_token_logprobs[next_token_ids[i]]]
req.token_logprob = list(zip(token_texts, token_ids, token_logprobs))
token_logprobs = [None] + req.logprob + [last_logprobs[i]]
req.token_logprob = list(zip(token_ids, token_logprobs))
pt += req.extend_input_len

self.handle_finished_requests(batch)
Expand Down Expand Up @@ -478,20 +479,18 @@ def forward_decode_batch(self, batch: Batch):
next_token_ids, _ = batch.sample(logits)
next_token_ids = next_token_ids.cpu().tolist()

# Check finish condition
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
reqs = batch.reqs
if last_logprobs is not None:
last_logprobs = last_logprobs[torch.arange(len(reqs)), next_token_ids].tolist()

# Check finish condition
for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)):
req.output_ids.append(next_tok_id)
req.check_finished()

if last_logprobs is not None:
next_tok_text = req.tokenizer.convert_ids_to_tokens([next_tok_id])[0]
next_tok_text = (
next_tok_text.decode() if isinstance(next_tok_text, bytes) else next_tok_text
)
req.token_logprob.append(
(next_tok_text, next_tok_id, last_logprobs[i][next_tok_id].tolist())
)
req.token_logprob.append((next_tok_id, last_logprobs[i]))

self.handle_finished_requests(batch)

Expand Down
5 changes: 5 additions & 0 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from sglang.srt.managers.io_struct import (
BatchStrOut,
DetokenizeReqInput,
FlushCacheReq,
GenerateReqInput,
TokenizedGenerateReqInput,
Expand Down Expand Up @@ -234,6 +235,10 @@ async def generate_request(self, obj: GenerateReqInput):

yield output_list

async def detokenize(self, obj: DetokenizeReqInput):
token_texts = self.tokenizer.convert_ids_to_tokens(obj.input_ids)
return [t.decode() if isinstance(t, bytes) else t for t in token_texts]

async def flush_cache(self):
flush_cache_req = FlushCacheReq()
self.send_to_router.send_pyobj(flush_cache_req)
Expand Down
34 changes: 20 additions & 14 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.io_struct import DetokenizeReqInput, GenerateReqInput
from sglang.srt.managers.openai_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
Expand Down Expand Up @@ -86,6 +86,23 @@ async def stream_generator(obj):
yield out


async def make_openai_style_logprobs(token_logprobs):
ret_logprobs = LogProbs()

# Detokenize
token_ids = [tid for tid, _ in token_logprobs]
token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))

for token_text, (_, token_logprob) in zip(token_texts, token_logprobs):
ret_logprobs.tokens.append(token_text)
ret_logprobs.token_logprobs.append(token_logprob)

# Not supported yet.
ret_logprobs.top_logprobs.append({})
ret_logprobs.text_offset.append(-1)
return ret_logprobs


@app.post("/generate")
async def generate_request(obj: GenerateReqInput):
obj.post_init()
Expand All @@ -105,17 +122,6 @@ async def stream_results():

@app.post("/v1/completions")
async def v1_completions(raw_request: Request):
def make_openai_style_logprobs(token_logprobs):
ret_logprobs = LogProbs()
for token_text, _, token_logprob in token_logprobs:
ret_logprobs.tokens.append(token_text)
ret_logprobs.token_logprobs.append(token_logprob)

# Not supported yet.
ret_logprobs.top_logprobs.append({})
ret_logprobs.text_offset.append(-1)
return ret_logprobs

request_json = await raw_request.json()
request = CompletionRequest(**request_json)

Expand Down Expand Up @@ -156,7 +162,7 @@ async def gnerate_stream_resp():
n_prev_token = prompt_tokens

if request.logprobs is not None:
logprobs = make_openai_style_logprobs(
logprobs = await make_openai_style_logprobs(
content["meta_info"]["token_logprob"][n_prev_token:]
)
n_prev_token = len(content["meta_info"]["token_logprob"])
Expand Down Expand Up @@ -201,7 +207,7 @@ async def gnerate_stream_resp():
token_logprob_pos = prompt_tokens

logprobs = (
make_openai_style_logprobs(ret["meta_info"]["token_logprob"][token_logprob_pos:])
await make_openai_style_logprobs(ret["meta_info"]["token_logprob"][token_logprob_pos:])
if request.logprobs is not None
else None
)
Expand Down

0 comments on commit 4b2cd26

Please sign in to comment.