龙空技术网

什么是小批量梯度下降、批量梯度下降和随机梯度下降

小黑黑讲AI 121

前言:

现时各位老铁们对“梯度下降算法简述”大概比较重视,兄弟们都需要了解一些“梯度下降算法简述”的相关资讯。那么小编同时在网摘上搜集了一些对于“梯度下降算法简述””的相关文章,希望各位老铁们能喜欢,我们一起来了解一下吧!

什么是小批量梯度下降、批量梯度下降和随机梯度下降,它们之间有什么区别呢?

如何使用Pytorch实现一个标准的小批量梯度下降?

这篇文章将帮助你理清上面这两个问题。

梯度下降算法,有三种常见形式,分别是批量梯度下降、随机梯度下降和小批量梯度下降。

下面我们分别讲解这三种梯度下降算法,并详细的介绍小批量梯度下降的实现方式。

批量梯度下降:

在每次迭代中,批量梯度下降都会基于所有的训练样本,计算损失函数的梯度。

因此,我们可以得到一条平滑的收敛曲线。

例如,在训练集中有100个样本,迭代50轮。

那么在每一轮迭代中,都会一起使用这100个样本,计算整个训练集的梯度,并对模型更新。

所以总共会更新50次梯度。

因为每次迭代都会使用整个训练集计算梯度,所以这种方法可以得到准确的梯度方向。

但如果数据集非常大,那么就导致每次迭代都很慢,计算成本就会很高。

随机梯度下降:

会在一轮完整的迭代过程中,遍历整个训练集。

但是每次更新,都只基于一个样本,计算梯度。

这样会得到一条震荡的收敛曲线:

例如,如果训练集有100个样本,迭代50轮,那么每一轮迭代,会遍历这100个样本,每次会计算某一个样本的梯度,然后更新模型参数。

换句话说,100个样本,迭代50轮,那么就会更新100*50=5000次梯度。

因为每次只用一个样本训练,所以迭代速度会非常快。

但更新的方向会不稳定,这也导致随机梯度下降,可能永远都不会收敛。

不过也因为这种震荡属性,使得随机梯度下降,可以跳出局部最优解。这在某些情况下,是非常有用的。

小批量梯度下降:

结合了批量梯度下降和随机梯度下降的优点。

在每次迭代时,都会从训练集中,随机的选择一组小批量的样本,来计算梯度,更新模型。

例如,如果训练集中有100个样本,迭代50轮。

如果设置小批量的数量是20,那么在每一轮迭代中,会有5次小批量迭代。

换句话说,就是将100个样本分成5个小批量,每个小批量20个数据,每次迭代用一个小批量。

因此,按照这样的方式,会对梯度,进行50轮*5个小批量=250次更新。

小批量梯度下降结合了随机梯度下降的高效性和批量梯度下降的稳定性。

它比随机梯度下降有更稳定的收敛,同时又比批量梯度下降计算的更快。

另外,由于小批量的随机性,还能使迭代跳出局部最优解。

因此,小批量梯度下降是最为常见的模型训练方式。

接下来,我们基于一元线性回归问题,来说明如何实现一个标准的,小批量梯度下降算法。

小批量数据的准备

首先来看训练数据的生成。

使用random.seed,设置一个固定的随机种子,可以确保每次运行,都得到相同的数据,方便调试。

然后随机生成100个横坐标x,它的范围在0到2之间。

生成带有噪音的纵坐标y。数据基本分布在y=4+3x的附近。

这样就生成了如图所示的训练数据。

接着,将训练数据x和y转为张量。

使用TensorDataset,将x和y组成训练集dataset,并使用DataLoader,构造随机的小批量数据。

这里设置参数shuffle=True,代表随机打乱数据的顺序。

参数batch_size=16,代表每一个小批量的数据规模是16,也就是每16个数据,作为一组训练数据。

打印dataloader的长度,结果是7,它对应了100/16=6.25。

也就是总共100个样本,每16个数据会作为1个batch,100个样本被分为了7个batch。

这里要注意,最后一个batch只包括4个数据。

例如,运行结果如下:

遍历dataloader,会得到编号为0到6的7个batch。

batch0到batch5,每个batch有16个样本,batch6,有4个样本。

小批量梯度下降算法

设置待迭代的直线参数为w和b。

然后进入模型的迭代循环:

外层循环,代表了整个训练数据集的迭代轮数。这里一共迭代50轮。

内层循环代表了,在一个迭代轮次中,以小批量的方式,使用dataloader对数据进行遍历。

其中batch_idx表示当前遍历的批次。

data和label表示这个批次的训练数据和标记。

对于每次迭代,首先计算当前直线的预测值,保存到h。

然后计算预测值h和真实值y之间的均方误差,保存到loss中。

使用loss.backward()进行反向传播,计算代价loss关于参数w和b的偏导数。

接着进行梯度下降,沿着梯度的反方向,更新w和b的值。

最后清空张量w和b中的梯度信息,为下一次迭代做准备。

另外,每次迭代,都会打印当前迭代的轮数epoch、数据的批次batch_idx和损失值loss。

完成迭代后,打印w和b的值,并绘制直线。

运行程序,就得到了蓝色的线性回归直线。

那么到这里,批量梯度下降、随机随机梯度下降、小批量梯度下降就讲完了,感谢大家的观看,我们下节课再会。

标签: #梯度下降算法简述