Skip to content

Commit

Permalink
quantized llama-pro
Browse files Browse the repository at this point in the history
  • Loading branch information
KeremTurgutlu committed Apr 16, 2024
1 parent 9655e8a commit bae3681
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 17 deletions.
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,44 @@ LoRA fine-tuning using a custom LoRA module.
+ --train_type hqq_dora \
```

### `--train_type bnb_llama_pro`

4-bit quantized Llama-Pro fine-tuning using bitsanbytes Linear4bit layer with NF4 quantization.

To create llama-pro weights, run the following command:

```bash
python scripts/block_expansion.py \
--model_name meta-llama/Llama-2-7b-hf \
--output_dir /path/to/llama_pro_weights_directory \
--expansion_rate 0.1
```

```diff
- --train_type full \
+ --train_type bnb_llama_pro \
+ --llama_pro_path /path/to/llama_pro_weights_directory \
```

### `--train_type hqq_llama_pro`

4-bit quantized Llama-Pro fine-tuning using HQQ library.

To create llama-pro weights, run the following command:

```bash
python scripts/block_expansion.py \
--model_name meta-llama/Llama-2-7b-hf \
--output_dir /path/to/llama_pro_weights_directory \
--expansion_rate 0.1
```

```diff
- --train_type full \
+ --train_type hqq_llama_pro \
+ --llama_pro_path /path/to/llama_pro_weights_directory \
```

## Low Memory Loading

During quantized LoRA training we use a custom quantization and loading code to avoid loading the entire model into GPU memory before sharding it across GPUs. This is the default behavior of our training script when any of the following training options `"qlora", "custom_qlora", "hqq_lora"` is used. Other training options are already optimized for low memory loading to their best extent.
Expand Down
60 changes: 60 additions & 0 deletions scripts/block_expansion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@

import argparse
from transformers import AutoConfig
import torch
from transformers.utils import hub, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
import safetensors
from safetensors.torch import save_file
import os
from pathlib import Path

def main():
# Set up the argument parser
parser = argparse.ArgumentParser(description="Receive deepen model's args")
parser.add_argument("--model_name", default='meta-llama/Llama-2-7b-hf', type=str, help="original model path")
parser.add_argument("--output_dir", default=None, type=str, help="deepened model ckpt save path")
parser.add_argument("--expansion_rate", default=0.1, type=float, help="add new trainable % of layers")

# Parse the arguments
args = parser.parse_args()

idx = hub.cached_file(args.model_name, SAFE_WEIGHTS_INDEX_NAME)
files, _ = hub.get_checkpoint_shard_files(args.model_name, idx)

cfg = AutoConfig.from_pretrained(args.model_name)
num_original_layers = cfg.num_hidden_layers
new_layers = num_original_layers + int(num_original_layers * args.expansion_rate)

split = int(num_original_layers / (new_layers - num_original_layers))

if args.output_dir is None:
output_dir = Path(os.environ['HOME'])/'models'/(args.model_name + f'_blk_exp-{num_original_layers}-{new_layers}')
else:
output_dir = Path(args.output_dir)/(args.model_name + f'_blk_exp-{num_original_layers}-{new_layers}')
os.makedirs(output_dir, exist_ok=True)

for filename in files:
weights = safetensors.torch.load_file(filename)
expanded_weights = {}
for k,v in iter(weights.items()):
if 'layers' in k:
layer_no = int(k.split('layers.')[1].split('.')[0])
# shift existing layers
new_layer_no = layer_no + layer_no // split
new_k = k.replace(f'layers.{layer_no}', f'layers.{new_layer_no}')
expanded_weights[new_k] = v
# add new layers
if (layer_no+1) % split == 0:
new_layer_no += 1
new_k = k.replace(f'layers.{layer_no}', f'layers.{new_layer_no}')
if 'down_proj' in k or 'o_proj' in k:
expanded_weights[new_k] = torch.zeros_like(v)
else:
expanded_weights[new_k] = v.clone()
else:
expanded_weights[k] = v
save_file(expanded_weights, output_dir/Path(filename).name)


if __name__ == "__main__":
main()
44 changes: 44 additions & 0 deletions tests/test_block_expansion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import unittest, tempfile
import torch
import torch.nn as nn
import safetensors
from safetensors.torch import save_file
from pathlib import Path
from transformers.utils import hub, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
from glob import glob

# python -m unittest tests.test_quantize.TestQuantizer.test_quantizer
class TestBlockExpansion(unittest.TestCase):

def setUp(self) -> None:
# set seed
self.llama_pro_path = Path("/mnt/vol_b/models/meta-llama/Llama-2-7b-hf_blk_exp-32-35")
self.filenames = glob(str(self.llama_pro_path/"*.safetensors"))
num_original_layers, num_expanded_layers = self.llama_pro_path.name.split("blk_exp-")[1].split("-")
self.num_original_layers, self.num_expanded_layers = int(num_original_layers), int(num_expanded_layers)
self.split = int(self.num_original_layers / (self.num_expanded_layers - self.num_original_layers))


def tearDown(self) -> None:
return super().tearDown()

def test_expanded_weights(self):

total_new_layers = self.num_expanded_layers - self.num_original_layers
new_layer_ids = [self.split + (self.split + 1)*n for n in range(total_new_layers)]

verify_weights = {}
for filename in self.filenames:
weights = safetensors.torch.load_file(str(filename))
for k,v in iter(weights.items()):
if any(((f"layers.{i}" in k) or (f"layers.{i-1}" in k) for i in new_layer_ids)):
verify_weights[k] = v

for k,v in verify_weights.items():
if any(((f"layers.{i}" in k) for i in new_layer_ids)):
if 'down_proj' in k or 'o_proj' in k:
assert torch.equal(v, torch.zeros_like(v))
else:
lid = int(k.split("layers.")[1].split(".")[0])
assert torch.equal(verify_weights[k.replace(f"layers.{lid}", f"layers.{lid-1}")], v)

61 changes: 44 additions & 17 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,16 +555,27 @@ def fsdp_main(local_rank:int, world_size:int, args:Dict):
model = AutoModelForCausalLM.from_config(cfg, torch_dtype=torch_dtype)
if args["precision"] == "bf16":
model.to(torch_dtype)
elif args["train_type"] in ["qlora", "custom_qlora", "hqq_lora", "hqq_dora", "bnb_dora"]: # Our custom loading
elif args["train_type"] in ["qlora", "custom_qlora", "hqq_lora", "hqq_dora", "bnb_dora", "bnb_llama_pro", "hqq_llama_pro"]: # Our custom loading
cfg = AutoConfig.from_pretrained(args["model_name"])
cfg.use_cache = False
cfg._attn_implementation = attn_impl
skip_modules = ["lm_head"]

if args["train_type"] in ["bnb_llama_pro", "hqq_llama_pro"]:
llama_pro_path = Path(args["llama_pro_path"])
num_original_layers, num_expanded_layers = llama_pro_path.name.split("blk_exp-")[1].split("-")
num_original_layers, num_expanded_layers = int(num_original_layers), int(num_expanded_layers)
total_new_layers = num_expanded_layers - num_original_layers
split = int(num_original_layers / (num_expanded_layers - num_original_layers))
new_layer_ids = [split+(split+1)*n for n in range(total_new_layers)]
new_layer_names = [f"layers.{i}" for i in new_layer_ids]
skip_modules += [str(lid) for lid in new_layer_ids]
cfg.num_hidden_layers = num_expanded_layers

# load model on meta device without calling init and replace nn.Linear with Linear4bit
with init_empty_weights():
model = AutoModelForCausalLM.from_config(cfg)
if args["train_type"] in ["hqq_lora", "hqq_dora"]:
if args["train_type"] in ["hqq_lora", "hqq_dora", "hqq_llama_pro"]:
# TODO: Tune BaseQuantizeConfig.
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=True,
quant_scale=True, offload_meta=True, view_as_float=True)
Expand All @@ -577,25 +588,28 @@ def fsdp_main(local_rank:int, world_size:int, args:Dict):
model.is_loaded_in_4bit = True

# Grab the safetensors files that hold the weights
try:
idx = hub.cached_file(args["model_name"], SAFE_WEIGHTS_INDEX_NAME)
files, _ = hub.get_checkpoint_shard_files(args["model_name"], idx)
except OSError:
if args["train_type"] in ["bnb_llama_pro", "hqq_llama_pro"]:
files = glob(str(llama_pro_path/"*.safetensors"))
else:
try:
# This means the model doesn't have a model.safetensors.index.json because it is not sharded
files = []
files.append(hub.cached_file(args["model_name"], SAFE_WEIGHTS_NAME))
except OSError as e:
# This means the model probably doesn't have a safetensors file
raise e
idx = hub.cached_file(args["model_name"], SAFE_WEIGHTS_INDEX_NAME)
files, _ = hub.get_checkpoint_shard_files(args["model_name"], idx)
except OSError:
try:
# This means the model doesn't have a model.safetensors.index.json because it is not sharded
files = []
files.append(hub.cached_file(args["model_name"], SAFE_WEIGHTS_NAME))
except OSError as e:
# This means the model probably doesn't have a safetensors file
raise e

# Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly
# and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage
def load_and_quantize_parallel(name_param, model, **kwargs):
name, param = name_param
load_and_quantize(model, name, param, **kwargs)

quant_method = "hqq" if args["train_type"] in ["hqq_lora", "hqq_dora"] else "bnb"
quant_method = "hqq" if args["train_type"] in ["hqq_lora", "hqq_dora", "hqq_llama_pro"] else "bnb"
param_count = sum((p.numel() for n,p in model.named_parameters()))
if rank == 0 or args['verbose']:
print("Loading model", rank)
Expand Down Expand Up @@ -673,14 +687,23 @@ def load_and_quantize_parallel(name_param, model, **kwargs):
if rank == 0 or args['verbose']:
print(f"Rank {rank}: LoRA layers added: {torch.cuda.memory_reserved(local_rank)/2**30:.3f} GiB")

elif args["train_type"] in ["bnb_llama_pro", "hqq_llama_pro"]:
for n,p in model.named_parameters():
if any([layer_name in n for layer_name in new_layer_names]):
p.requires_grad = True
if args['verbose']:
print("Trainable Llama-Pro layer", n)
else:
p.requires_grad = False

if args["log_to"] == 'wandb':
logger.log({"memory/allocated_after_model_created": torch.cuda.memory_allocated(local_rank)}, rank)
logger.log({"memory/reserved_after_model_creation": torch.cuda.memory_reserved(local_rank)}, rank)


# Wrap model with llama-recipies or custom LoRA policy
my_auto_wrap_policy = get_wrapping_policy(custom_policy=args["train_type"] in ["custom_qlora", "hqq_lora", "hqq_dora", "bnb_dora"],
vanilla_policy=args["train_type"] in ["full"])
vanilla_policy=args["train_type"] in ["full", "bnb_llama_pro", "hqq_llama_pro"])

if rank == 0 or args['verbose']:
print("Wrapping model w/ FSDP", rank)
Expand Down Expand Up @@ -932,9 +955,12 @@ def load_and_quantize_parallel(name_param, model, **kwargs):
os.makedirs(args["output_dir"], exist_ok=True)
dist.barrier()
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
if args["train_type"] in ["custom_lora", "custom_qlora", "hqq_lora", "hqq_dora", "bnb_dora"]:
if args["train_type"] in ["custom_lora", "custom_qlora", "hqq_lora", "hqq_dora", "bnb_dora", "bnb_llama_pro", "hqq_llama_pro"]:
cpu_state_dict = {}
trainable_fsdp_modules = [(n,m) for n,m in model.named_modules() if n.endswith(('lora_AB', 'dora_layer', 'magnitude_layer'))]
if args["train_type"] in ["bnb_llama_pro", "hqq_llama_pro"]:
trainable_fsdp_modules =[(n,m) for n,m in model.named_modules() if n.endswith(tuple(new_layer_names))]
else:
trainable_fsdp_modules = [(n,m) for n,m in model.named_modules() if n.endswith(('lora_AB', 'dora_layer', 'magnitude_layer'))]
for prefix, module in trainable_fsdp_modules:
prefix = (prefix.replace("_fsdp_wrapped_module.", "")
.replace("_checkpoint_wrapped_module.", "")
Expand Down Expand Up @@ -966,7 +992,8 @@ def load_and_quantize_parallel(name_param, model, **kwargs):
@call_parse()
def main(
world_size: int = -1, # Number of GPUs to use. -1 = all available GPUs.
train_type: Param("", choices=["full", "lora", "qlora", "custom_qlora", "custom_lora", "hqq_lora", "hqq_dora", "bnb_dora"]) = "qlora", # "full", "lora", "qlora", or "custom_qlora"
train_type: Param("", choices=["full", "lora", "qlora", "custom_qlora", "custom_lora", "hqq_lora", "hqq_dora", "bnb_dora", "bnb_llama_pro", "hqq_llama_pro"]) = "qlora", # "full", "lora", "qlora", or "custom_qlora"
llama_pro_path: str = None, # Path to the quantized llama pro model
batch_size: int = 1, # Batch size per GPU. Effective BS = batch_size * world_size * gradient_accumulation_steps
context_length: int = 512, # Max length of input sequence (in tokens)
gradient_accumulation_steps: int = 1, # How many steps to accumulate gradients over (increases effective batch size)
Expand Down

0 comments on commit bae3681

Please sign in to comment.