相关文章推荐
淡定的鸭蛋  ·  Azure App service get ...·  8 月前    · 
有胆有识的鸵鸟  ·  js数组删除最后两位 - CSDN文库·  9 月前    · 
耍酷的脸盆  ·  iOS开发系列--音频播放、录音、视频播放、 ...·  11 月前    · 
宽容的黑框眼镜  ·  cmip6数据处理:delta方法是每天的数 ...·  1 年前    · 
高大的电影票  ·  搬家第二天-28.Wincc ...·  1 年前    · 
Code  ›  通过一元线性回归模型理解梯度下降法开发者社区
https://cloud.tencent.com/developer/article/1367938
活泼的奔马
9 月前
Awesome_Tang

通过一元线性回归模型理解梯度下降法

前往小程序,Get 更优 阅读体验!
立即前往
腾讯云
开发者社区
文档 建议反馈 控制台
首页
学习
活动
专区
工具
TVP
最新优惠活动
文章/答案/技术大牛
发布
首页
学习
活动
专区
工具
TVP 最新优惠活动
返回腾讯云官网
Awesome_Tang
首页
学习
活动
专区
工具
TVP 最新优惠活动
返回腾讯云官网
社区首页 > 专栏 > 通过一元线性回归模型理解梯度下降法

通过一元线性回归模型理解梯度下降法

作者头像
Awesome_Tang
发布 于 2018-12-04 15:39:11
1.2K 1
发布 于 2018-12-04 15:39:11
举报
文章被收录于专栏: FSociety

关于线性回归相信各位都不会陌生,当我们有一组数据(譬如房价和面积),我们输入到excel,spss等软件,我们很快就会得到一个拟合函数:

h_\theta(x)=\theta_0+\theta_1x
h_\theta(x)=\theta_0+\theta_1x

但我们有没有去想过,这个函数是如何得到的? 如果数学底子还不错的同学应该知道,当维数不多的时候,是可以通过正规方程法求得的,但如果维数过多的话,像图像识别/自然语言处理等领域,正规方程法就没法满足需求了,这时候便需要 梯度下降法 来实现了。

梯度下降法

首先我们需要知道一个概念

  • 损失函数(loss function)
J(\theta_0,\theta_1)
J(\theta_0,\theta_1)

损失函数是用来测量你的预测值

f(x)
f(x)

与实际值之间的不一致程度,我们需要做的就是找到一组

\theta_0,\theta_1
\theta_0,\theta_1

使得

J(\theta_0,\theta_1)
J(\theta_0,\theta_1)

最小,这组

\theta_0,\theta_1
\theta_0,\theta_1

便叫做 全局最优解 。 我们需要定义一个损失函数,在线性回归问题中我们一般选择 平方误差代价函数 :

J(\theta_0,\theta_1)= \frac{1}{2m}\sum_{i=1}^{m}(h_\theta(x_i)-y_i)^2
J(\theta_0,\theta_1)= \frac{1}{2m}\sum_{i=1}^{m}(h_\theta(x_i)-y_i)^2

我们的目标是

minimizeJ(\theta_0,\theta_1)
minimizeJ(\theta_0,\theta_1)

如果不好理解的话我们通过图形来理解:

图2

假设上图是我们的

J(\theta_o,\theta_1)
J(\theta_o,\theta_1)

,那我们需要找到的就是左边箭头指向的那个点,这个点对应的

\theta_0,\theta_1
\theta_0,\theta_1

便是我们找的全局最优解,当然对于其他模型可能会存在局部最优解,譬如右边箭头指向的点,但是对于线性模型,只会存在全局最优解,真正的图像模型如下图所示,是个碗状的,我们要做的是找到碗底,这样是不是很好理解了。 那么如何到达最底呢,我们再看一张图。 我们需要从绿点到达红点,我们需要确定的有两件事情

  • 朝哪个方向走;
  • 走多远。

第一个问题,我们需要回忆下高中的数学知识—— 导数 ,在二维空间里面,导数是能代表函数上升下降快慢及方向的,这个各位在脑子里面想一个就明白,函数上升,导数为正,上升越快,导数越大,下降反之。扩展到多维空间,便是偏导数(

\frac{\partial}{\partial\theta_0 }J(\theta_0,\theta_1),\frac{\partial}{\partial\theta_1}J(\theta_0,\theta_1)
\frac{\partial}{\partial\theta_0 }J(\theta_0,\theta_1),\frac{\partial}{\partial\theta_1}J(\theta_0,\theta_1)

)。 第二个问题,走多远或者说步长,这里便需要我们自己定义,在梯度下降法中叫做 学习率

(\alpha),
(\alpha),

。 接下来放公式:

\theta_0:=\theta_0-\alpha\frac{\partial}{\partial\theta_0 }J(\theta_0,\theta_1)
\theta_0:=\theta_0-\alpha\frac{\partial}{\partial\theta_0 }J(\theta_0,\theta_1)
\theta_1:=\theta_1-\alpha\frac{\partial}{\partial\theta_1}J(\theta_0,\theta_1)
\theta_1:=\theta_1-\alpha\frac{\partial}{\partial\theta_1}J(\theta_0,\theta_1)

这边就不推导了,偏导数自己也快忘记的差不多了,直接放结果:

\theta_0:=\theta_0-\alpha\frac{1}{m}\sum_{i=1}^{m}(h_\theta(x_i)-y_i)
\theta_0:=\theta_0-\alpha\frac{1}{m}\sum_{i=1}^{m}(h_\theta(x_i)-y_i)
\theta_1:=\theta_1-\alpha\frac{1}{m}\sum_{i=1}^{m}(h_\theta(x_i)-y_i)x_i
\theta_1:=\theta_1-\alpha\frac{1}{m}\sum_{i=1}^{m}(h_\theta(x_i)-y_i)x_i

接下来迭代去更新

\theta_0,\theta_1
\theta_0,\theta_1

直至收敛就好了。

python实现

我们通过

y = 2x+1
y = 2x+1

生成一些随机点,注意

y = 2x+1
y = 2x+1

并不是我们的最优解:

代码语言: javascript
复制
# 以y= 2x+1为原型生成一个散点图
# 此时最优解并不是y = 2x+1
X0 = np.ones((100, 1))
X1 = np.random.random(100).reshape(100,1)
X = np.hstack((X0,X1))
y = np.zeros(100).reshape(100,1)
for i , x in enumerate(X1):
    val = x*2+1+random.uniform(-0.2,0.2)
    y[i] = val
plt.figure(figsize=(8,6))
plt.scatter(X1,y,color='g')
plt.plot(X1,X1*2+1,color='r',linewidth=2.5,linestyle='-')
plt.show()

out 迭代部分:

代码语言: javascript
复制
# 梯度下降法求最优解
def gradientDescent(X,Y,times = 1000, alpha=0.01):
    alpha:学习率,默认0.01
    times:迭代次数,默认1000次
    m = len(y)
    theta = np.array([1,1]).reshape(2, 1)
    loss = {}
    for i in range(times):
        diff = np.dot(X,theta)- y
        cost = (diff**2).sum()/(2.0*m)
        loss[i] = cost
        theta = theta - alpha*(np.dot(np.transpose(X), diff)/m)
    plt.figure(figsize=(8,6))
    plt.scatter(loss.keys(),loss.values(),color='r')
    plt.show()
    return theta
theta = gradientDescent(X,Y)

默认设置的迭代1000次,学习率为0.01,最后结果如下:

  • 损失函数
 
推荐文章
淡定的鸭蛋  ·  Azure App service get 503 The service is unavailable. - Microsoft Q&A
8 月前
有胆有识的鸵鸟  ·  js数组删除最后两位 - CSDN文库
9 月前
耍酷的脸盆  ·  iOS开发系列--音频播放、录音、视频播放、拍照、视频录制 - KenshinCui - 博客园
11 月前
宽容的黑框眼镜  ·  cmip6数据处理:delta方法是每天的数据都进行了降尺度偏差订正,还是月尺度的数据进行了偏差订正? - 海盐summer 的回答 - 知乎
1 年前
高大的电影票  ·  搬家第二天-28.Wincc V7.3使用Microsoft Hierarchical Flexgrid控件显示SQL Server数据表并导出到excel文件图表 - 来自金沙江的小鱼 - 博客园
1 年前
今天看啥   ·   Py中国   ·   codingpro   ·   小百科   ·   link之家   ·   卧龙AI搜索
删除内容请联系邮箱 2879853325@qq.com
Code - 代码工具平台
© 2024 ~ 沪ICP备11025650号