Skip to content

Commit

Permalink
chore: pre-commit linting fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Sep 16, 2023
1 parent 6237305 commit 235ec15
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
6 changes: 4 additions & 2 deletions src/axolotl/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,15 @@ def reduce_and_broadcast(fn1, fn2):
if not is_distributed():
return fn2([fn1()])

gathered_values = gather_scalar_from_all_ranks(fn1, world_size=dist.get_world_size())
gathered_values = gather_scalar_from_all_ranks(
fn1, world_size=dist.get_world_size()
)

# Use compute_and_broadcast to compute the reduced value on the main process
# and then broadcast it to all ranks
return compute_and_broadcast(lambda: fn2(gathered_values))


def broadcast_dict(vals: dict):
if not is_distributed():
return vals
Expand Down
7 changes: 6 additions & 1 deletion src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@
)
from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader
from axolotl.utils.distributed import is_distributed, is_main_process, reduce_and_broadcast, zero_first
from axolotl.utils.distributed import (
is_distributed,
is_main_process,
reduce_and_broadcast,
zero_first,
)
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup

LOG = logging.getLogger("axolotl")
Expand Down

0 comments on commit 235ec15

Please sign in to comment.