Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support decode token logprobs #130

Merged
merged 6 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Support decode token logprobs
  • Loading branch information
comaniac committed Feb 6, 2024
commit 2c0c92a1b52a91fcd395449b3c7a5931f5fc8a42
69 changes: 37 additions & 32 deletions python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,21 @@ def __init__(self, config):
self.tp_size = get_tensor_model_parallel_world_size()

def forward(self, input_ids, hidden_states, weight, input_metadata):
last_index = None
if input_metadata.forward_mode != ForwardMode.DECODE:
last_index = (
torch.cumsum(
input_metadata.seq_lens - input_metadata.prefix_lens,
dim=0,
dtype=torch.long,
)
- 1
)

if not input_metadata.return_logprob:
if input_metadata.forward_mode == ForwardMode.DECODE:
last_hidden = hidden_states
else:
last_index = (
torch.cumsum(
input_metadata.seq_lens - input_metadata.prefix_lens,
dim=0,
dtype=torch.long,
)
- 1
)
last_hidden = hidden_states[last_index]
hidden_states = None

Expand All @@ -35,38 +38,40 @@ def forward(self, input_ids, hidden_states, weight, input_metadata):
last_logits = last_logits[:, : self.config.vocab_size]
return last_logits, (None, None)
else:
assert input_metadata.forward_mode != ForwardMode.DECODE
last_index = (
torch.cumsum(
input_metadata.seq_lens - input_metadata.prefix_lens,
dim=0,
dtype=torch.long,
)
- 1
)

logits = torch.matmul(hidden_states, weight.T)
if self.tp_size > 1:
logits = tensor_model_parallel_all_gather(logits)
logits = logits[:, : self.config.vocab_size]
all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6)

logprobs = all_logprobs[
torch.arange(all_logprobs.shape[0], device="cuda"),
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
]
logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
if input_metadata.forward_mode == ForwardMode.DECODE:
# When decoding, logprobs shape is (batch, vocab size), and
# we expect the caller to directly get the logprobs via logprobs[next_token_id].
last_logits = logits
logprobs = all_logprobs
normalized_logprobs = None
else:
# When prefill, logprobs shape is (batch, seq_len), where each value
# is already the logprob of the selected token. However, since we do not
# know the first sampled token ID yet, we always pad 0. Thus,
# the logprobs for the first decoding token has to be computed by the caller
# using last_logits.
logprobs = all_logprobs[
torch.arange(all_logprobs.shape[0], device="cuda"),
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
]
logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)

start = input_metadata.extend_start_loc.clone()
end = start + input_metadata.extend_seq_lens - 2
start.clamp_(min=0, max=logprobs.shape[0] - 1)
end.clamp_(min=0, max=logprobs.shape[0] - 1)
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
normalized_logprobs = sum_logp / (
(input_metadata.extend_seq_lens - 1).clamp(min=1)
)
start = input_metadata.extend_start_loc.clone()
end = start + input_metadata.extend_seq_lens - 2
start.clamp_(min=0, max=logprobs.shape[0] - 1)
end.clamp_(min=0, max=logprobs.shape[0] - 1)
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
normalized_logprobs = sum_logp / (
(input_metadata.extend_seq_lens - 1).clamp(min=1)
)
last_logits = logits[last_index]

last_logits = logits[last_index]
return last_logits, (logprobs, normalized_logprobs)


Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/router/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(self, rid, input_text, input_ids):
self.last_node = None

self.logprob = None
self.token_logprob = None
self.normalized_logprob = None

# For constrained decoding
Expand Down
42 changes: 34 additions & 8 deletions python/sglang/srt/managers/router/model_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,16 +393,15 @@ def forward_fill_batch(self, batch: Batch):
logits, (logprobs, normalized_logprobs) = self.model_runner.forward(
batch, ForwardMode.EXTEND, batch.return_logprob
)
# print("extend logits", logits)
if logprobs is not None:
logprobs = logprobs.cpu().tolist()
normalized_logprobs = normalized_logprobs.cpu().tolist()

next_token_ids, next_token_probs = batch.sample(logits)
next_token_ids, _ = batch.sample(logits)
next_token_ids = next_token_ids.cpu().tolist()
else:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
logprobs = normalized_logprobs = None
logits = logprobs = normalized_logprobs = None

# Check finish condition
reqs = batch.reqs
Expand All @@ -414,6 +413,19 @@ def forward_fill_batch(self, batch: Batch):
if logprobs is not None:
req.logprob = logprobs[pt : pt + req.extend_input_len - 1]
req.normalized_logprob = normalized_logprobs[i]

# Cannot directly take logprobs[pt + req.extend_input_len - 1]
# as it is always the logprob of token 0. This is because at the
# time when logprobs is computed, the next token ID is not yet sampled.
assert logits is not None
token_ids = req.input_ids + [next_token_ids[i]]

next_token_logprobs = torch.log(torch.softmax(logits[i].float(), dim=-1) + 1e-6)
next_token_logprobs = next_token_logprobs.cpu().tolist()
token_texts = req.tokenizer.convert_ids_to_tokens(token_ids)
comaniac marked this conversation as resolved.
Show resolved Hide resolved
token_texts = [t.decode() if isinstance(t, bytes) else t for t in token_texts]
token_logprobs = [None] + req.logprob + [next_token_logprobs[next_token_ids[i]]]
req.token_logprob = list(zip(token_texts, token_ids, token_logprobs))
pt += req.extend_input_len

self.handle_finished_requests(batch)
Expand Down Expand Up @@ -463,15 +475,28 @@ def forward_decode_batch(self, batch: Batch):
batch.prepare_for_decode()

# Forward
logits = self.model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids, next_token_probs = batch.sample(logits)
logits, (logprobs, _) = self.model_runner.forward(
batch,
ForwardMode.DECODE,
batch.return_logprob,
)
next_token_ids, _ = batch.sample(logits)
next_token_ids = next_token_ids.cpu().tolist()

# Check finish condition
reqs = batch.reqs
for i in range(len(reqs)):
reqs[i].output_ids.append(next_token_ids[i])
reqs[i].check_finished()
for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)):
req.output_ids.append(next_tok_id)
req.check_finished()

if logprobs is not None:
next_tok_text = req.tokenizer.convert_ids_to_tokens([next_tok_id])[0]
next_tok_text = (
next_tok_text.decode() if isinstance(next_tok_text, bytes) else next_tok_text
)
req.token_logprob.append(
(next_tok_text, next_tok_id, logprobs[i][next_tok_id].tolist())
comaniac marked this conversation as resolved.
Show resolved Hide resolved
)

self.handle_finished_requests(batch)

Expand Down Expand Up @@ -513,6 +538,7 @@ def handle_finished_requests(self, batch: Batch):
}
if req.return_logprob:
meta_info["prompt_logprob"] = req.logprob
meta_info["token_logprob"] = req.token_logprob
meta_info["normalized_prompt_logprob"] = req.normalized_logprob
output_meta_info.append(meta_info)
output_finished.append(req.finished)
Expand Down
11 changes: 5 additions & 6 deletions python/sglang/srt/managers/router/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ def forward_decode(
out_cache_loc,
out_cache_cont_start,
out_cache_cont_end,
return_logprob,
):
input_metadata = InputMetadata.create(
self,
Expand All @@ -404,10 +405,9 @@ def forward_decode(
out_cache_loc=out_cache_loc,
out_cache_cont_start=out_cache_cont_start,
out_cache_cont_end=out_cache_cont_end,
return_logprob=return_logprob,
)
return self.model.forward(input_ids, input_metadata.positions, input_metadata)[
0
]
return self.model.forward(input_ids, input_metadata.positions, input_metadata)

@torch.inference_mode()
def forward_extend_multi_modal(
Expand Down Expand Up @@ -455,8 +455,8 @@ def forward(self, batch: Batch, forward_mode: ForwardMode, return_logprob=False)
"prefix_lens": batch.prefix_lens,
"position_ids_offsets": batch.position_ids_offsets,
"out_cache_loc": batch.out_cache_loc,
"return_logprob": return_logprob,
}
kwargs["return_logprob"] = return_logprob
return self.forward_extend_multi_modal(**kwargs)
else:
kwargs = {
Expand All @@ -466,17 +466,16 @@ def forward(self, batch: Batch, forward_mode: ForwardMode, return_logprob=False)
"prefix_lens": batch.prefix_lens,
"position_ids_offsets": batch.position_ids_offsets,
"out_cache_loc": batch.out_cache_loc,
"return_logprob": return_logprob,
}

if forward_mode == ForwardMode.DECODE:
kwargs["out_cache_cont_start"] = batch.out_cache_cont_start
kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
return self.forward_decode(**kwargs)
elif forward_mode == ForwardMode.EXTEND:
kwargs["return_logprob"] = return_logprob
return self.forward_extend(**kwargs)
elif forward_mode == ForwardMode.PREFILL:
kwargs["return_logprob"] = return_logprob
return self.forward_prefill(**kwargs)
else:
raise ValueError(f"Invaid forward mode: {forward_mode}")
63 changes: 55 additions & 8 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import threading
import time
from typing import List, Optional, Union
import logging

# Fix a Python bug
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
Expand Down Expand Up @@ -42,6 +43,7 @@
CompletionResponseStreamChoice,
CompletionStreamResponse,
DeltaMessage,
LogProbs,
UsageInfo,
)
from sglang.srt.managers.router.manager import start_router_process
Expand All @@ -51,6 +53,7 @@

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

logger = logging.getLogger(__name__)

app = FastAPI()
tokenizer_manager = None
Expand Down Expand Up @@ -104,6 +107,17 @@ async def stream_results():

@app.post("/v1/completions")
async def v1_completions(raw_request: Request):
def make_openai_style_logprobs(token_logprobs):
ret_logprobs = LogProbs()
for token_text, _, token_logprob in token_logprobs:
ret_logprobs.tokens.append(token_text)
ret_logprobs.token_logprobs.append(token_logprob)

# Not supported yet.
ret_logprobs.top_logprobs.append({})
ret_logprobs.text_offset.append(-1)
return ret_logprobs

request_json = await raw_request.json()
request = CompletionRequest(**request_json)

Expand All @@ -120,6 +134,7 @@ async def v1_completions(raw_request: Request):
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
},
return_logprob=request.logprobs is not None,
stream=request.stream,
)
adapted_request.post_init()
Expand All @@ -128,17 +143,34 @@ async def v1_completions(raw_request: Request):

async def gnerate_stream_resp():
stream_buffer = ""
n_prev_token = 0
async for content in stream_generator(adapted_request):
text = content["text"]
prompt_tokens = content["meta_info"]["prompt_tokens"]
completion_tokens = content["meta_info"]["completion_tokens"]

if not stream_buffer: # The first chunk
if request.echo:
# Prepend prompt in response text.
text = request.prompt + text
else:
# Skip prompt tokens if echo is disabled.
n_prev_token = prompt_tokens

if request.logprobs is not None:
logprobs = make_openai_style_logprobs(
content["meta_info"]["token_logprob"][n_prev_token:]
)
n_prev_token = len(content["meta_info"]["token_logprob"])
else:
logprobs = None

delta = text[len(stream_buffer) :]
stream_buffer = text
stream_buffer = content["text"]
choice_data = CompletionResponseStreamChoice(
index=0,
text=delta,
logprobs=None,
logprobs=logprobs,
finish_reason=None,
)
chunk = CompletionStreamResponse(
Expand All @@ -152,23 +184,36 @@ async def gnerate_stream_resp():
total_tokens=prompt_tokens + completion_tokens,
),
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
yield f"data: {chunk.json(ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"

return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")

# Non-streaming response.
ret = await generate_request(adapted_request)

prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"]
text = ret["text"]
token_logprob_pos = prompt_tokens
if request.echo:
token_logprob_pos = 0
text = request.prompt + text
else:
token_logprob_pos = prompt_tokens

logprobs = (
make_openai_style_logprobs(ret["meta_info"]["token_logprob"][token_logprob_pos:])
if request.logprobs is not None
else None
)
choice_data = CompletionResponseChoice(
index=0,
text=ret["text"],
logprobs=None,
text=text,
logprobs=logprobs,
finish_reason=None, # TODO(comaniac): Add finish reason.
)

prompt_tokens = ret["meta_info"]["prompt_tokens"]
completion_tokens = ret["meta_info"]["completion_tokens"]
response = CompletionResponse(
id=ret["meta_info"]["id"],
model=request.model,
Expand Down Expand Up @@ -204,7 +249,9 @@ async def v1_chat_completions(raw_request: Request):
if not isinstance(m.content, str):
raise HTTPException(
status_code=503,
detail="Structured content requests not supported with HuggingFace Chat Templates. Make sure the server specifies a sglang chat template.",
detail="Structured content requests not supported with "
"HuggingFace Chat Templates. "
"Make sure the server specifies a sglang chat template.",
)
prompt = tokenizer_manager.tokenizer.apply_chat_template(
request.messages, tokenize=False, add_generation_prompt=True
Expand Down
Loading