Skip to content

Commit

Permalink
Merge pull request eric-mitchell#22 from eric-mitchell/fsdp_port_disc…
Browse files Browse the repository at this point in the history
…overy

Updated launch logic to auto-discover a free FSDP port if none is spe…
  • Loading branch information
eric-mitchell committed Jul 23, 2023
2 parents b4c2ccd + 1ea1c06 commit 49066e6
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 2 deletions.
2 changes: 1 addition & 1 deletion config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ eval_batch_size: 16
debug: false

# the port to use for FSDP
fsdp_port: 12355
fsdp_port: null

# which dataset(s) to train on; can pass a list like datasets=[hh,shp]
datasets:
Expand Down
7 changes: 6 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
torch.backends.cuda.matmul.allow_tf32 = True
import torch.nn as nn
import transformers
from utils import get_local_dir, get_local_run_dir, disable_dropout, init_distributed
from utils import get_local_dir, get_local_run_dir, disable_dropout, init_distributed, get_open_port
import os
import hydra
import torch.multiprocessing as mp
Expand Down Expand Up @@ -61,6 +61,11 @@ def main(config: DictConfig):
print('Setting eval_every to', config.eval_every - config.eval_every % config.batch_size)
config.eval_every = config.eval_every - config.eval_every % config.batch_size

if 'FSDP' in config.trainer and config.fsdp_port is None:
free_port = get_open_port()
print('no FSDP port specified; using open port for FSDP:', free_port)
config.fsdp_port = free_port

print(OmegaConf.to_yaml(config))

config_path = os.path.join(config.local_run_dir, 'config.yaml')
Expand Down
6 changes: 6 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
from typing import Dict, Union, Type, List


def get_open_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0)) # bind to all interfaces and use an OS provided port
return s.getsockname()[1] # return only the port number


def get_remote_file(remote_path, local_path=None):
hostname, path = remote_path.split(':')
local_hostname = socket.gethostname()
Expand Down

0 comments on commit 49066e6

Please sign in to comment.