梯度视角下的LoRA:简介、分析、猜测及推广
By 苏剑林 | 2023-04-17 | 68336位读者 |随着ChatGPT及其平替的火热,各种参数高效(Parameter-Efficient)的微调方法也“水涨船高”,其中最流行的方案之一就是本文的主角LoRA了,它出自论文《LoRA: Low-Rank Adaptation of Large Language Models》。LoRA方法上比较简单直接,而且也有不少现成实现,不管是理解还是使用都很容易上手,所以本身也没太多值得细写的地方了。
然而,直接实现LoRA需要修改网络结构,这略微麻烦了些,同时LoRA给笔者的感觉是很像之前的优化器AdaFactor,所以笔者的问题是:能否从优化器角度来分析和实现LoRA呢?本文就围绕此主题展开讨论。
方法简介 #
以往的一些结果(比如《Exploring Aniversal Intrinsic Task Subspace via Prompt Tuning》)显示,尽管预训练模型的参数量很大,但每个下游任务对应的本征维度(Intrinsic Dimension)并不大,换句话说,理论上我们可以微调非常小的参数量,就能在下游任务取得不错的效果。
LoRA借鉴了上述结果,提出对于预训练的参数矩阵$W_0\in\mathbb{R}^{n\times m}$,我们不去直接微调$W_0$,而是对增量做低秩分解假设:
\begin{equation}W = W_0 + A B,\qquad A\in\mathbb{R}^{n\times r},B\in\mathbb{R}^{r\times m}\end{equation}
其中$A,B$之一用全零初始化,$W_0$固定不变,优化器只优化$A,B$。由于本征维度很小的结论,所以$r$我们可以取得很小,常见的是$r=8$,极端情况下我们甚至可以取$1$。所以说,LoRA是一种参数高效的微调方法,至少被优化的参数量大大降低了。
用MathJax直接画了个示意图:
$$\style{display: inline-block; width: 24ex; padding: 10ex 0; border: 1px solid #6C8EBF; background-color: #DAE8FC}{W_0\in\mathbb{R}^{n\times m}} \quad + \quad \style{display: inline-block; width: 8ex; padding: 10ex 0; border: 1px solid #D79B00; background-color: #FFE6CC}{A\in\mathbb{R}^{n\times r}}\quad\times\quad \style{display: inline-block; width: 24ex; padding: 3ex 0; border: 1px solid #D79B00; background-color: #FFE6CC}{B\in\mathbb{R}^{r\times m}}$$
梯度分析 #
正如《Ladder Side-Tuning:预训练模型的“过墙梯”》所提到的,很多参数高效的微调实际上只是降低了显存需求,并没有降低计算量。那么LoRA是否例外呢?它在显存和计算量方面的效率如何呢?下面我们来分析一下。
首先,我们知道训练模型所消耗的显存来源包括模型参数、模型梯度、模型激活值、优化器状态四部份,LoRA通过低秩分解降低了模型参数量,那么梯度和优化器状态也会随之降低,因此节省的显存是很明显的。那它能否节省计算量呢?
这取决于LoRA的实现方式,不同的实现方式计算梯度的复杂度不一样。LoRA的两种等效实现如下:
\begin{align}Y =&\, XW = X(W_0 + AB) \label{eq:lora-1}\\[5pt]
Y =&\, XW_0 + XAB = XW_0 + ZB \label{eq:lora-2}\end{align}
其中$X\in\mathbb{R}^{b\times n}$是模型输入,$Z=XA\in\mathbb{R}^{b\times r}$是中间输出。针对实现$\eqref{eq:lora-1}$,我们有
\begin{equation}\frac{\partial \mathcal{L}}{\partial A} = \frac{\partial \mathcal{L}}{\partial W} B^{\top} = \left(X^{\top}\frac{\partial \mathcal{L}}{\partial Y}\right) B^{\top},\quad \frac{\partial \mathcal{L}}{\partial B} = A^{\top}\frac{\partial \mathcal{L}}{\partial W} = A^{\top}\left(X^{\top}\frac{\partial \mathcal{L}}{\partial Y}\right)\label{eq:grad-1}\end{equation}
$\mathcal{L}$是损失函数。很明显,这种实现导致的后果是需要算完整梯度$\frac{\partial \mathcal{L}}{\partial W}\in\mathbb{R}^{n\times m}$,然后才能算$A,B$的梯度,这意味着它比不LoRA还慢,也费显存。对于实现$\eqref{eq:lora-2}$,我们则有
\begin{equation}\frac{\partial \mathcal{L}}{\partial A} = X^{\top}\frac{\partial \mathcal{L}}{\partial Z} = X^{\top}\left(\frac{\partial \mathcal{L}}{\partial Y} B^{\top}\right),\quad \frac{\partial \mathcal{L}}{\partial B} = Z^{\top}\frac{\partial \mathcal{L}}{\partial Y} = (XA)^{\top}\frac{\partial \mathcal{L}}{\partial Y}\label{eq:grad-2}\end{equation}
此时的$Z,\frac{\partial \mathcal{L}}{\partial Z}\in\mathbb{R}^{b\times r}$,相比完整的梯度显然省了不少,计算复杂度也明显降低。所以,LoRA想要节省显存和计算最大化,关键是按照$\eqref{eq:lora-2}$而不是$\eqref{eq:lora-1}$来实现。
(注:关于矩阵计算梯度,我们可以根据链式法则和输出形状来“凑”,比如$\frac{\partial \mathcal{L}}{\partial A}$,根据链式法则我们知道它必然是$\frac{\partial \mathcal{L}}{\partial W}$和$B$以某种方式相乘,我们约定$\frac{\partial \mathcal{L}}{\partial A}$的形状跟$A$一致,即$n\times r$,想要用$\frac{\partial \mathcal{L}}{\partial W}$和$B$凑出一个$n\times r$的结果来,那就只有$\frac{\partial \mathcal{L}}{\partial W} B^{\top}$了。)
其他原因 #
除了低秩分解带来的好处外,如下几点也是LoRA能节省显存和提速的原因:
1、只更新了部分参数:比如LoRA原论文就选择只更新Self Attention的参数,实际使用时我们还可以选择只更新部分层的参数;
2、减少了通信时间:由于更新的参数量变少了,所以(尤其是多卡训练时)要传输的数据量也变少了,从而减少了传输时间;
3、采用了各种低精度加速技术,如FP16、FP8或者INT8量化等。
当然,这三部分原因确实能加快训练速度,但它们并不是LoRA所独有的,事实上几乎都有参数高效方法都具有这些特点。LoRA的突出优点是它的低秩分解很直观,在不少场景下跟全量微调的效果一致,以及在预测阶段可以直接把$W_0,A,B$合并成单个矩阵从而不增加推理成本。
优化视角 #
梯度$\eqref{eq:grad-1}$还告诉了我们如何从优化器角度来实现LoRA。优化器可以直接获取到全量梯度$\frac{\partial \mathcal{L}}{\partial W}$,然后我们只需要按照公式$\eqref{eq:grad-1}$对梯度进行投影,就得到$A,B$的梯度,接着就可以按照常规的优化器实现$A,B$的更新了。
假如优化器是SGD,那么就是
\begin{equation}\begin{aligned}
A_{t+1} =&\, A_t - \eta\frac{\partial \mathcal{L}}{\partial W_t} B_t^{\top},\quad B_{t+1} = B_t - \eta A_t^{\top}\frac{\partial \mathcal{L}}{\partial W_t}\\[5pt]
W_{t+1} =&\, W_0 + A_{t+1} B_{t+1} = W_t + (A_{t+1} B_{t+1} - A_t B_t)
\end{aligned}\end{equation}
如果是Adam之类的带滑动变量的优化器,则只需要滑动投影后的梯度,因此是降低了优化器的参数量,节省了一定的显存。模型越大,这部分参数所占的显存比例也就越大。
LoRA约定$A$或$B$之一使用全零初始化,这是为了保证初始状态模型跟预训练一致,但同时也带来了不对称问题(一个全零,一个非全零)。事实上,$A,B$都使用非全零初始化也是可以的,只需要事先将预训练权重减去$A_0 B_0$就行了,或者等价地说,将$W$参数化为
\begin{equation}W = W_0 - A_0 B_0 + A B\end{equation}
这样同时保持了初始状态一致,同时允许$A,B$都用非全零初始化,增强了对称性。
随机投影 #
如果我们将SGD场景下的更新量$A_{t+1} B_{t+1} - A_t B_t$展开,结果将是
\begin{equation}- \eta\left(\frac{\partial \mathcal{L}}{\partial W_t} B_t^{\top} B_t + A_t A_t^{\top}\frac{\partial \mathcal{L}}{\partial W_t}\right) + \eta^2 \frac{\partial \mathcal{L}}{\partial W_t} B_t^{\top} A_t^{\top}\frac{\partial \mathcal{L}}{\partial W_t}\end{equation}
假设$\eta^2$项是可以忽略的高阶项,那么就剩下
\begin{equation}- \eta\left(\frac{\partial \mathcal{L}}{\partial W_t} B_t^{\top} B_t + A_t A_t^{\top}\frac{\partial \mathcal{L}}{\partial W_t}\right)\end{equation}
从这个角度来看,相比全量微调的SGD,LoRA就是用括号中的结果替代了全量的梯度$\frac{\partial \mathcal{L}}{\partial W_t}$。
简单起见,接下来我们只关心$r=1$的情形,留意到在上式中,$t$时刻的投影向量$A_t,B_t$是依赖于$t$的,如果我们将它们换成不依赖于$t$的随机向量(每步训练都重新随机生成),那么会发生什么呢?我们考虑$u,v\sim\mathcal{N}(0,1)$,其中$u\in\mathbb{R}^{m\times 1}, v\in\mathbb{R}^{1\times n}$,那么更新量就变为
\begin{equation}- \eta\left(\frac{\partial \mathcal{L}}{\partial W_t} v^{\top} v + u u^{\top}\frac{\partial \mathcal{L}}{\partial W_t}\right)\end{equation}
可以证明的是
\begin{equation}\mathbb{E}_{u\sim \mathcal{N}(0,1)}[u u^{\top}] = I_{n\times n},\quad \mathbb{E}_{v\sim \mathcal{N}(0,1)}[v^{\top} v] = I_{m\times m}\end{equation}
这里的$I_{n\times n},I_{m\times m}$分别指$n\times n,m\times m$的单位矩阵。因此,跟“零阶梯度”类似,在平均意义下,这种每步都重新初始化的LoRA事实上等价于满秩的SGD。然而,真要按照这个方式实现的话,其速度甚至可能比满秩的SGD都要慢,所以它的目的不是提速,而是希望能缓解灾难遗忘问题——通过对单个(batch)样本使用低秩矩阵(而不是满秩)更新量的方式,减少对整个模型权重的影响。当然,这只是猜测,实际效果如何,笔者还没有实验过。
一个变体 #
同样还是先只考虑$r=1$的情形,LoRA相当于假设了$\Delta w_{i,j} = u_i v_j$,我们能不能做其他低秩分解假设呢?比如$\Delta w_{i,j} = u_i + v_j$?写成矩阵形式就是
\begin{equation}W = W_0 + A \mathbb{1}_{1\times m} + \mathbb{1}_{n\times 1} B,\qquad A\in\mathbb{R}^{n\times 1},B\in\mathbb{R}^{1\times m}\end{equation}
其中$\mathbb{1}_{1\times m},\mathbb{1}_{n\times 1}$分别指$1\times m,n\times 1$的全1矩阵。容易求出它的梯度是:
\begin{equation}\frac{\partial \mathcal{L}}{\partial A} = \frac{\partial \mathcal{L}}{\partial W} \mathbb{1}_{m\times 1},\quad \frac{\partial \mathcal{L}}{\partial B} = \mathbb{1}_{1\times n}\frac{\partial \mathcal{L}}{\partial W}\end{equation}
其实就是原本梯度的行求和与列求和。相比原版LoRA,这个加性分解有两个优点:1、加比乘计算量更低,梯度形式也更简单;2、$AB$的秩一定是1,但是$A \mathbb{1}_{1\times m} + \mathbb{1}_{n\times 1} B$的秩可能是2,如果秩代表了模型能力的话,那也就是说同样的参数量,加性的表达能力可能还更强。至于具体效果如何,后面笔者用到LoRA的时候,再做对比实验吧。
那么,加性分解能不能推广到$r > 1$的情形呢?自然是可以的,但稍微有些技巧。这里约定$m,n$都能被$r$整除,那么我们只需要将参数化方式改为
\begin{equation}W = W_0 + A I_{r(1\times m/r)} + I_{r(n/r\times 1)} B,\qquad A\in\mathbb{R}^{n\times r},B\in\mathbb{R}^{r\times m}\end{equation}
这里的$I_{r(1\times m/r)}$、$I_{r(n/r\times 1)}$分别指$1\times m/r$、$n/r\times 1$的分块矩阵,每一块则是$r\times r$的单位阵。这个形式说白了,就是分别将$A$、$B$看成是$n/r\times 1$、$1\times m/r$的分块矩阵,然后套用$r=1$的思路来操作。
文章小结 #
本文介绍了从梯度角度来理解LoRA,除了基本的介绍外,还包含了笔者的一些猜测和推广,供读者参考。
转载到请包括本文地址:https://kexue.fm/archives/9590
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Apr. 17, 2023). 《梯度视角下的LoRA:简介、分析、猜测及推广 》[Blog post]. Retrieved from https://kexue.fm/archives/9590
@online{kexuefm-9590,
title={梯度视角下的LoRA:简介、分析、猜测及推广},
author={苏剑林},
year={2023},
month={Apr},
url={\url{https://kexue.fm/archives/9590}},
}
May 15th, 2023
苏老师您好,我想问一个问题:
原文中:“每步训练都重新随机生成”,“在平均意义下,这种每步都重新初始化的LoRA事实上等价于满秩的SGD。”
那么,如果我使用W0的模型在同样一批数据完整的微调4次lora,然后 4 个lora 权重相加后,再加上 W0 得到的新模型,理论上 新模型 能否接近全参数微调的 W 模型?
如果能够接近,那么大概多少次才能够接近 全参数微调模型呢? 例如4次能否接近 全参数微调,或者接近多少?
huggingface peft的 LoraModel 有两个好用的函数:
1. add_weighted_adapter: 把2个lora权重相加 A_up_weight+ B_up_weight, A_down_weight + B_down_weight
2. merge_and_unload: 把lora 权重加到原模型的权重上, weight + transpose(up_weight @ down_weight)
我说的等价,不是并联而是串联。也就是说,先做一次LoRA,然后加到$W_0$上,把加后的权重当做$W_0$,再继续做LoRA,再加上去,依次类推。
感谢苏老师的回答,我还有个疑问:
如果是串联: LoRA1 训练4个epoch 后, 加到 \(W_0\) 后。LoRA2训练4个epoch,那么整个模型算不算训练了8个epoch,会不会容易过拟合呢?
我理解的串联不会过拟合:
在 \(W_0\)模型进行第一个LoRA ,训练到loss收敛后, 加上 \(W_0\) 生成 \(W_1\) 模型
在 \(W_1\)模型进行第二个LoRA ,训练到loss收敛后, 加上 \(W_1\) 生成 \(W_2\) 模型
在 \(W_(n-1)\)模型进行第n个LoRA ,训练到loss收敛后, 加上 \(W_(n-1)\) 生成 \(W_n\) 模型
当 \(W_(n-1)\) 模型进行 LoRA时, loss 不会下降,那就是达到了最大值。
这个时候停止就不会过拟合,如果继续训练则可能会过拟合。
串联是想通过多次LoRA达到全量微调的效果,目标场景是只有LoRA的算力但又想达到全量微调的效果。
那么什么场景下想要全量微调?肯定是数据足够多,LoRA已经无法发挥数据全部能力的时候啊,所以这种场景下多次LoRA串联不会过拟合。如果你的数据有限,单次LoRA效果已经很好,那么你多次LoRA自然会过拟合。
感谢苏老师的解答,学习到了很多东西。
June 9th, 2023
针对原文"如果是Adam之类的带滑动变量的优化器,则只需要滑动投影后的梯度,因此是降低了优化器的参数量,节省了一定的显存。模型越大,这部分参数所占的显存比例也就越大。",我有几个不明白的点:a. 如果在FP32情况下,模型参数所占显存大小为m,adam优化器所占显存为2m,梯度所占显存为m,只考虑这些的话节省优化器的参数量,这部分参数所站的显存比例最高也是50%;b.按照链式法则,需要计算lora中U或者V的梯度,也必须有W的梯度,导致在lora参数与loss之间参数的梯度必须得到计算,就是非lora(冻结)的参数也需要计算梯度,也就是模型所有的梯度是没法节省显存的,如果不考虑模型冻结参数进行8比特量化的话,那么针对模型本身信息所占的显存,lora节省显存的极限是50%
小模型情况下,模型的激活值所占的显存比优化器多得多,当模型越大,激活值所占的比例反而有所降低,所以在大模型中节省优化器的参数量,相对来说会比小模型省更多的显存。
另外,梯度也是能省显存的,参考@ReMeL|comment-21444。
June 30th, 2023
有个问题请教下,如果lora不全零初始化,那么保存才来的lora权重是不是也要把U0和V0保存下来,这样后面和原始模型权重merge的时候需要使用到吧
对。也可以在加载预训练权重的时候,就先对预训练权重进行相应的修改,但这样在训练完成之后就要立刻进行merge。
August 9th, 2023
不对称有什么影响?
个人审美上的问题。
February 28th, 2024
[...]LoRA(Low-Rank Adaptation)是当前LLM的参数高效微调手段之一,此前我们在《梯度视角下的LoRA:简介、分析、猜测及推广》也有过简单讨论。这篇文章我们来学习LoRA的一个新结论:[...]
April 15th, 2024
苏老师您好,文章关于Lora梯度的分析,我有点小疑问,还请赐教。使用Lora技术,模型的参数可以用下式分解,$W = W_0 + \Delta{W}$,假如目标函数 $Y=XW$,文中认为对于$\Delta{W}$的导数是$\frac{\partial{Y}}{\partial{\Delta{W}}}=\frac{\partial{Y}}{\partial{W}}\frac{\partial{W}}{\partial{\Delta{W}}}$,从这个视角下看,对于$\Delta{W}$的计算,是需要使用全量的模型参数的,我的疑问是,是否可以利用分配律$Y = XW = X(W_0 + \Delta{W}) = XW_0 + X\Delta{W}$来计算$\Delta{W}$的梯度呢?这个情况下,就只需要计算$\frac{\partial{Y}}{\partial{\Delta{W}}}= X$来得到$\Delta{W}$的梯度了
梯度公式是绕不开的(换言之梯度的计算量很难有明显降低),而且除非$Y$是一个标量,否则也不成立$\frac{\partial{Y}}{\partial{\Delta{W}}}= X$。本文已经说了,LoRA的提速主要来源于它没有对所有参数都进行LoRA。
苏老师,请问论文这一段说的不是楼上的意思吗?When using LoRA, We do not need to calculate the gradient of the pre-trained, frozen weights.Consider $h_i = W_ix_{i-1}$and its LoRA counterpart$h_i = W_ix_{i-1}+BAx_{i-1}$.....
LoRA的低秩分解对于求梯度来说,确实可以省一点显存,但不多。我主要想表达的是,LoRA相比全量微调能省显存,主要来源是:1、选择了更少的参数进行微调;2、LoRA带来的Adam优化的m、v显存降低。
其中第1点是所有微调方法都可以用的,包括Full FT,不是LoRA独有的。LoRA的低秩分解对于求梯度这一步,节省的计算量或者显存并不多。