Skip to content

Commit

Permalink
E2e device cuda (#575)
Browse files Browse the repository at this point in the history
* use torch.cuda.current_device() instead of local_rank

* ignore NVML errors for gpu stats

* llama lora packing e2e tests
  • Loading branch information
winglian committed Sep 15, 2023
1 parent 9218ebe commit 2414673
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 6 deletions.
1 change: 1 addition & 0 deletions .github/workflows/e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ jobs:
- name: Install dependencies
run: |
pip3 install -e .
pip3 install flash-attn
pip3 install -r requirements-tests.txt
- name: Run e2e tests
Expand Down
13 changes: 8 additions & 5 deletions src/axolotl/utils/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pynvml
import torch
from pynvml.nvml import NVMLError


def gpu_memory_usage(device=0):
Expand All @@ -20,11 +21,13 @@ def gpu_memory_usage_smi(device=0):
device = device.index
if isinstance(device, str) and device.startswith("cuda:"):
device = int(device[5:])

pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
return info.used / 1024.0**3
try:
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
return info.used / 1024.0**3
except NVMLError:
return 0.0


def log_gpu_memory_usage(log, msg, device):
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_device():
cfg.device_map = "auto"
else:
if cfg.device.startswith("cuda"):
cfg.device_map = {"": cfg.local_rank}
cfg.device_map = {"": torch.cuda.current_device()}
else:
cfg.device_map = {"": cfg.device}

Expand Down
42 changes: 42 additions & 0 deletions tests/e2e/test_lora_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,45 @@ def test_lora(self):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)

def test_lora_packing(self):
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model_config": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(),
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)

0 comments on commit 2414673

Please sign in to comment.