-
Notifications
You must be signed in to change notification settings - Fork 228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
全面解析并实现逻辑回归(Python) #33
Comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
一、逻辑回归模型结构
逻辑回归是一种广义线性的分类模型且其模型结构可以视为单层的神经网络,由一层输入层、一层仅带有一个sigmoid激活函数的神经元的输出层组成,而无隐藏层。其模型的功能可以简化成两步,“通过模型权重[w]对输入特征[x]线性求和+sigmoid激活输出概率”。
具体来说,我们输入数据特征x,乘以一一对应的模型权重w后求和,通过输出层神经元激活函数σ(sigmoid函数)将(wx + b)的计算后非线性转换为0~1区间的概率数值后输出。学习训练(优化模型权重)的过程是通过梯度下降学到合适的模型权重[W],使得模型输出值Y=sigmoid(wx + b)与实际值y的误差最小。
逻辑回归模型本质上属于广义线性分类器(决策边界为线性)。这点可以从逻辑回归模型的决策函数看出,决策函数Y=sigmoid(wx + b),当wx+b>0,Y>0.5;当wx+b<0,Y<0.5,以wx+b这条线可以区分开Y=0或1(如下图),可见决策边界是线性的。
![](https://camo.githubusercontent.com/68a0758cf3b7b250f9f4419d64dada696b42ee1f16ae82fa3bcc546e41a469fb/68747470733a2f2f696d672d626c6f672e6373646e696d672e636e2f696d675f636f6e766572742f36303631343835633234343930613933343639633861656364346234323062362e706e67)
二、学习目标
逻辑回归是一个经典的分类模型,对于模型预测我们的目标是:预测的概率与实际正负样本的标签是对应的,Sigmoid 函数的输出表示当前样本标签为 1 的概率,y^可以表示为
![](https://camo.githubusercontent.com/51ea4f0dc62c773d282d3f4b7a3c54a89fd3fb18b065cd30dbf7c2cb03d76397/68747470733a2f2f696d672d626c6f672e6373646e696d672e636e2f696d675f636f6e766572742f66383435383733356131326530396432356632633035616439353761383865312e706e67)
当前样本预测为0的概率可以表示为1-y^
![](https://camo.githubusercontent.com/e1e61561383333d9c061fe0ce5a85158e88a1de10a329481f5a4209df85d6776/68747470733a2f2f696d672d626c6f672e6373646e696d672e636e2f696d675f636f6e766572742f36333663656433643065366566336366313430393236363030636630336166662e706e67)
对于正样本y=1,我们期望预测概率尽量趋近为1 。对于负样本y=0,期望预测概率尽量都趋近为0。也就是,我们希望预测的概率使得下式的概率最大(最大似然法)
![](https://camo.githubusercontent.com/23ea4924d1f88213147cf1480a3cbc44eb1bfcb52c833dcce8073851272cb829/68747470733a2f2f696d672d626c6f672e6373646e696d672e636e2f696d675f636f6e766572742f64623365636563613239366334333863376364346163653366393336646432332e706e67)
![](https://camo.githubusercontent.com/f35072822f1bdefc1d5c71ec2459d48b17645fbe0e8c9ee74b0842ef5b3dcc98/68747470733a2f2f696d672d626c6f672e6373646e696d672e636e2f696d675f636f6e766572742f37366363393933303834653737306530353764653138363536363634393663622e706e67)
我们对 P(y|x) 引入 log 函数,因为 log 运算并不会影响函数本身的单调性。则有:
我们希望 log P(y|x) 越大越好,反过来,只要 log P(y|x) 的负值 -log P(y|x) 越小就行了。那我们就可以引入损失函数,且令 Loss = -log P(y|x),得到损失函数为:
我们已经推导出了单个样本的损失函数,是如果是计算 m 个样本的平均的损失函数,只要将 m 个 Loss 叠累加取平均就可以了:
这就在最大似然法推导出的lr的学习目标——交叉熵损失(或对数损失函数),也就是让最大化使模型预测概率服从真实值的分布,预测概率的分布离真实分布越近,模型越好。可以关注到一个点,如上式逻辑回归在交叉熵为目标以sigmoid输出的预测概率,概率值只能尽量趋近0或1,同理loss也并不会为0。
三、优化算法
我们以极小交叉熵为学习目标,下面要做的就是,使用优化算法去优化参数以达到这个目标。由于最大似然估计下逻辑回归没有(最优)解析解,我们常用梯度下降算法,经过多次迭代,最终学习到的参数也就是较优的数值解。
![](https://camo.githubusercontent.com/6ec9069697ac64e9c8c4fb6961523756424fd4c57869e7d76557bb0065ed3f40/68747470733a2f2f696d672d626c6f672e6373646e696d672e636e2f696d675f636f6e766572742f39643235633533623133313662373366336164366232383533386561663137652e706e67)
梯度下降算法可以直观理解成一个下山的方法,将损失函数J(w)比喻成一座山,我们的目标是到达这座山的山脚(即求解出最优模型参数w使得损失函数为最小值)。
下山要做的无非就是“往下坡的方向走,走一步算一步”,而在损失函数这座山上,每一位置的下坡的方向也就是它的负梯度方向(直白点,也就是山的斜向下的方向)。在每往下走一步(步长由α控制)到一个位置的时候,求解当前位置的梯度,向这一步所在位置沿着最陡峭最易下山的位置再走一步。这样一步步地走下去,一直走到觉得我们已经到了山脚。
![](https://camo.githubusercontent.com/9273ac14e2bf8e67edf5a17f6c75acaa7e467a99bc609ed46a95421e4efe0ac6/68747470733a2f2f696d672d626c6f672e6373646e696d672e636e2f696d675f636f6e766572742f66313265316235346336613862373666333434373963343339613232386339652e706e67)
当然这样走下去,有可能我们不是走到山脚(全局最优,Global cost minimun),而是到了某一个的小山谷(局部最优,Local cost minimun),这也梯度下降算法的可进一步优化的地方。
对应的算法步骤:
另外的,以非极大似然估计角度,去求解逻辑回归(最优)解析解,可见kexue.fm/archives/8578
四、Python实现逻辑回归
本项目的数据集为癌细胞分类数据。基于Python的numpy库实现逻辑回归模型,定义目标函数为交叉熵,使用梯度下降迭代优化模型,并验证分类效果:
![](https://camo.githubusercontent.com/f8c66193694fbed48d3c16462687f0872f67cd3dd898225264fd9c7e896593fa/68747470733a2f2f696d672d626c6f672e6373646e696d672e636e2f696d675f636f6e766572742f36383534383866366361633363396462623033373162623839363461333839622e706e67)
(END)
文章首发公众号“算法进阶”,阅读原文可访问文章相关代码
The text was updated successfully, but these errors were encountered: