背景
论文:《Batch Normalization: Accelerating Deep Network Training byReducing Internal Covariate Shift 》
深度神经网络的训练事实上是复杂的,因为随着前一层的参数的变化,每一层的输入分布都会发生变化。这使得训练时需要更低的学习速率和更细致的参数初始化,导致训练速度降低,而且也很难训练出具有饱和非线性【将模型输出限定到有限区间】的模型。这种现象称为“内部协变量位移”(internal convariate shift),解决这种现象的方法就是归一化每一层的输入。
这种“协变量位移”可以理解为输入值的分布不同,也可以理解为输入特征值的scale差异较大,与权重进行矩阵相乘后,会产生一些偏离较大地差异值;而深度学习网络需要通过训练不断进行参数更新,那么差异值产生的变化都会深深影响后面层,偏离越大表现越为明显;因此,对于反向传播来说,这些现象都会导致梯度发散,从而需要更多的训练迭代次数来抵消scale不同带来的影响,因此导致最终训练收敛时间变得很长。
而Batch Normalization的作用就是将这些输入值进行归一化,将scale的差异降低至同一个范围内。这样做的好处在于一方面提高梯度的收敛程度,加快训练速度;另一方面使得每一层可以尽量面对同一特征分布的输入值,减少了变化带来的不确定性,也降低了对后面网络层的影响,使得各层网络变得相对独立。
Batch Normalization作用总结
优点
- 可以使用较大的学习率,不用过分担心网络参数初始化问题。
- 加速训练收敛速度。
- 允许在深层网络中使用sigmoid这种易导致梯度消失的激活函数;
- 可以作为一个正则化器,在某些情况下可以消除对Dropout的需要,具有防止网络过拟合的作用,提高网络的泛化能力。
- 可以打乱样本训练顺序(这样就不可能出现同一张照片被多次选择用来训练)论文中提到可以提高1%的精度。
缺点
- BN依赖于batch_size,batch_size太小(小于32)模型效果不佳,batch_size减小模型效果急剧下降。
- 对于图片生成任务如:图片超分辨率和风格迁移,效果不佳。因为此类任务更加关注单张图片本身特有的一些细节信息。而BN更加关注数据整体的分布,在判别模型如图像分类识别方面效果较好。
- RNN 、LSTM等动态网络使用 BN 效果不佳(应该选择Instance Normalization)
- BN的测试过程和训练过程不连续。测试过程使用的均值和方差来自于训练集,若测试集与训练集分布不一致,测试集以此均值和方差进行归一化,并不合理。
Batch Normalization
Batch normalization计算公式如下:
从上图可以看出,通过计算mini-batch输入数据均值和方差后,对mini-batch的数据进行归一化,再加上β和γ对归一化后的数据进行平移和缩放。
通过对cirfar10数据集进行batch normalization分析,可以清晰的知道BN层的作用,如下图:
从图像分析可知,数据经过标准化后,其形状保持大致不变,但尺寸被我们压缩至(-1, 1)之间,而原尺寸在(-80,80)之间。通过平移和缩放,BN可以使数据被限定在我们想要的范围内,所以每层的输出数据都进行BN的话,可以使后续网络层具有稳定的输入值,降低梯度发散的可能,从而加快训练速度;同时也意味着允许使用大点的学习率,加快收敛过程。
因为可以对归一化后的数据进行缩放和平移,将归一化后的数据限制到某一范围,使得sigmoid或tanh等激活函数在深层网络中不会因为数据差异值导致激活函数失效或者对数据不敏感,比如说归一化后数据分布在(-10,10)这个范围内,对tanh函数来说输入值绝对值大于3后就对输入数据不再敏感,因为激活输出值要么是0要么是1,也就是说激活函数失效了,随着网络层数的增加会导致梯度消失,因此对归一化后的数据进行缩放和平移很重要,可以确保归一化的输出值在进入激活函数前分布在激活函数起作用的区域(也可以理解为非线性激活函数的线性区域),并且在实际应用中β和γ参数是可以学习的,也可以做到自适应激活函数,具体参考如下图。
这也是为什么batch normalization层一般情况下放在激活层的前面的原因。
BN完整训练流程
BN反向传播论文公式推导
Batch normalization前向传播示意图:
其中,表示输入,先计算输入均值和方差,得到归一化后的输入
,最后执行平移和缩放。
论文中BN反向传播公式:
其中,ℓ表示loss,其详细推到过程如下:
(1)
其中需要注意的是求和符号,因为
是针对batch的,所以其偏导应该是batch中所有样本对其的偏导之和。也可以理解为
和所有
都相关,而不是特定的某一个,所以其链式推导的中间变量是
而不是
.
同理:
(2)
由公式(3)可知,可以将
视为
的多元复合函数,则有:
(4)
可以进一步对上式进行拆分,因为是关于
和
的复合函数即可以将
视为
。
(5)
因此可以将等式(4)拆分为如下几项(其实也可以将理解为关于
的复合函数,其中
都与
有关):
(6)
其中:
(7)
(8)
(9)
(10)
(11)
(12)
(13)
(14)
(15)
其中公式12中因为所有样本i=1,…,m都与有关系,所以需要用求和公式,同样由等式(3)和等式
(16)可知,可以将
理解为关于自变量
的复合函数即
,因此可以得到等式(11)。同时由等式(3)可知,可以将
理解为关于自变量
的函数,因此可以得到等式(9).
故由等式(6)-(15)可得:
其中:
到此,就完成了论文中反向传播公式推导,但是在实际BN反向传播代码实现中已知的梯度只有,因此我们可以将
换成
的表达式。
公式拓展
计算:
(16)
计算:
计算:
计算:
多元复合函数链式法则
上述公式推导涉及到了多元复合函数链式法则,公式如下:
有复合函数,其中
都是关于
的函数,即
则函数h对
的偏导分别为:
Python代码实现
def batchnorm_forward(x, gamma, beta, bn_param):
"""
Forward pass for batch normalization.
During training the sample mean and (uncorrected) sample variance are
computed from minibatch statistics and used to normalize the incoming data.
During training we also keep an exponentially decaying running mean of the mean
and variance of each feature, and these averages are used to normalize data
at test-time.
At each timestep we update the running averages for mean and variance using
an exponential decay based on the momentum parameter:
running_mean = momentum * running_mean + (1 - momentum) * sample_mean
running_var = momentum * running_var + (1 - momentum) * sample_var
Note that the batch normalization paper suggests a different test-time
behavior: they compute sample mean and variance for each feature using a
large number of training images rather than using a running average. For
this implementation we have chosen to use running averages instead since
they do not require an additional estimation step; the torch7 implementation
of batch normalization also uses running averages.
Input:
- x: Data of shape (N, D)
- gamma: Scale parameter of shape (D,)
- beta: Shift paremeter of shape (D,)
- bn_param: Dictionary with the following keys:
- mode: 'train' or 'test'; required
- eps: Constant for numeric stability
- momentum: Constant for running mean / variance.
- running_mean: Array of shape (D,) giving running mean of features
- running_var Array of shape (D,) giving running variance of features
Returns a tuple of:
- out: of shape (N, D)
- cache: A tuple of values needed in the backward pass
"""
mode = bn_param['mode']
eps = bn_param.get('eps', 1e-5)
momentum = bn_param.get('momentum', 0.9)
N, D = x.shape
running_mean = bn_param.get('running_mean', np.zeros(D, dtype=x.dtype))
running_var = bn_param.get('running_var', np.zeros(D, dtype=x.dtype))
out, cache = None, None
if mode == 'train':
sample_mean = np.sum(x, axis = 0)/N
sample_var = np.sum((x - sample_mean)**2, axis = 0)/N
x_norm = (x - sample_mean)/np.sqrt(sample_var + eps)
out = gamma*x_norm + beta
running_mean = momentum * running_mean + (1 - momentum) * sample_mean
running_var = momentum * running_var + (1-momentum) * sample_var
cache = (x, x_norm, sample_mean, sample_var, gamma, eps)
elif mode == 'test':
x_norm = (x - running_mean)/np.sqrt(running_var + eps)
out = gamma * x_norm + beta
else:
raise ValueError('Invalid forward batchnorm mode "%s"' % mode)
# Store the updated running means back into bn_param
bn_param['running_mean'] = running_mean
bn_param['running_var'] = running_var
return out, cache
def batchnorm_backward(dout, cache):
"""
Backward pass for batch normalization.
For this implementation, you should write out a computation graph for
batch normalization on paper and propagate gradients backward through
intermediate nodes.
Inputs:
- dout: Upstream derivatives, of shape (N, D)
- cache: Variable of intermediates from batchnorm_forward.
Returns a tuple of:
- dx: Gradient with respect to inputs x, of shape (N, D)
- dgamma: Gradient with respect to scale parameter gamma, of shape (D,)
- dbeta: Gradient with respect to shift parameter beta, of shape (D,)
"""
dx, dgamma, dbeta = None, None, None
(N, D) = dout.shape
(x, x_norm, sample_mean, sample_var, gamma, eps) = cache
# derivatives as defined in paper
dx_norm = gamma*dout
dvar = np.sum(dx_norm*(x-sample_mean)*((sample_var + eps)**(-3.0/2))*(-1.0/2), axis = 0)
dmean = np.sum(dx_norm*(-1.0/np.sqrt(sample_var + eps)),axis = 0) + dvar*(np.sum(-2*(x-sample_mean), axis = 0)*(1.0/N))
dx = dx_norm*(1.0/np.sqrt(sample_var + eps)) + dvar*(2.0*(x-sample_mean)/N) + dmean*(1.0/N)
dgamma = np.sum(dout*x_norm, axis = 0)
dbeta = np.sum(dout, axis = 0)
return dx, dgamma, dbeta
参考链接: