forked from ModelTC/lightllm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
quick_launch_docker.py
executable file
·86 lines (80 loc) · 2.02 KB
/
quick_launch_docker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#!/bin/env python3
import os
import argparse
args = argparse.ArgumentParser()
group_container = args.add_argument_group("container")
group_container.add_argument(
"--image",
type=str,
default="ghcr.io/modeltc/lightllm:main",
help="default to ghcr.io/modeltc/lightllm:main",
)
group_container.add_argument(
"--name", type=str, required=False, help="set a name to the container"
)
group_container.add_argument(
"--keep-container",
"-K",
action="store_true",
help="default not to keep the container",
)
group_container.add_argument(
"--shm-size",
type=str,
required=False,
help="default to half of the RAM size",
)
group_server = args.add_argument_group("server")
group_server.add_argument(
"-m", "--model", type=str, required=True, help="path to model dir"
)
group_server.add_argument("-p", "--port", type=int, default=8080)
group_server.add_argument(
"-n", "--num-proc", type=int, default=1, help="number of process/gpus"
)
group_server.add_argument("-mt", "--max-total-tokens", type=int, default=4096)
args = args.parse_args()
model_path = os.path.abspath(args.model)
shm_size = (
args.shm_size
if args.shm_size
else (os.sysconf("SC_PAGE_SIZE") * os.sysconf("SC_PHYS_PAGES") // 2)
)
launch_args = [
"docker",
"run",
"-it",
"--gpus",
"all",
"-p",
f"{args.port}:{args.port}",
"-v",
f"{model_path}:{model_path}",
"--shm-size",
str(shm_size),
]
if args.name:
launch_args.extend(["--name", args.name])
if not args.keep_container:
launch_args.append("--rm")
launch_args.append(args.image)
launch_args.extend(
[
"/bin/bash",
"/lightllm/tools/resolve_ptx_version",
"python",
"-m",
"lightllm.server.api_server",
"--model_dir",
model_path,
"--host",
"0.0.0.0",
"--port",
args.port,
"--tp",
args.num_proc,
]
)
launch_args = list(map(str, launch_args))
print(f'launching: {" ".join(launch_args)}')
os.execvp(launch_args[0], launch_args)