forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
executor.py
547 lines (447 loc) · 18.7 KB
/
executor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
import asyncio
import time
from pathlib import Path
from queue import Queue
from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union
import torch
from janus import Queue as AsyncQueue
from transformers import AutoTokenizer
import tensorrt_llm.bindings as tllm
from tensorrt_llm._utils import mpi_broadcast, mpi_rank, mpi_world_size
from tensorrt_llm.hlapi.mpi_session import MpiSession, NodeSession, SocketClient
from tensorrt_llm.hlapi.tokenizer import TokenizerBase
from tensorrt_llm.hlapi.utils import GenerationOutput, print_traceback_on_error
from tensorrt_llm.logger import logger
def has_event_loop() -> bool:
try:
asyncio.get_running_loop()
except RuntimeError:
return False
return True
class GenerationRequest:
def __init__(self,
req_id: int,
ids: torch.Tensor,
end_id: int,
pad_id: int,
streaming: bool = True,
**kwargs):
self.prompt = None
self.ids = ids
self.streaming = streaming
self.kwargs = kwargs
self.end_id = end_id
self.pad_id = pad_id
self._id = req_id
def get_inference_request(self) -> tllm.InferenceRequest:
ir = tllm.InferenceRequest(self._id)
ir.input_ids = self.ids.to(dtype=torch.int32)
ir.is_streaming = self.streaming
def set_property(name: str,
dtype: torch.dtype = torch.int32,
default: Any = None):
if name in self.kwargs or default is not None:
value = self.kwargs.get(name, default)
setattr(ir, name, torch.tensor([value], dtype=dtype))
set_property("max_new_tokens", default=[8])
set_property("end_id", default=self.end_id)
set_property("pad_id", default=self.pad_id)
set_property("min_length")
set_property("temperature", torch.float32)
set_property("runtime_top_k", torch.float32)
set_property("runtime_top_p", torch.float32)
set_property("random_seed", torch.int64)
return ir
class GenerationResult(GenerationOutput):
def __init__(self,
generation_request: GenerationRequest,
tokenizer: Optional[TokenizerBase] = None) -> None:
self.running = True
self.done = False
self.generation_request = generation_request
self.tokenizer = tokenizer
if has_event_loop():
self._base_queue = AsyncQueue()
self.queue = self._base_queue.sync_q
self.aqueue = self._base_queue.async_q
else:
self._base_queue = Queue()
self.queue = self._base_queue
self.aqueue = None
self.generation: Optional[torch.Tensor]
if generation_request.streaming:
self.generation = generation_request.ids
else:
self.generation = None
# TODO: fill the following fields from GenerationOutput
self.token_ids = []
self.logprobs = []
def enqueue(self, msg: Tuple[Union[str, Dict[str, torch.Tensor]], bool]):
self.queue.put(msg)
def handle_generation_msg(self, msg: Union[str, Dict[str, torch.Tensor]]):
if isinstance(msg, str):
raise RuntimeError(msg)
# TODO[chunweiy]: Unify the msg format for parallel and non-parallel mode
if isinstance(msg, dict):
self.token_ids = msg["output_ids"][0][0]
else:
# this is for parallel mode
assert isinstance(msg, list)
self.token_ids = msg[0]
@staticmethod
def process_generation(msg: dict):
token_ids = msg["output_ids"][0]
# TODO: add other fields if needed
return token_ids
def wait_step(self, timeout: Optional[float] = None):
msg, self.done = self.queue.get(timeout=timeout)
self.handle_generation_msg(msg)
async def await_step(self):
assert self.aqueue is not None
msg, self.done = await self.aqueue.get()
self.handle_generation_msg(msg)
@property
def text(self) -> str:
return self.tokenizer.decode(self.token_ids)
def wait_completion(self,
timeout: Optional[float] = None) -> "GenerationResult":
while not self.done:
self.wait_step(timeout)
return self
async def await_completion(self) -> "GenerationResult":
while not self.done:
await self.await_step()
return self
def __iter__(self):
return self
def __next__(self):
if self.done:
raise StopIteration
self.wait_step()
return self
def __aiter__(self):
return self
async def __anext__(self):
if self.done:
raise StopAsyncIteration
await self.await_step()
return self
class GenerationExecutor:
TERMINATE_REQUEST_ID = 0
def __init__(
self,
engine_dir: Path,
tokenizer: Union[str, Path, TokenizerBase],
max_beam_width: int = 1,
executor_type: tllm.TrtGptModelType = tllm.TrtGptModelType.
InflightBatching,
executor_policy: tllm.SchedulerPolicy = tllm.SchedulerPolicy.
GUARANTEED_NO_EVICT,
executor_config: tllm.TrtGptModelOptionalParams = tllm.
TrtGptModelOptionalParams(),
) -> None:
self.active_requests = 0
self.tokenizer = tokenizer
if not isinstance(tokenizer, TokenizerBase):
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer,
legacy=False,
padding_side='left',
truncation_side='left',
trust_remote_code=True,
use_fast=True)
# NOTE: underscore variables are used for communication with the C++ runtime
self._requests: List[tllm.InferenceRequest] = []
self._results: Dict[int, GenerationResult] = {}
self._cancelled_ids: Set[int] = set()
self._completed: Queue = Queue()
if has_event_loop():
self._stats = AsyncQueue()
self.stats_queue = self._stats.sync_q
self.stats_aqueue = self._stats.async_q
else:
self._stats = Queue()
self.stats_queue = self._stats
self.stats_aqueue = None
self.engine = tllm.GptManager(engine_dir, executor_type, max_beam_width,
executor_policy, self.fetch_requests,
self.handle_response,
self.get_cancelled_ids, self.handle_stats,
executor_config,
GenerationExecutor.TERMINATE_REQUEST_ID)
self._next_request_id = GenerationExecutor.TERMINATE_REQUEST_ID + 1
def submit(self, request: GenerationRequest) -> GenerationResult:
"""
Low-level API to the executor. Return a "future" GenerationResult which can be waited.
"""
inference_request = request.get_inference_request()
result = GenerationResult(request, self.tokenizer)
self._results[inference_request.request_id] = result
self.active_requests += 1
self._requests.append(inference_request)
return result
def get_next_request_id(self) -> int:
# underlying type is uint64
uint64_max = 2**64 - 1
request_id = self._next_request_id
self._next_request_id = (request_id + 1) % uint64_max
return request_id
def generate_async(
self, prompt: Union[str, List[str]], streaming: bool,
max_new_tokens: Union[int, List[int]]
) -> Union[GenerationResult, List[GenerationResult]]:
unbatched = isinstance(prompt, str)
if unbatched:
assert isinstance(max_new_tokens, int)
prompt = [prompt]
max_new_tokens = [max_new_tokens]
assert isinstance(self.tokenizer, TokenizerBase)
def get_ids(prompt: str) -> torch.Tensor:
return self.tokenizer.encode(prompt,
return_tensors="pt",
return_attention_mask=False)
pad_id = getattr(self.tokenizer, "pad_token_id",
self.tokenizer.eos_token_id)
results = [
self.submit(
GenerationRequest(req_id=self.get_next_request_id(),
ids=get_ids(p),
streaming=streaming,
max_new_tokens=[m],
pad_id=pad_id,
end_id=self.tokenizer.eos_token_id))
for p, m in zip(prompt, max_new_tokens)
]
if unbatched:
results = results[0]
return results
def generate(
self, prompt: Union[str, List[str]], max_new_tokens: Union[int,
List[int]]
) -> Union[GenerationResult, List[GenerationResult]]:
results = self.generate_async(prompt, False, max_new_tokens)
result_list = [results] if isinstance(results,
GenerationRequest) else results
for result in result_list:
result.wait_completion()
return results
def get_stats(self):
return self.stats_queue.get()
async def aget_stats(self):
assert self.stats_aqueue is not None
return await self.stats_aqueue.get()
def wait_first_completed(
self, futures: List[GenerationResult]
) -> Generator[GenerationResult, None, None]:
wait_set = set(f.generation_request._id for f in futures)
# clear already-finished requests
for f in futures:
if f.done:
wait_set.remove(f.generation_request._id)
yield f
# wait remaining active requests
while len(wait_set) > 0:
req_id = self._completed.get()
if req_id in wait_set:
wait_set.remove(req_id)
yield self._results[req_id]
# Callbacks for BatchManager
def fetch_requests(self, max_num_sequences) -> List[tllm.InferenceRequest]:
fetched = []
for _ in range(max_num_sequences):
if len(self._requests) == 0:
break
fetched.append(self._requests.pop())
return fetched
def handle_response(self, req_id: int, tensors: List[tllm.NamedTensor],
finished: bool, err: str) -> None:
self._results[req_id].enqueue(
({t.name: t.tensor
for t in tensors
if t.tensor is not None} if not err else err, finished))
if finished:
self._completed.put(req_id)
def get_cancelled_ids(self) -> Set[int]:
return self._cancelled_ids
def handle_stats(self, stats: str):
while self.stats_queue.full():
self.stats_queue.get()
self.stats_queue.put(stats)
class ParallelGenerationExecutor(GenerationExecutor):
''' GenerationExecutor with MPI enabled. '''
def __init__(
self,
tp_size: int,
engine_dir: Path,
tokenizer: Union[str, Path, TokenizerBase],
max_beam_width: int = 1,
executor_type: tllm.TrtGptModelType = tllm.TrtGptModelType.
InflightFusedBatching,
executor_policy: tllm.SchedulerPolicy = tllm.SchedulerPolicy.
GUARANTEED_NO_EVICT,
kvcache_free_gpu_memory_fraction: Optional[float] = None,
socket_client: Optional[SocketClient] = None,
# TODO: support serialization
# executor_config: tllm.TrtGptModelOptionalParams = tllm.TrtGptModelOptionalParams(),
) -> None:
assert kvcache_free_gpu_memory_fraction is None or isinstance(
kvcache_free_gpu_memory_fraction, float)
self.on_PMP = mpi_world_size() == 1
self.on_MPI = mpi_world_size() > 1
self._terminated = False
self._terminated_sync = False
self.active_requests = 0
self.tokenizer = tokenizer
if not isinstance(tokenizer, TokenizerBase):
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer,
legacy=False,
padding_side='left',
truncation_side='left',
trust_remote_code=True,
use_fast=True)
# NOTE: underscore variables are used for communication with the C++ runtime
self._requests: list[tllm.InferenceRequest] = []
self._results: dict[int, GenerationResult] = {}
self._cancelled_ids: set[int] = set()
self._completed: Queue = Queue()
if has_event_loop():
self._stats = AsyncQueue()
self.stats_queue = self._stats.sync_q
self.stats_aqueue = self._stats.async_q
else:
self._stats = Queue()
self.stats_queue = self._stats
self.stats_aqueue = None
self._next_request_id = GenerationExecutor.TERMINATE_REQUEST_ID + 1
self.socket_client = socket_client
if self.on_PMP:
# initialize the executor on each MPI node
assert isinstance(self.tokenizer,
TokenizerBase), "tokenizer not initialized"
self.mpi_session = MpiSession(
n_workers=tp_size,
async_callback=self._async_listener_calllback)
self.socket_client = self.mpi_session.get_socket_client()
self.mpi_session.submit_sync(
ParallelGenerationExecutor._node_init_executor_task, engine_dir,
self.tokenizer, max_beam_width, executor_type, executor_policy,
kvcache_free_gpu_memory_fraction, self.socket_client)
else:
executor_config = tllm.TrtGptModelOptionalParams()
if kvcache_free_gpu_memory_fraction is not None:
executor_config.kv_cache_config.free_gpu_memory_fraction = kvcache_free_gpu_memory_fraction
self.engine = tllm.GptManager(
engine_dir, executor_type, max_beam_width, executor_policy,
self.fetch_requests_on_mpi_node,
self.handle_response_on_mpi_node, self.get_cancelled_ids,
self.handle_stats, executor_config,
GenerationExecutor.TERMINATE_REQUEST_ID)
def submit(self, request: GenerationRequest) -> GenerationResult:
# submit on the PMP
inference_request = request.get_inference_request()
result = GenerationResult(request, self.tokenizer)
self._results[inference_request.request_id] = result
self.active_requests += 1
self.mpi_session.submit_sync(
ParallelGenerationExecutor._node_add_request_task,
inference_request)
return result
@print_traceback_on_error
@staticmethod
def _node_add_request_task(inference_request):
executor: GenerationExecutor = NodeSession.state
assert isinstance(executor,
GenerationExecutor), 'executor not initialized'
executor._requests.append(inference_request)
@print_traceback_on_error
@staticmethod
def _node_init_executor_task(
engine_dir: Path,
tokenizer: TokenizerBase,
max_beam_width: int,
executor_type: tllm.TrtGptModelType,
executor_policy: tllm.SchedulerPolicy,
kvcache_free_gpu_memory_fraction: Optional[float],
socket_client: Optional[SocketClient],
# executor_config: tllm.TrtGptModelOptionalParams
):
''' Create a local GenerationExecutor instance for each MPI process. '''
assert not NodeSession.is_initialized(), 'executor already initialized'
logger.info(f'Initializing executor on MPI node #{mpi_rank()}')
tp_size = mpi_world_size()
NodeSession.state = ParallelGenerationExecutor(
tp_size,
engine_dir,
tokenizer,
max_beam_width,
executor_type,
executor_policy,
kvcache_free_gpu_memory_fraction=kvcache_free_gpu_memory_fraction,
socket_client=socket_client)
# Callbacks for BatchManager
@print_traceback_on_error
def fetch_requests_on_mpi_node(
self, max_num_sequences) -> List[tllm.InferenceRequest]:
if mpi_rank() != 0 or self._terminated_sync:
if self._terminated:
return []
terminated = mpi_broadcast(self._terminated, 0)
if terminated:
logger.warning(f'#node{mpi_rank()} to terminate')
self._terminated_sync = True
self._terminated = True
if terminated:
return []
batch_size = 0
fetched = []
if mpi_rank() == 0:
batch_size = min(len(self._requests), max_num_sequences)
batch_size = mpi_broadcast(batch_size, 0)
for _ in range(batch_size):
# the MPIPoolExecutor will always submit the same input to every worker, sometimes they arrive at slightly different time
while len(self._requests) == 0:
time.sleep(0.05)
fetched.append(self._requests.pop())
return fetched
def handle_response_on_mpi_node(self, req_id: int,
tensors: List[tllm.NamedTensor],
finished: bool, err: str) -> None:
if mpi_rank() != 0:
return
tensor_dic = {t.name: t.tensor for t in tensors if t.tensor is not None}
output = GenerationResult.process_generation(
tensor_dic) if not err else err
self.socket_client.send(
dict(
req_id=req_id,
output=output if isinstance(output, str) else output.tolist(),
finished=finished,
))
def _async_listener_calllback(self, data: Dict[str, Any]):
req_id = data['req_id']
output = data['output']
finished = data['finished']
self._results[req_id].enqueue((output, finished))
if finished:
self._completed.put(req_id)
@print_traceback_on_error
@staticmethod
def _node_quit_task():
executor: GenerationExecutor = NodeSession.state
assert isinstance(executor,
GenerationExecutor), 'executor not initialized'
if mpi_rank() == 0:
executor._terminated = True
time.sleep(1)
executor.engine.shutdown()
NodeSession.state = None
def _shutdown_mpi_nodes(self):
self.mpi_session.submit_sync(ParallelGenerationExecutor._node_quit_task)
def shutdown(self):
if self.on_PMP and self.mpi_session is not None:
self._shutdown_mpi_nodes()
self.mpi_session.shutdown()
self.mpi_session = None
def __del__(self):
self.shutdown()