Skip to content

Commit

Permalink
add accelerator config (#66)
Browse files Browse the repository at this point in the history
Co-authored-by: wangyuxin <[email protected]>
  • Loading branch information
wangyuxinwhy and wangyuxin committed Jul 24, 2023
1 parent fd3c51a commit d9da840
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions uniem/finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from enum import Enum
from pathlib import Path
from typing import Callable, Iterable, Sequence, Sized, cast
from typing import Any, Callable, Iterable, Sequence, Sized, cast

import torch
from accelerate import Accelerator
Expand Down Expand Up @@ -235,10 +235,12 @@ def run(
drop_last: bool = False,
shuffle: bool = False,
num_workers: int = 0,
# Trainer
epochs: int = 3,
# Aceelerator
mixed_precision: MixedPrecisionType = MixedPrecisionType.no,
gradient_accumulation_steps: int = 1,
accelerator_kwargs: dict[str, Any] | None = None,
# Trainer
epochs: int = 3,
save_on_epoch_end: bool = False,
num_max_checkpoints: int = 1,
log_with: str | LoggerType | GeneralTracker | list[str | LoggerType | GeneralTracker] | None = None,
Expand All @@ -257,11 +259,13 @@ def run(
automatic_checkpoint_naming=True,
total_limit=num_max_checkpoints,
)
accelerator_kwargs = accelerator_kwargs or {}
accelerator = Accelerator(
mixed_precision=mixed_precision.value,
gradient_accumulation_steps=gradient_accumulation_steps,
project_config=project_config,
log_with=log_with,
**accelerator_kwargs,
)
self.accelerator = accelerator
accelerator.init_trackers('uniem')
Expand Down

0 comments on commit d9da840

Please sign in to comment.