Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: messy draft for llama #220

Closed
wants to merge 77 commits into from
Closed
Changes from 1 commit
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
0b159f1
stop using the generation builder
thejaminator Apr 26, 2023
f928d60
fix partial issues
thejaminator Apr 26, 2023
203bf5b
test use accelerate
thejaminator Apr 26, 2023
8058a38
test use accelerate
thejaminator Apr 26, 2023
3c0178c
fix devices hopefully
thejaminator Apr 26, 2023
ec5a436
fix partial
thejaminator Apr 26, 2023
4035ac4
print the accelerate device
thejaminator Apr 26, 2023
b22c0dc
give up on accelerate
thejaminator Apr 26, 2023
02c2f11
add fsdp
thejaminator Apr 26, 2023
b384b8a
save changes
thejaminator Apr 26, 2023
bae6e1e
commit wroking
thejaminator Apr 27, 2023
9c0845d
refactor exception
thejaminator Apr 27, 2023
e5c132f
remove unneeded
thejaminator Apr 27, 2023
a8d161f
commit working dumb method
thejaminator Apr 27, 2023
c4a17f6
set format in main process
thejaminator Apr 27, 2023
8b5d0d3
clone tensor
thejaminator Apr 27, 2023
538b119
change fsdp to disable cpu_offload
thejaminator Apr 27, 2023
808b152
set format to torch in test
thejaminator Apr 27, 2023
8ce7326
fix closure bug not sharing memory
thejaminator Apr 27, 2023
15834e8
more logs
thejaminator Apr 27, 2023
99c1cb9
log the output sent back
thejaminator Apr 27, 2023
08f6fbc
more logs
thejaminator Apr 27, 2023
f509688
shift it back to float32
thejaminator Apr 27, 2023
439cf8f
print loaded closure
thejaminator Apr 27, 2023
ae4e052
add logging of sentinel
thejaminator Apr 27, 2023
6ab5e53
fix deadlock maybe?
thejaminator Apr 27, 2023
ae5eb25
add print for breaking
thejaminator Apr 27, 2023
0a0fc40
more prints
thejaminator Apr 27, 2023
6d7fa08
set low min mem for fsdp
thejaminator Apr 27, 2023
820388f
set low min mem for fsdp
thejaminator Apr 27, 2023
bfb8e12
add counter
thejaminator Apr 27, 2023
943ae48
stop destroying the process group
thejaminator Apr 27, 2023
f3aa91c
re log
thejaminator Apr 27, 2023
e813a64
replicate it by 2
thejaminator Apr 27, 2023
6649635
add assertions
thejaminator Apr 27, 2023
61cfce8
add type of exception
thejaminator Apr 27, 2023
44ec152
try increasing timeout
thejaminator Apr 27, 2023
d24ded8
try out not sending the sentinel
thejaminator Apr 27, 2023
7ffcbaf
fix typo
thejaminator Apr 27, 2023
eaf3f42
log more
thejaminator Apr 27, 2023
ea2e2ff
try waiting
thejaminator Apr 27, 2023
6fbacce
add sleep
thejaminator Apr 27, 2023
ea7694e
make it 5
thejaminator Apr 27, 2023
fdd854c
skip destroying group
thejaminator Apr 27, 2023
6b77cbb
try while true
thejaminator Apr 27, 2023
e5960ef
Revert "try while true"
thejaminator Apr 27, 2023
2e823ef
Revert "skip destroying group"
thejaminator Apr 27, 2023
659309b
Revert "make it 5"
thejaminator Apr 27, 2023
9eac600
Revert "add sleep"
thejaminator Apr 27, 2023
f5b8a53
Revert "try waiting"
thejaminator Apr 27, 2023
e644ec9
Revert "log more"
thejaminator Apr 27, 2023
322c9d8
Revert "fix typo"
thejaminator Apr 27, 2023
a2bae35
Revert "try out not sending the sentinel"
thejaminator Apr 27, 2023
34448dc
set num workeres to 8
thejaminator Apr 27, 2023
835562e
add commit
thejaminator Apr 27, 2023
6cf1c76
fsdp_single rename
thejaminator Apr 27, 2023
f81e6bf
more logs
thejaminator Apr 27, 2023
d50cda7
add range
thejaminator Apr 27, 2023
00de896
disable cpu offload
thejaminator Apr 27, 2023
155e91f
set min memory by dividng
thejaminator Apr 27, 2023
9b42104
add more logs for tests
thejaminator Apr 27, 2023
5accebb
rename tests
thejaminator Apr 27, 2023
4fe807d
use a sentinel class
thejaminator Apr 27, 2023
cfa854b
add log for FSDP
thejaminator Apr 27, 2023
b1af49e
save changes
thejaminator Apr 27, 2023
3804a84
add tol and better test
thejaminator Apr 27, 2023
4d20887
fix fsdp?? with imap??
thejaminator Apr 27, 2023
2cd4401
add assert for outputs
thejaminator Apr 27, 2023
72fc632
add comment
thejaminator Apr 27, 2023
1fca929
check it again
thejaminator Apr 27, 2023
6c9920d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 27, 2023
78c5847
fix imap
thejaminator Apr 27, 2023
192385b
fix assertion
thejaminator Apr 27, 2023
edd67de
try second
thejaminator Apr 27, 2023
fceb165
try second
thejaminator Apr 27, 2023
2772300
fix intiialization
thejaminator Apr 27, 2023
0fd5411
remove len hack
thejaminator Apr 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix devices hopefully
  • Loading branch information
thejaminator committed Apr 26, 2023
commit 3c0178cd2e2da249797a1046288a50aa5c2c880e
36 changes: 21 additions & 15 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,43 +10,35 @@
from warnings import filterwarnings

import torch
import torch.multiprocessing as mp
from datasets import (
Array2D,
Array3D,
Dataset,
DatasetDict,
Features,
SplitDict,
SplitInfo,
Value,
get_dataset_config_info,
)
from simple_parsing import Serializable, field
from torch import Tensor
from transformers import AutoConfig, PreTrainedModel
from transformers import PreTrainedModel
from transformers.modeling_outputs import Seq2SeqLMOutput

from .dataset_name import (
DatasetDictWithName,
extract_dataset_name_and_config,
)
from .prompt_loading import PromptConfig, load_prompts
from ..multiprocessing import evaluate_with_processes
from ..promptsource import DatasetTemplates
from ..utils import (
assert_type,
colorize,
float32_to_int16,
infer_label_column,
infer_num_classes,
instantiate_model,
instantiate_tokenizer,
is_autoregressive,
select_train_val_splits,
select_usable_devices,
)
from ..utils.data_utils import flatten_list
from .dataset_name import (
DatasetDictWithName,
extract_dataset_name_and_config,
)
from .prompt_loading import PromptConfig, load_prompts


@dataclass
Expand Down Expand Up @@ -103,6 +95,7 @@ def extract_hiddens_list(
*,
model: PreTrainedModel,
device: str | torch.device,
accelerate_device: torch.device | None,
split_type: Literal["train", "val"],
rank: int = 0,
world_size: int = 1,
Expand All @@ -113,6 +106,7 @@ def extract_hiddens_list(
cfg,
model=model,
device=device,
accelerate_device=accelerate_device,
split_type=split_type,
rank=rank,
world_size=world_size,
Expand All @@ -126,6 +120,7 @@ def extract_hiddens(
*,
model: PreTrainedModel,
device: str | torch.device,
accelerate_device: torch.device | None,
split_type: Literal["train", "val"],
rank: int = 0,
world_size: int = 1,
Expand All @@ -142,7 +137,16 @@ def extract_hiddens(
ds_names = p_cfg.datasets
assert len(ds_names) == 1, "Can only extract hiddens from one dataset at a time."

model = model.to(device)

if not accelerate_device:
# We need to move the model to another gpu
model = model.to(device)
else:
# But in the case of using accelerate, we won't
# move the model to another gpu, we will just
# make sure whatever tensors are created are
# on the correct device
device = accelerate_device

tokenizer = instantiate_tokenizer(
cfg.model, truncation_side="left", verbose=rank == 0
Expand Down Expand Up @@ -358,6 +362,8 @@ def extract_hiddens_with_gpus(
model = instantiate_model(
cfg.model, torch_dtype="auto" if first_device != "cpu" else torch.float32
)
# get the device of the model incase we are using accelerate
accelerate_device: torch.device | None = model.device if use_accelerate else None
with ThreadPool(len(devices)) as pool:
for split_name in split_names:
thunks: list[Callable[[], list[dict]]] = []
Expand Down
Loading