Skip to content

Commit

Permalink
support multi-gpu inference for server.
Browse files Browse the repository at this point in the history
  • Loading branch information
fengyh3 committed Apr 21, 2023
1 parent dd5193c commit 47107a2
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion llama_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
app = Flask(__name__)
args = None
lm_generation = None
torch.cuda.set_device(0)


def init_model():
Expand Down Expand Up @@ -65,6 +64,7 @@ def init_model():
gpus = ["cuda:" + str(i) for i in range(args.world_size)]
model = tp.tensor_parallel(model, gpus)
else:
torch.cuda.set_device(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

Expand Down

0 comments on commit 47107a2

Please sign in to comment.