Skip to content

Commit

Permalink
Added errors for download failure and also a --force-redownload argum…
Browse files Browse the repository at this point in the history
…ent (#820)

* Added errors for download failure and also a --force-redownload argument

* Update NeoXArgs docs automatically

* Update NeoXArgs docs automatically

---------

Co-authored-by: github-actions <[email protected]>
Co-authored-by: Quentin Anthony <[email protected]>
  • Loading branch information
3 people authored Mar 14, 2023
1 parent 2da9a73 commit dd5a53d
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 15 deletions.
14 changes: 4 additions & 10 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = cbed1b5
Default = 09fc35d

current git hash of repository

Expand Down Expand Up @@ -1830,7 +1830,9 @@ Args for deepspeed runner (deepspeed.launcher.runner).
Default = None

Either "tune", "run", or `None`.




- **no_ssh_check**: bool

Default = False
Expand All @@ -1845,11 +1847,3 @@ Args for deepspeed runner (deepspeed.launcher.runner).

Adds a `--comment` to the DeepSpeed launch command. In DeeperSpeed this is passed on to the SlurmLauncher as well. Sometime necessary for cluster rules, or so I've heard.



- **no_ssh_check**: bool

Default = False

If `True` and running with multiple nodes, then DeepSpeedd doesn't conduct a check to ensure the head node is reachable with ssh.

8 changes: 8 additions & 0 deletions prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ def get_args():
parser.add_argument(
"-m", "--merge-file", default=None, help=f"Tokenizer merge file (if required)"
)
parser.add_argument(
"-f",
"--force-redownload",
dest="force_redownload",
default=False,
action="store_true",
)
return parser.parse_args()


Expand All @@ -65,4 +72,5 @@ def get_args():
data_dir=args.data_dir,
vocab_file=args.vocab_file,
merge_file=args.merge_file,
force_redownload=args.force_redownload,
)
25 changes: 20 additions & 5 deletions tools/corpora.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
merge_file=None,
vocab_file=None,
data_dir=None,
force_redownload=None,
num_workers=None,
):
if tokenizer_type is None:
Expand All @@ -49,6 +50,8 @@ def __init__(
data_dir = os.environ.get("DATA_DIR", "./data")
if merge_file is None:
merge_file = f"{data_dir}/gpt2-merges.txt"
if force_redownload is None:
force_redownload = False
if vocab_file is None:
if tokenizer_type == "GPT2BPETokenizer":
vocab_file = f"{data_dir}/gpt2-vocab.json"
Expand All @@ -64,6 +67,7 @@ def __init__(
self._merge_file = merge_file
self._vocab_file = vocab_file
self._data_dir = data_dir
self._force_redownload = force_redownload
self._num_workers = num_workers

@property
Expand Down Expand Up @@ -121,9 +125,14 @@ def download(self):
"""downloads dataset"""
os.makedirs(os.path.join(self.base_dir, self.name), exist_ok=True)
for url in self.urls:
os.system(
f"wget {url} -O {os.path.join(self.base_dir, self.name, os.path.basename(url))}"
)
try:
os_cmd = f"wget {url} -O {os.path.join(self.base_dir, self.name, os.path.basename(url))}"
if os.system(os_cmd) != 0:
raise Exception(
f"Cannot download file at URL {url}: server may be down"
)
except Exception as e:
raise Exception(f"Download error: {e}")

def tokenize(self):
"""tokenizes dataset"""
Expand Down Expand Up @@ -151,9 +160,13 @@ def tokenize(self):
os.system(cmd)

def prepare(self):
if not self.exists():
if self._force_redownload:
self.download()
self.tokenize()
else:
if not self.exists():
self.download()

self.tokenize()


class Enron(DataDownloader):
Expand Down Expand Up @@ -325,6 +338,7 @@ def prepare_dataset(
data_dir: str = None,
vocab_file: str = None,
merge_file: str = None,
force_redownload: bool = None,
num_workers: int = None,
):
"""
Expand All @@ -349,6 +363,7 @@ def prepare_dataset(
vocab_file=vocab_file,
merge_file=merge_file,
data_dir=data_dir,
force_redownload=force_redownload,
num_workers=num_workers,
)
d.prepare()

0 comments on commit dd5a53d

Please sign in to comment.