Pytorch深度学习实战2-1:详细推导Xavier参数初始化(附Python实现)

目录

  • 1 参数初始化
  • 2 Xavier参数初始化原理
    • 2.1 前向传播阶段
    • 2.2 反向传播阶段
    • 2.3 可视化思考
  • 3 Python实现

1 参数初始化

参数初始化在深度学习中起着重要的作用。在神经网络中,参数初始化是指为模型中的权重和偏置项设置初始值的过程。合适的参数初始化可以帮助模型更好地学习和收敛到最优解。参数初始化的目标是使模型具有良好的初始状态,以便在训练过程中快速且稳定地收敛。错误的参数初始化可能导致模型无法正常学习,梯度消失或梯度爆炸等问题。

常见的参数初始化方法包括随机初始化、零初始化、正态分布初始化和均匀分布初始化等。这些方法根据不同的分布特性和模型结构选择合适的初始值。在某些情况下,不同层或不同类型的参数可能需要不同的初始化策略。例如使用预训练模型时,可以采用迁移学习的方法,将预训练模型的参数作为初始值,从而加速收敛并提高性能。

除了设置初始值外,参数初始化还可以与其他优化技术相结合,如学习率调整、正则化和批归一化等,以进一步提高模型的性能和稳定性

举例而言,如图所示是在 t a n h ( ⋅ ) \rm{tanh(\cdot)} tanh()下九层神经网络各层激活输出,可以看到在网络深层激活输出逐渐衰减或保持不变

在这里插入图片描述

2 Xavier参数初始化原理

Xavier初始化的核心原理是保证各层网络的前向传播激活值和反向传播梯度值方差保持一致。Xavier初始化基于如下假设:

  • 输入样本独立同分布采样,且各个特征维度方差相等;
  • 激活函数 σ ( ⋅ ) \sigma \left( \cdot \right) σ()对称且近似线性区间满足 σ ( z ) ≈ z ⇔ σ ′ ( z ) ≈ 1 \sigma \left( \boldsymbol{z} \right) \approx \boldsymbol{z}\Leftrightarrow \sigma '\left( \boldsymbol{z} \right) \approx 1 σ(z)zσ(z)1
  • 激活输入 z \boldsymbol{z} z处于激活函数的线性区间

2.1 前向传播阶段

根据

a l = σ ( z l ) = σ ( W l a l − 1 − b l ) \boldsymbol{a}^l=\sigma \left( \boldsymbol{z}^l \right) =\sigma \left( \boldsymbol{W}^l\boldsymbol{a}^{l-1}-\boldsymbol{b}^l \right) al=σ(zl)=σ(Wlal1bl)

可得

v a r [ a l ] ≈ v a r [ z l ] = v a r [ W l a l − 1 − b l ] \mathrm{var}\left[ \boldsymbol{a}^l \right] \approx \mathrm{var}\left[ \boldsymbol{z}^l \right] =\mathrm{var}\left[ \boldsymbol{W}^l\boldsymbol{a}^{l-1}-\boldsymbol{b}^l \right] var[al]var[zl]=var[Wlal1bl]

初始阶段第 l l l层的网络权重 W l \boldsymbol{W}^l Wl的各个元素独立采样自某个分布 P P P,即

[ z 1 l z 2 l ⋮ z n l l ] = [ w 1 , 1 l w 1 , 2 l ⋯ w 1 , n l − 1 l w 2 , 1 l w 2 , 2 l ⋯ w 2 , n l − 1 l ⋮ ⋮ ⋱ ⋮ w n l , 1 l w n l , 2 l ⋯ w n l , n l − 1 l ] [ a 1 l − 1 a 2 l − 1 ⋮ a n l − 1 l − 1 ] ⇒ v a r [ z i l ] = v a r [ ∑ k = 1 n l − 1 w 1 , k l a k l − 1 ] \left[ \begin{array}{c} z_{1}^{l}\\ z_{2}^{l}\\ \vdots\\ z_{n_l}^{l}\\\end{array} \right] =\left[ \begin{matrix} w_{1,1}^{l}& w_{1,2}^{l}& \cdots& w_{1,n_{l-1}}^{l}\\ w_{2,1}^{l}& w_{2,2}^{l}& \cdots& w_{2,n_{l-1}}^{l}\\ \vdots& \vdots& \ddots& \vdots\\ w_{n_l,1}^{l}& w_{n_l,2}^{l}& \cdots& w_{n_l,n_{l-1}}^{l}\\\end{matrix} \right] \left[ \begin{array}{c} a_{1}^{l-1}\\ a_{2}^{l-1}\\ \vdots\\ a_{n_{l-1}}^{l-1}\\\end{array} \right] \Rightarrow \mathrm{var}\left[ z_{i}^{l} \right] =\mathrm{var}\left[ \sum_{k=1}^{n_{l-1}}{w_{1,k}^{l}a_{k}^{l-1}} \right] z1lz2lznll = w1,1lw2,1lwnl,1lw1,2lw2,2lwnl,2lw1,nl1lw2,nl1lwnl,nl1l a1l1a2l1anl1l1 var[zil]=var[k=1nl1w1,klakl1]

考虑到 w i , j l w_{i,j}^{l} wi,jl与前一层激活值 a l − 1 \boldsymbol{a}^{l-1} al1独立,所以

v a r [ z i l ] = v a r [ ∑ k = 1 n l − 1 w i , k l a k l − 1 ] = ∑ k = 1 n l − 1 v a r [ w i , k l a k l − 1 ] = ∑ k = 1 n l − 1 ( v a r [ w i , k l ] v a r [ a k l − 1 ] + v a r [ w i , k l ] E 2 [ a k l − 1 ] + v a r [ a k l − 1 ] E 2 [ w i , k l ] ) \begin{aligned}\mathrm{var}\left[ z_{i}^{l} \right] &=\mathrm{var}\left[ \sum_{k=1}^{n_{l-1}}{w_{i,k}^{l}a_{k}^{l-1}} \right]\\& =\sum_{k=1}^{n_{l-1}}{\mathrm{var}\left[ w_{i,k}^{l}a_{k}^{l-1} \right]}\\&=\sum_{k=1}^{n_{l-1}}{\left( \mathrm{var}\left[ w_{i,k}^{l} \right] \mathrm{var}\left[ a_{k}^{l-1} \right] +\mathrm{var}\left[ w_{i,k}^{l} \right] \mathbb{E} ^2\left[ a_{k}^{l-1} \right] +\mathrm{var}\left[ a_{k}^{l-1} \right] \mathbb{E} ^2\left[ w_{i,k}^{l} \right] \right)}\end{aligned} var[zil]=var[k=1nl1wi,klakl1]=k=1nl1var[wi,klakl1]=k=1nl1(var[wi,kl]var[akl1]+var[wi,kl]E2[akl1]+var[akl1]E2[wi,kl])

根据激活函数对称性,可令 W l \boldsymbol{W}^l Wl a l − 1 \boldsymbol{a}^{l-1} al1均值为0,根据假设中的方差关系

{ ∀ i    v a r [ a i l ] = v a r [ a l ] ∀ i , j    v a r [ w i , j l ] = v a r [ W l ] \begin{cases} \forall i\,\,\mathrm{var}\left[ a_{i}^{l} \right] =\mathrm{var}\left[ \boldsymbol{a}^l \right]\\ \forall i,j\,\,\mathrm{var}\left[ w_{i,j}^{l} \right] =\mathrm{var}\left[ \boldsymbol{W}^l \right]\\\end{cases} {ivar[ail]=var[al]i,jvar[wi,jl]=var[Wl]

上式可简化为 v a r [ z i l ] = n l − 1 v a r [ w i , 1 l ] v a r [ a 1 l − 1 ] \mathrm{var}\left[ z_{i}^{l} \right] =n_{l-1}\mathrm{var}\left[ w_{i,1}^{l} \right] \mathrm{var}\left[ a_{1}^{l-1} \right] var[zil]=nl1var[wi,1l]var[a1l1],改写成矩阵形式

v a r [ a l ] ≈ n l − 1 v a r [ W l ] v a r [ a l − 1 ] \mathrm{var}\left[ \boldsymbol{a}^l \right] \approx n_{l-1}\mathrm{var}\left[ \boldsymbol{W}^l \right] \mathrm{var}\left[ \boldsymbol{a}^{l-1} \right] var[al]nl1var[Wl]var[al1]

结合 a 0 = x \boldsymbol{a}^0=\boldsymbol{x} a0=x可递推得到

v a r [ a l ] ≈ v a r [ x ] ∏ k = 1 l n k − 1 v a r [ W k ] {\mathrm{var}\left[ \boldsymbol{a}^l \right] \approx \mathrm{var}\left[ \boldsymbol{x} \right] \prod_{k=1}^l{n_{k-1}\mathrm{var}\left[ \boldsymbol{W}^k \right]}} var[al]var[x]k=1lnk1var[Wk]

2.2 反向传播阶段

根据 δ l = ( W l + 1 ) T δ l + 1 ⊙ σ ′ ( z l ) \boldsymbol{\delta }^l=\left( \boldsymbol{W}^{l+1} \right) ^T\boldsymbol{\delta }^{l+1}\odot \sigma '\left( \boldsymbol{z}^l \right) δl=(Wl+1)Tδl+1σ(zl)可得

v a r [ δ l ] ≈ n l + 1 v a r [ W l + 1 ] v a r [ δ l + 1 ] \mathrm{var}\left[ \boldsymbol{\delta }^l \right] \approx n_{l+1}\mathrm{var}\left[ \boldsymbol{W}^{l+1} \right] \mathrm{var}\left[ \boldsymbol{\delta }^{l+1} \right] var[δl]nl+1var[Wl+1]var[δl+1]

结合 δ L = ∇ y ~ E ⊙ σ ′ ( z L ) ≈ ∇ y ~ E \boldsymbol{\delta }^L=\nabla _{\boldsymbol{\tilde{y}}}E\odot \sigma '\left( \boldsymbol{z}^L \right) \approx \nabla _{\boldsymbol{\tilde{y}}}E δL=y~Eσ(zL)y~E可递推得到

v a r [ δ l ] ≈ ∇ y ~ E ∏ k = l + 1 L n k v a r [ W k ] {\mathrm{var}\left[ \boldsymbol{\delta }^l \right] \approx \nabla _{\boldsymbol{\tilde{y}}}E\prod_{k=l+1}^L{n_k\mathrm{var}\left[ \boldsymbol{W}^k \right]}} var[δl]y~Ek=l+1Lnkvar[Wk]

为保证前向传播激活和反向传播梯度在网络中顺利流动,应保持各层参数方差相等,即满足

{ n l v a r [ W l ] = 1 n l − 1 v a r [ W l ] = 1 \begin{cases} n_l\mathrm{var}\left[ \boldsymbol{W}^l \right] =1\\ n_{l-1}\mathrm{var}\left[ \boldsymbol{W}^l \right] =1\\\end{cases} nlvar[Wl]=1nl1var[Wl]=1

由于第 l l l层的输入神经元个数 n l − 1 n_{l-1} nl1和输出神经元个数 n l n_l nl一般不相等,故取折中

v a r [ W l ] = 2 n l − 1 + n l \mathrm{var}\left[ \boldsymbol{W}^l \right] =\frac{2}{n_{l-1}+n_l} var[Wl]=nl1+nl2

所以网络连接权采样自服从方差满足上式的分布即可,例如

W ∼ N ( 0 , 2 n l − 1 + n l )    o r W ∼ U ( − 6 n l − 1 + n l , 6 n l − 1 + n l ) \boldsymbol{W}\sim \mathcal{N} \left( 0,\frac{2}{n_{l-1}+n_l} \right) \,\, \mathrm{or} \boldsymbol{W}\sim U\left( -\sqrt{\frac{6}{n_{l-1}+n_l}},\sqrt{\frac{6}{n_{l-1}+n_l}} \right) WN(0,nl1+nl2)orWU(nl1+nl6 ,nl1+nl6 )

2.3 可视化思考

如图所示,经过Xavier初始化后网络各层前向和反向传播时的方差保持一致

在这里插入图片描述

如图所示,经过Xavier初始化后的测试误差通常更小

在这里插入图片描述

Xavier进一步指出:观察层与层之间传播的激活值和梯度有利于理解深层网络的训练复杂度;保持层与层之间激活值和梯度的良好流动对学习效果非常重要。尽管在Xavier初始化做出了比较苛刻的假设,且在工程上很容易被违反,但其在实践中被证明是有效的,已经成为很多深度学习框架的默认初始化方法之一。

3 Python实现

简单实现一下Xavier初始化

python">def initialize_parameters_xavier(layers_dims):
    parameters = {}
    L = len(layers_dims)
    for l in range(1, L):
        mu = 0
        sigma = np.sqrt(2.0 / (layers_dims[l - 1] + layers_dims[l]))
        parameters['W' + str(l)] = np.random.normal(loc=mu, scale=sigma, size=(layers_dims[l], layers_dims[l - 1]))
        parameters['b' + str(l)] = np.zeros((layers_dims[l], 1))
    return parameters

可视化

python">for l in range(1, num_layers):
	A_pre = A
	W = parameters['W' + str(l)]
	b = parameters['b' + str(l)]
	z = np.dot(W, A_pre) + b # z = Wx + b
	
	A = tanh(z)
	
	print(A)
	plt.subplot(1, 8, l)
	plt.hist(A.flatten(), facecolor='g')
	plt.xlim([-2, 2])
	plt.ylim([0, 1000000])
	plt.yticks([])
plt.show()

如下所示
在这里插入图片描述
可以看出各层输出方差基本一致,实现了良好的初始化效果

完整工程代码请联系下方博主名片获取


🔥 更多精彩专栏

  • 《ROS从入门到精通》
  • 《Pytorch深度学习实战》
  • 机器学习强基计划》
  • 《运动规划实战精讲》

👇源码获取 · 技术交流 · 抱团学习 · 咨询分享 请联系👇

http://www.niftyadmin.cn/n/5219465.html

相关文章

YOLOv8独家原创改进: AKConv(可改变核卷积),即插即用的卷积,效果秒杀DSConv | 2023年11月最新发表

💡💡💡本文全网首发独家改进:可改变核卷积(AKConv),赋予卷积核任意数量的参数和任意采样形状,为网络开销和性能之间的权衡提供更丰富的选择,解决具有固定样本形状和正方形的卷积核不能很好地适应不断变化的目标的问题点,效果秒殺DSConv 1)AKConv替代标准卷积进行…

第二十三章 解析PR曲线、ROC曲线、AUC、AP(工具)

混淆矩阵Confusion Matrix 混淆矩阵定义 混淆矩阵是机器学习中总结分类模型预测结果的情形分析表,以矩阵形式将数据集中的记录按照真实的类别与分类模型预测的类别判断两个标准进行汇总。其中矩阵的行表示真实值,矩阵的列表示预测值,下面我…

笔记62:注意力汇聚 --- Nadaraya_Watson 核回归

本地笔记地址:D:\work_file\(4)DeepLearning_Learning\03_个人笔记\3.循环神经网络\第10章:动手学深度学习~注意力机制 a a a a a a a a a a a a a a a a

Intellij Idea 断点小圆变成灰色怎么处理

场景1:变成了灰色实心圆 原因 断点变成灰色通常表示该断点处于失效状态。这可能是由于无意中点击了debug调试下方的“mute breakpoints”按钮导致的。 解决方案 依次点击设置小图标->View Options->Mute BreakPoints. 点击后 Mute BrakPoints左侧显示✔ 符号…

H5 uniapp 接入wx sdk

uniapp因为要兼容小程序等,会重写wx对象,导致引入的jweixin-1.6.0.js中对象不生效。 综合网络资料,有两种解决方案: 一,通过npm工具引入 npm install jweixin-module --save 实际上是借用了wx的另一个对象jWeixin …

TDA4开发环境Docker化

文章目录 背景1. TDA4X Linux SDK编译环境镜像构建1.1 安装SDK1.2 验证制卡1.2.1 出现的问题:1.3 验证编译1.3.1 出现的问题2. TDA4X Linux-RT SDK编译环境镜像构建2.1 安装SDK2.2 出现的问题参考背景 开始阅读本篇前,假设你已经对docker有了一定了解,且有过docker换件搭建…

C和C++混合编程时CMakeLists编写需要注意的地方

事件还原: 组里小兄弟在做一个研发Demo的时候,基于一个开源C库和一个开源C库混合编程,cmake以后编译到100%,在链接阶段报错undefined reference to ‘xxxxxx’,此处报错的函数是C库中的一个函数,经检查CMak…

歌手荆涛演唱的《春节回家》,一种情感的表达和文化的传承

歌手荆涛演唱的《春节回家》,一种情感的表达和文化的传承 春节回家,是中国传统文化中最为重要的传统节日之一,也是亿万华夏儿女最为期待的日子。每当春节临近,无论身在何处,人们都会收拾行囊,踏上归途&…