Skip to content

Commit

Permalink
Allow setting max_memory
Browse files Browse the repository at this point in the history
  • Loading branch information
yukw777 committed Apr 5, 2023
1 parent b9b9a6c commit 42a341d
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.

import logging
import json
from dataclasses import dataclass, field
from typing import Union
from typing import Union, Optional
from functools import partial

import torch
Expand Down Expand Up @@ -43,6 +44,7 @@ class ModelArguments:
model_name_or_path: str
train_in_8bit: bool = field(default=False)
device_map: Union[None, str, dict[str, Union[int, str, torch.device]]] = field(default=None)
max_memory_config: Optional[str] = field(default=None)


@dataclass
Expand Down Expand Up @@ -113,6 +115,9 @@ def train() -> None:
model_args.model_name_or_path,
load_in_8bit=model_args.train_in_8bit,
device_map="auto" if model_args.train_in_8bit and model_args.device_map is None else model_args.device_map,
max_memory={int(k): v for k, v in json.loads(model_args.max_memory_config).items() if k.isnumeric()}
if model_args.max_memory_config is not None
else None,
)
if model_args.train_in_8bit:
logging.warning("Preparing 8bit training")
Expand Down

0 comments on commit 42a341d

Please sign in to comment.