Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Fix W&B callback for distributed training (#5223)
Browse files Browse the repository at this point in the history
* fix wandb callback for distributed training

* fix

* close out

Co-authored-by: Dirk Groeneveld <[email protected]>
  • Loading branch information
epwalsh and dirkgr committed May 26, 2021
1 parent 59df2ad commit 2d8f390
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- When `PretrainedTransformerIndexer` folds long sequences, it no longer loses the information from token type ids.
- Fixed documentation for `GradientDescentTrainer.cuda_device`.
- Fixed `wandb` callback to work in distributed training.


## [v2.4.0](https://github.com/allenai/allennlp/releases/tag/v2.4.0) - 2021-04-22
Expand Down
2 changes: 1 addition & 1 deletion allennlp/training/callbacks/log_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def on_batch(
batch_grad_norm: Optional[float] = None,
**kwargs,
) -> None:
if not is_training and not is_primary:
if not is_training or not is_primary:
return None
assert self.trainer is not None

Expand Down
34 changes: 23 additions & 11 deletions allennlp/training/callbacks/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,7 @@ def __init__(

self._watch_model = watch_model
self._files_to_save = files_to_save

import wandb

self.wandb = wandb
self.wandb.init(
self._wandb_kwargs: Dict[str, Any] = dict(
dir=os.path.abspath(serialization_dir),
project=project,
entity=entity,
Expand All @@ -105,9 +101,6 @@ def __init__(
**(wandb_kwargs or {}),
)

for fpath in self._files_to_save:
self.wandb.save(os.path.join(serialization_dir, fpath), base_path=serialization_dir)

@overrides
def log_scalars(
self,
Expand All @@ -122,7 +115,7 @@ def log_tensors(
self, tensors: Dict[str, torch.Tensor], log_prefix: str = "", epoch: Optional[int] = None
) -> None:
self._log(
{k: self.wandb.Histogram(v.cpu().data.numpy().flatten()) for k, v in tensors.items()},
{k: self.wandb.Histogram(v.cpu().data.numpy().flatten()) for k, v in tensors.items()}, # type: ignore
log_prefix=log_prefix,
epoch=epoch,
)
Expand All @@ -134,12 +127,31 @@ def _log(
dict_to_log = {f"{log_prefix}/{k}": v for k, v in dict_to_log.items()}
if epoch is not None:
dict_to_log["epoch"] = epoch
self.wandb.log(dict_to_log, step=self.trainer._batch_num_total) # type: ignore[union-attr]
self.wandb.log(dict_to_log, step=self.trainer._batch_num_total) # type: ignore

@overrides
def on_start(
self, trainer: "GradientDescentTrainer", is_primary: bool = True, **kwargs
) -> None:
super().on_start(trainer, is_primary=is_primary, **kwargs)

if not is_primary:
return None

import wandb

self.wandb = wandb
self.wandb.init(**self._wandb_kwargs)

for fpath in self._files_to_save:
self.wandb.save( # type: ignore
os.path.join(self.serialization_dir, fpath), base_path=self.serialization_dir
)

if self._watch_model:
self.wandb.watch(self.trainer.model) # type: ignore[union-attr]
self.wandb.watch(self.trainer.model) # type: ignore

@overrides
def close(self) -> None:
super().close()
self.wandb.finish() # type: ignore

0 comments on commit 2d8f390

Please sign in to comment.