Low Rank Adaptation for Waveforms
Designed to be used with Stable Audio Tools
Highly experimental still
Clone this repo, navigate to the root, and run:
$ pip install .
Add a lora
section to your model config i.e.:
"lora": {
"component_whitelist": ["transformer"],
"multiplier": 1.0,
"rank": 16,
"alpha": 16,
"dropout": 0,
"module_dropout": 0,
"lr": 1e-4
}
A full example config that works with stable audio open can be found here
Then run the modified train.py
as you would in stable-audio-tools with the following command line arguments as needed:
--use-lora
- Set to true to enable lora usage
- Default: false
--lora-ckpt-path
- A pre-trained lora continue from
--relora-every
- Enables ReLoRA training if set
- The number of steps between full-rank updates
- Default: 0
--quantize
- CURRENTLY BROKEN
- Set to true to enable 4-bit quantization of base model for QLoRA training
- Default: false
Run the modified run_gradio.py
as you would in stable-audio-tools with the following command line argument:
--lora-ckpt-path
- Your trained lora checkpoint
Create a loraw using the LoRAWrapper class. For example using a conditional diffusion model for which we only want to target the transformer component:
from loraw.network import LoRAWrapper
lora = LoRAWrapper(
target_model,
component_whitelist=["transformer"],
lora_dim=16,
alpha=16,
dropout=None,
multiplier=1.0
)
If using stable-audio-tools, you can create a LoRA based on your model config:
from loraw.network import create_lora_from_config
lora = create_lora_from_config(model_config, target_model)
If you want to load weights into the target model, be sure to do so first as activation will alter the structure and confuse state_dict copying
lora.activate()
lora.load_weights(path)
and lora.save_weights(path)
are for simple file IO. lora.merge_weights(path)
can be used to add more checkpoints without overwriting the current state.
With stable-audio-tools, after activation, you can simply call
lora.prepare_for_training(training_wrapper)
For training to work manually, you need to:
- Set all original weights to
requires_grad = False
- Set lora weights set to
requires_grad = True
(easily accessed withlora.residual_modules.parameters()
) - Update the optimizer to use the lora parameters (the same parameters as the previous step)