Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2024 Release #96

Merged
merged 82 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
1543548
Fix tokenizer, dropout, bias for LoRA
danielhanchen Jan 6, 2024
14bee60
Update loader.py
danielhanchen Jan 6, 2024
82db08e
Merge branch 'main' into nightly
danielhanchen Jan 7, 2024
d78643d
Fix LoRA downcasting
danielhanchen Jan 7, 2024
b6a9841
Update _utils.py
danielhanchen Jan 7, 2024
1d0e1dc
Merge branch 'main' into nightly
danielhanchen Jan 8, 2024
3c41946
Saving to GGUF
danielhanchen Jan 8, 2024
25b15bc
fix
danielhanchen Jan 8, 2024
4bd8f7b
colab_quantize_to_gguf
danielhanchen Jan 8, 2024
b3e0e0f
move save modules
danielhanchen Jan 8, 2024
bb93e31
save module
danielhanchen Jan 8, 2024
29aaf66
Update __init__.py
danielhanchen Jan 8, 2024
605290f
Update save.py
danielhanchen Jan 8, 2024
4937926
Merge branch 'main' into nightly
danielhanchen Jan 9, 2024
f342605
Temp downgrade due to TRL issue
danielhanchen Jan 9, 2024
b332a8c
Merge branch 'main' into nightly
danielhanchen Jan 9, 2024
5dca87c
Merge branch 'main' into nightly
danielhanchen Jan 10, 2024
38f7c04
Fix up bugs
danielhanchen Jan 10, 2024
74037bb
Merge branch 'main' into nightly
danielhanchen Jan 11, 2024
7af483c
Faster saving + other changes
danielhanchen Jan 16, 2024
5ad792a
Update llama.py
danielhanchen Jan 16, 2024
2fc754f
Saving modules
danielhanchen Jan 17, 2024
48461a6
spelling
danielhanchen Jan 17, 2024
ed3e1db
Update llama.py
danielhanchen Jan 17, 2024
07202ff
Update save.py
danielhanchen Jan 17, 2024
7c1c87f
Update save.py
danielhanchen Jan 17, 2024
44430d1
Update loader.py
danielhanchen Jan 17, 2024
2ceed5d
Update llama.py
danielhanchen Jan 17, 2024
e932893
patch saving
danielhanchen Jan 17, 2024
08227b1
Update save.py
danielhanchen Jan 17, 2024
9cbf93d
Update save.py
danielhanchen Jan 17, 2024
16b6b6c
Update save.py
danielhanchen Jan 17, 2024
8610a17
patch saving
danielhanchen Jan 17, 2024
14a2185
Update save.py
danielhanchen Jan 17, 2024
6dfbff5
Update save.py
danielhanchen Jan 17, 2024
12740f5
Update save.py
danielhanchen Jan 17, 2024
cfd7f4c
Update save.py
danielhanchen Jan 17, 2024
4fc5743
Update save.py
danielhanchen Jan 17, 2024
eeda38d
Update save.py
danielhanchen Jan 17, 2024
3cf6b09
Update save.py
danielhanchen Jan 17, 2024
5dd4037
Update save.py
danielhanchen Jan 17, 2024
ef80889
Update save.py
danielhanchen Jan 17, 2024
7d6a5f6
Update save.py
danielhanchen Jan 17, 2024
4bbd370
Update save.py
danielhanchen Jan 17, 2024
0d9ea04
Update save.py
danielhanchen Jan 17, 2024
3a42afa
Update save.py
danielhanchen Jan 17, 2024
1e70ce1
Update save.py
danielhanchen Jan 17, 2024
ac04280
Update save.py
danielhanchen Jan 17, 2024
0a7e12e
original_model
danielhanchen Jan 18, 2024
d746560
Update save.py
danielhanchen Jan 18, 2024
3b0f92e
Update save.py
danielhanchen Jan 18, 2024
a6f1a29
Update save.py
danielhanchen Jan 18, 2024
4ef3fad
Update save.py
danielhanchen Jan 18, 2024
eab82e4
Update save.py
danielhanchen Jan 18, 2024
3801dc7
Update save.py
danielhanchen Jan 18, 2024
fdfd769
Update save.py
danielhanchen Jan 18, 2024
f2884c1
Update save.py
danielhanchen Jan 18, 2024
cc0ff3f
Update save.py
danielhanchen Jan 18, 2024
c549550
Update save.py
danielhanchen Jan 18, 2024
f755247
Update save.py
danielhanchen Jan 18, 2024
2fced37
Update save.py
danielhanchen Jan 18, 2024
ee7154f
Update save.py
danielhanchen Jan 18, 2024
5aaffb7
Update save.py
danielhanchen Jan 18, 2024
61e4c1c
Update save.py
danielhanchen Jan 18, 2024
9f11efb
Update save.py
danielhanchen Jan 18, 2024
43a025a
Update save.py
danielhanchen Jan 18, 2024
b90c385
Update save.py
danielhanchen Jan 18, 2024
573427b
Update save.py
danielhanchen Jan 18, 2024
2672abf
Update save.py
danielhanchen Jan 18, 2024
68670c6
Update save.py
danielhanchen Jan 18, 2024
2231bdd
Update save.py
danielhanchen Jan 18, 2024
074e79a
Update save.py
danielhanchen Jan 18, 2024
90a2f9b
saving to RAM leakage?
danielhanchen Jan 18, 2024
8e140af
Update save.py
danielhanchen Jan 18, 2024
d3b4a24
new_save_directory
danielhanchen Jan 18, 2024
f535904
Update save.py
danielhanchen Jan 18, 2024
386655e
Update save.py
danielhanchen Jan 18, 2024
e5e91f4
Update save.py
danielhanchen Jan 18, 2024
4a91ede
Update save.py
danielhanchen Jan 18, 2024
c53dee9
Update pyproject.toml
danielhanchen Jan 18, 2024
7196ef0
Update pyproject.toml
danielhanchen Jan 18, 2024
b2cebb8
Update pyproject.toml
danielhanchen Jan 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Faster saving + other changes
  • Loading branch information
danielhanchen committed Jan 16, 2024
commit 7af483c394cd7d5b068493513270fc59a055417a
16 changes: 13 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ huggingface = [
"datasets",
"sentencepiece",
"accelerate",
"trl",
"trl>=0.7.9",
"peft",
"packaging",
"ninja",
"tqdm",
"psutil",
]
cu118only = [
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
Expand Down Expand Up @@ -93,30 +93,40 @@ colab = [
]
colab_ampere = [
"unsloth[cu121]",
"packaging",
"ninja",
"flash-attn",
]
cu118_ampere = [
"unsloth[huggingface]",
"bitsandbytes",
"unsloth[cu118only]",
"packaging",
"ninja",
"flash-attn",
]
cu121_ampere = [
"unsloth[huggingface]",
"bitsandbytes",
"unsloth[cu121only]",
"packaging",
"ninja",
"flash-attn",
]
cu118_ampere_torch211 = [
"unsloth[huggingface]",
"bitsandbytes",
"unsloth[cu118only_torch211]",
"packaging",
"ninja",
"flash-attn",
]
cu121_ampere_torch211 = [
"unsloth[huggingface]",
"bitsandbytes",
"unsloth[cu121only_torch211]",
"packaging",
"ninja",
"flash-attn",
]

Expand Down
5 changes: 3 additions & 2 deletions unsloth/kernels/rms_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@ def _rms_layernorm_forward(
r += row_idx * r_row_stride

X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
W_row = tl.load(W + col_offsets, mask = mask, other = 0)#.to(tl.float32)

row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
inv_var = 1 / tl.sqrt(row_var + eps)
inv_var = 1.0 / tl.sqrt(row_var + eps)
tl.store(r, inv_var)
normed = X_row * inv_var
normed = normed.to(W_row.dtype) # Exact copy from HF
output = normed * W_row
tl.store(Y + col_offsets, output, mask = mask)
pass
Expand Down
12 changes: 7 additions & 5 deletions unsloth/kernels/swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ def _fg_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
mask = offsets < n_elements

e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0).to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)

# f = e * sigmoid(e)
f_row = e_row / (1 + tl.exp(-e_row))
f_row = f_row.to(g_row.dtype) # Exact copy from HF
# h = f * g
h_row = f_row * g_row

Expand All @@ -53,12 +54,13 @@ def _DWf_DW_dfg_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements

DW_row = tl.load(DW + offsets, mask = mask, other = 0).to(tl.float32)
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0).to(tl.float32)
DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
e_row = tl.load(e + offsets, mask = mask, other = 0)#.to(tl.float32)
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)

# f = e * sigmoid(e)
se_row = 1 / (1 + tl.exp(-e_row))
se_row = 1 / (1 + tl.exp(-e_row.to(tl.float32)))
se_row = se_row.to(e_row.dtype) # Exact copy from HF
# f = e * se
f_row = e_row * se_row
# h = f * g
Expand Down
110 changes: 86 additions & 24 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoConfig
from transformers import set_seed as transformers_set_seed
from peft import LoraConfig, TaskType, get_peft_model as _get_peft_model
from peft import PeftModel, PeftConfig


def original_apply_qkv(self, X):
Expand Down Expand Up @@ -156,6 +157,28 @@ def LlamaAttention_fast_forward_inference(
pass


torch_silu = torch.nn.functional.silu
def fast_mlp_inference(self, X):
gate = self.gate_proj(X)
up = self.up_proj(X)
gate = torch_silu(gate, inplace = True)
gate *= up
X = self.down_proj(gate)
return X
pass


def fast_rms_layernorm_inference(self, X):
X = X.to(torch.float32)
variance = X.square().mean(-1, keepdim = True)
variance += self.variance_epsilon
X *= variance.rsqrt_()
X = X.to(residual.dtype)
X *= self.weight
return X
pass


# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L320
def LlamaAttention_fast_forward(
self,
Expand Down Expand Up @@ -287,28 +310,51 @@ def LlamaDecoderLayer_fast_forward(
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states

hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
bsz, q_len, hd = hidden_states.size()

if (self.training):
# Self Attention
residual = hidden_states
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states = residual + hidden_states

# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
else:
# Self Attention
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states += residual

# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(self.post_attention_layernorm, hidden_states)
hidden_states = fast_mlp_inference(self.mlp, hidden_states)
hidden_states += residual
pass

outputs = (hidden_states,)

Expand Down Expand Up @@ -414,8 +460,7 @@ def LlamaModel_fast_forward(
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window = None if not hasattr(self.config, "sliding_window") else \
self.config.sliding_window,
sliding_window = getattr(self.config, "sliding_window"),
)
pass

Expand Down Expand Up @@ -479,7 +524,11 @@ def custom_forward(*inputs):
all_self_attns += (layer_outputs[1],)
pass

hidden_states = fast_rms_layernorm(self.norm, hidden_states)
if (self.training):
hidden_states = fast_rms_layernorm(self.norm, hidden_states)
else:
hidden_states = fast_rms_layernorm_inference(self.norm, hidden_states)
pass

# add hidden states from the last decoder layer
if output_hidden_states:
Expand Down Expand Up @@ -513,7 +562,7 @@ def LlamaForCausalLM_fast_forward(
*args, **kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:

if causal_mask is None:
if self.training and causal_mask is None:
causal_mask = xformers.attn_bias.LowerTriangularMask()

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Expand Down Expand Up @@ -665,6 +714,7 @@ def from_pretrained(
bnb_4bit_quant_type = "nf4",
bnb_4bit_compute_dtype = dtype,
)
pass

# https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/discussions/12
# RoPE Scaling's max_position_embeddings must be updated
Expand Down Expand Up @@ -721,6 +771,7 @@ def from_pretrained(
name = name[:len(name) - len("-bnb-4bit")]
model.config.update({"_name_or_path" : name})
pass

# Log Unsloth version for future fastpaths for inference
model.config.update({"unsloth_version" : __version__})

Expand Down Expand Up @@ -828,6 +879,17 @@ def get_peft_model(
)
model = _get_peft_model(model, lora_config)


# Fix up config for transformers uploading PEFT
name = model.peft_config["default"].base_model_name_or_path
if name.startswith("unsloth/") and name.endswith("-bnb-4bit"):
name = name[:len(name) - len("-bnb-4bit")]
model.peft_config["default"].base_model_name_or_path = name
pass
# Add revision to enable future fast inference paths
model.peft_config["default"].revision = f"unsloth"


# Do patching
n_mlp = 0
n_qkv = 0
Expand Down
Loading