Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
josStorer committed Jul 9, 2023
1 parent e930eb5 commit d8c7045
Showing 1 changed file with 144 additions and 56 deletions.
200 changes: 144 additions & 56 deletions finetune/lora/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,52 +50,84 @@
parser = ArgumentParser()

parser.add_argument("--load_model", default="", type=str) # full path, with .pth
parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb
parser.add_argument(
"--wandb", default="", type=str
) # wandb project name. if "" then don't use wandb
parser.add_argument("--proj_dir", default="out", type=str)
parser.add_argument("--random_seed", default="-1", type=int)

parser.add_argument("--data_file", default="", type=str)
parser.add_argument("--data_type", default="utf-8", type=str)
parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data)
parser.add_argument(
"--vocab_size", default=0, type=int
) # vocab_size = 0 means auto (for char-level LM and .txt data)

parser.add_argument("--ctx_len", default=1024, type=int)
parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps
parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final
parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x
parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs"

parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU)
parser.add_argument(
"--epoch_steps", default=1000, type=int
) # a mini "epoch" has [epoch_steps] steps
parser.add_argument(
"--epoch_count", default=500, type=int
) # train for this many "epochs". will continue afterwards with lr = lr_final
parser.add_argument(
"--epoch_begin", default=0, type=int
) # if you load a model trained for x "epochs", set epoch_begin = x
parser.add_argument(
"--epoch_save", default=5, type=int
) # save the model every [epoch_save] "epochs"

parser.add_argument(
"--micro_bsz", default=12, type=int
) # micro batch size (batch size per GPU)
parser.add_argument("--n_layer", default=6, type=int)
parser.add_argument("--n_embd", default=512, type=int)
parser.add_argument("--dim_att", default=0, type=int)
parser.add_argument("--dim_ffn", default=0, type=int)
parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better)
parser.add_argument(
"--pre_ffn", default=0, type=int
) # replace first att layer by ffn (sometimes better)
parser.add_argument("--head_qk", default=0, type=int) # my headQK trick
parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim
parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer
parser.add_argument(
"--tiny_att_layer", default=-999, type=int
) # tiny attention @ which layer

parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
parser.add_argument(
"--lr_init", default=6e-4, type=float
) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
parser.add_argument("--lr_final", default=1e-5, type=float)
parser.add_argument("--warmup_steps", default=0, type=int) # try 50 if you load a model
parser.add_argument(
"--warmup_steps", default=0, type=int
) # try 50 if you load a model
parser.add_argument("--beta1", default=0.9, type=float)
parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence
parser.add_argument(
"--beta2", default=0.99, type=float
) # use 0.999 when your model is close to convergence
parser.add_argument("--adam_eps", default=1e-8, type=float)

parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
parser.add_argument(
"--grad_cp", default=0, type=int
) # gradient checkpt: saves VRAM, but slower
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift
parser.add_argument(
"--my_pile_shift", default=-1, type=int
) # my special pile mode - text shift
parser.add_argument("--my_pile_edecay", default=0, type=int)
parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
parser.add_argument(
"--layerwise_lr", default=1, type=int
) # layerwise lr for faster convergence (but slower it/s)
parser.add_argument(
"--ds_bucket_mb", default=200, type=int
) # deepspeed bucket size in MB. 200 seems enough
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)

parser.add_argument("--my_img_version", default=0, type=str)
parser.add_argument("--my_img_size", default=0, type=int)
parser.add_argument("--my_img_bit", default=0, type=int)
parser.add_argument("--my_img_clip", default='x', type=str)
parser.add_argument("--my_img_clip", default="x", type=str)
parser.add_argument("--my_img_clip_scale", default=1, type=float)
parser.add_argument("--my_img_l1_scale", default=0, type=float)
parser.add_argument("--my_img_encoder", default='x', type=str)
parser.add_argument("--my_img_encoder", default="x", type=str)
# parser.add_argument("--my_img_noise_scale", default=0, type=float)
parser.add_argument("--my_sample_len", default=0, type=int)
parser.add_argument("--my_ffn_shift", default=1, type=int)
Expand All @@ -104,7 +136,7 @@
parser.add_argument("--load_partial", default=0, type=int)
parser.add_argument("--magic_prime", default=0, type=int)
parser.add_argument("--my_qa_mask", default=0, type=int)
parser.add_argument("--my_testing", default='', type=str)
parser.add_argument("--my_testing", default="", type=str)

parser.add_argument("--lora", action="store_true")
parser.add_argument("--lora_load", default="", type=str)
Expand All @@ -122,18 +154,26 @@
import numpy as np
import torch
from torch.utils.data import DataLoader

if "deepspeed" in args.strategy:
import deepspeed
import pytorch_lightning as pl
from pytorch_lightning import seed_everything

if args.random_seed >= 0:
print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3)
print(
f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n"
* 3
)
seed_everything(args.random_seed)

np.set_printoptions(precision=4, suppress=True, linewidth=200)
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
warnings.filterwarnings(
"ignore", ".*Consider increasing the value of the `num_workers` argument*"
)
warnings.filterwarnings(
"ignore", ".*The progress bar already tracks a metric with the*"
)
# os.environ["WDS_SHOW_SEED"] = "1"

args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
Expand All @@ -158,7 +198,9 @@
args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}"
args.proj_dir = f"{args.proj_dir}-{args.run_name}"
else:
args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
args.run_name = (
f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
)
if not os.path.exists(args.proj_dir):
os.makedirs(args.proj_dir)

Expand Down Expand Up @@ -240,24 +282,40 @@
)
rank_zero_info(str(vars(args)) + "\n")

assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"]
assert args.data_type in [
"utf-8",
"utf-16le",
"numpy",
"binidx",
"dummy",
"wds_img",
"uint16",
]

if args.lr_final == 0 or args.lr_init == 0:
rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n")
rank_zero_info(
"\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n"
)

assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
os.environ["RWKV_FLOAT_MODE"] = args.precision
if args.precision == "fp32":
for i in range(10):
rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n")
rank_zero_info(
"\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n"
)
if args.precision == "fp16":
rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")
rank_zero_info(
"\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n"
)

os.environ["RWKV_JIT_ON"] = "1"
if "deepspeed_stage_3" in args.strategy:
os.environ["RWKV_JIT_ON"] = "0"
if args.lora and args.grad_cp == 1:
print('!!!!! LoRA Warning: Gradient Checkpointing requires JIT off, disabling it')
print(
"!!!!! LoRA Warning: Gradient Checkpointing requires JIT off, disabling it"
)
os.environ["RWKV_JIT_ON"] = "0"

torch.backends.cudnn.benchmark = True
Expand All @@ -284,41 +342,47 @@
train_data = MyDataset(args)
args.vocab_size = train_data.vocab_size

if args.data_type == 'wds_img':
if args.data_type == "wds_img":
from src.model_img import RWKV_IMG

assert args.lora, "LoRA not yet supported for RWKV_IMG"
model = RWKV_IMG(args)
else:
from src.model import RWKV, LORA_CONFIG, LoraLinear

if args.lora:
assert args.lora_r > 0, "LoRA should have its `r` > 0"
LORA_CONFIG["r"] = args.lora_r
LORA_CONFIG["alpha"] = args.lora_alpha
LORA_CONFIG["dropout"] = args.lora_dropout
LORA_CONFIG["parts"] = set(str(args.lora_parts).split(','))
enable_time_finetune = 'time' in LORA_CONFIG["parts"]
enable_ln_finetune = 'ln' in LORA_CONFIG["parts"]
LORA_CONFIG["parts"] = set(str(args.lora_parts).split(","))
enable_time_finetune = "time" in LORA_CONFIG["parts"]
enable_ln_finetune = "ln" in LORA_CONFIG["parts"]
model = RWKV(args)
# only train lora parameters
if args.lora:
model.requires_grad_(False)
for name, module in model.named_modules():
# have to check param name since it may have been wrapped by torchscript
if any(n.startswith("lora_") for n, _ in module.named_parameters()):
print(f' LoRA training module {name}')
print(f" LoRA training module {name}")
for pname, param in module.named_parameters():
param.requires_grad = 'lora_' in pname
elif enable_ln_finetune and '.ln' in name:
print(f' LoRA additionally training module {name}')
param.requires_grad = "lora_" in pname
elif enable_ln_finetune and ".ln" in name:
print(f" LoRA additionally training module {name}")
for param in module.parameters():
param.requires_grad = True
elif enable_time_finetune and any(n.startswith("time") for n, _ in module.named_parameters()):
elif enable_time_finetune and any(
n.startswith("time") for n, _ in module.named_parameters()
):
for pname, param in module.named_parameters():
if pname.startswith("time"):
print(f' LoRA additionally training parameter {pname}')
print(f" LoRA additionally training parameter {pname}")
param.requires_grad = True

if len(args.load_model) == 0 or args.my_pile_stage == 1: # shall we build the initial weights?
if (
len(args.load_model) == 0 or args.my_pile_stage == 1
): # shall we build the initial weights?
init_weight_name = f"{args.proj_dir}/rwkv-init.pth"
generate_init_weight(model, init_weight_name) # save initial weights
args.load_model = init_weight_name
Expand Down Expand Up @@ -346,27 +410,39 @@
# If using LoRA, the LoRA keys might be missing in the original model
model.load_state_dict(load_dict, strict=(not args.lora))
if os.path.isfile(args.lora_load):
model.load_state_dict(torch.load(args.lora_load, map_location="cpu"),
strict=False)
model.load_state_dict(
torch.load(args.lora_load, map_location="cpu"), strict=False
)

trainer: Trainer = Trainer.from_argparse_args(
args,
callbacks=[train_callback(args)],
)

if (args.lr_init > 1e-4 or trainer.world_size * args.micro_bsz * trainer.accumulate_grad_batches < 8):
if 'I_KNOW_WHAT_IM_DOING' in os.environ:

if (
args.lr_init > 1e-4
or trainer.world_size * args.micro_bsz * trainer.accumulate_grad_batches < 8
):
if "I_KNOW_WHAT_IM_DOING" in os.environ:
if trainer.global_rank == 0:
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
print(f' WARNING: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)')
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
print(
f" WARNING: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)"
)
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
else:
if trainer.global_rank == 0:
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
print(f' ERROR: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)')
print(f' Unless you are sure this is what you want, adjust them accordingly')
print(f' (to suppress this, set environment variable "I_KNOW_WHAT_IM_DOING")')
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
print(
f" ERROR: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)"
)
print(
f" Unless you are sure this is what you want, adjust them accordingly"
)
print(
f' (to suppress this, set environment variable "I_KNOW_WHAT_IM_DOING")'
)
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
exit(0)

if trainer.global_rank == 0:
Expand All @@ -379,10 +455,22 @@
print(f"{str(shape[0]).ljust(5)} {n}")

if "deepspeed" in args.strategy:
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = (
args.ds_bucket_mb * 1000 * 1000
)
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = (
args.ds_bucket_mb * 1000 * 1000
)

# must set shuffle=False, persistent_workers=False (because worker is in another thread)
data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
data_loader = DataLoader(
train_data,
shuffle=False,
pin_memory=True,
batch_size=args.micro_bsz,
num_workers=1,
persistent_workers=False,
drop_last=True,
)

trainer.fit(model, data_loader)

0 comments on commit d8c7045

Please sign in to comment.