BN层导致TensorFlow2.3模型训练效果较差#
目录#
[toc]
1. 背景#
最近我把TensorFlow从1.13
升级到了2.3
版本,用来跑之前的一个验证码识别的项目时,发现1.13
准确率95%
以上,2.3
只有1%
。
完整代码请参考Github。
2. 模型#
模型结构如下。
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 80, 120, 1) 0
_________________________________________________________________
conv2d (Conv2D) (None, 78, 118, 32) 320
_________________________________________________________________
conv2d_1 (Conv2D) (None, 76, 116, 32) 9248
_________________________________________________________________
batch_normalization_v1 (Batc (None, 76, 116, 32) 128
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 38, 58, 32) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 36, 56, 64) 18496
_________________________________________________________________
conv2d_3 (Conv2D) (None, 34, 54, 64) 36928
_________________________________________________________________
batch_normalization_v1_1 (Ba (None, 34, 54, 64) 256
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 17, 27, 64) 0
_________________________________________________________________
conv2d_4 (Conv2D) (None, 15, 25, 128) 73856
_________________________________________________________________
conv2d_5 (Conv2D) (None, 13, 23, 128) 147584
_________________________________________________________________
batch_normalization_v1_2 (Ba (None, 13, 23, 128) 512
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 6, 11, 128) 0
_________________________________________________________________
conv2d_6 (Conv2D) (None, 4, 9, 256) 295168
_________________________________________________________________
conv2d_7 (Conv2D) (None, 2, 7, 256) 590080
_________________________________________________________________
batch_normalization_v1_3 (Ba (None, 2, 7, 256) 1024
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 1, 3, 256) 0
_________________________________________________________________
flatten (Flatten) (None, 768) 0
_________________________________________________________________
dropout (Dropout) (None, 768) 0
_________________________________________________________________
dense (Dense) (None, 248) 190712
_________________________________________________________________
reshape (Reshape) (None, 4, 62) 0
=================================================================
Total params: 1,364,312
Trainable params: 1,363,352
Non-trainable params: 960
_________________________________________________________________
构建模型的代码。
def captcha_model():
input_tensor = Input((height, width, channel))
x = input_tensor
for i in range(4):
x = Conv2D(32*2**i, (3,3) ,activation='relu', data_format = 'channels_last')(x)
x = Conv2D(32*2**i, (3,3) ,activation='relu', data_format = 'channels_last')(x)
x = BatchNormalization(axis = -1)(x)
x = MaxPooling2D( (2, 2), data_format = 'channels_last')(x)
x = Flatten()(x)
x = Dropout(0.7)(x)
x = Dense(n_len * n_class,activation = 'softmax')(x)
x = Reshape([n_len , n_class])(x)
model = Model(inputs = (input_tensor), outputs = x)
return model
model = captcha_model()
model.compile(loss='categorical_crossentropy',optimizer='adadelta',metrics=['accuracy'])
log = model.fit(gen(trainDir),
steps_per_epoch=125,
epochs=50,
workers=1,
validation_data=gen(valDir),
validation_steps=30,
callbacks=callbacks,
shuffle=False,
initial_epoch=0)
TensorFlow2.3
的log
。
epoch,accuracy,loss,val_accuracy,val_loss
0,0.015124999918043613,8.563679695129395,0.01588541641831398,4.2869744300842285
1,0.01668749935925007,8.47919750213623,0.02031249925494194,4.490324020385742
2,0.014875000342726707,8.407829284667969,0.01875000074505806,4.638314247131348
3,0.015312500298023224,8.317134857177734,0.01875000074505806,4.755492687225342
4,0.016374999657273293,8.158424377441406,0.01640624925494194,4.840750217437744
5,0.017625000327825546,8.092458724975586,0.01822916604578495,5.16778564453125
6,0.016062499955296516,8.025052070617676,0.01927083358168602,5.403683185577393
7,0.01924999989569187,7.923693656921387,0.01848958246409893,5.478787899017334
8,0.017249999567866325,7.8893632888793945,0.01796874962747097,5.470637321472168
9,0.018187500536441803,7.766856670379639,0.01640624925494194,5.435889720916748
10,0.015687499195337296,7.725063323974609,0.01796874962747097,5.396360397338867
11,0.017374999821186066,7.6679534912109375,0.01848958246409893,5.356987476348877
12,0.018437499180436134,7.564635753631592,0.01848958246409893,5.319505214691162
13,0.017625000327825546,7.515840530395508,0.01692708395421505,5.276818752288818
14,0.017812499776482582,7.390673637390137,0.01770833320915699,5.235841751098633
15,0.015937499701976776,7.373181343078613,0.01666666753590107,5.200058937072754
16,0.017000000923871994,7.2762250900268555,0.01640624925494194,5.161618709564209
17,0.01693749986588955,7.20680046081543,0.01588541641831398,5.12495231628418
18,0.018124999478459358,7.1355390548706055,0.01666666753590107,5.095307350158691
19,0.017000000923871994,7.079192161560059,0.01640624925494194,5.062613487243652
20,0.020375000312924385,7.007359981536865,0.01640624925494194,5.035354137420654
21,0.017124999314546585,6.972399711608887,0.01666666753590107,5.006525993347168
22,0.02031249925494194,6.923651218414307,0.01692708395421505,4.98254919052124
23,0.017937500029802322,6.880887508392334,0.01718750037252903,4.952263355255127
24,0.019874999299645424,6.814514636993408,0.01718750037252903,4.927443981170654
25,0.017625000327825546,6.816472053527832,0.01718750037252903,4.906290054321289
26,0.017500000074505806,6.758553504943848,0.01614583283662796,4.8787150382995605
27,0.02031249925494194,6.677739143371582,0.01718750037252903,4.855156898498535
28,0.01875000074505806,6.657727241516113,0.01718750037252903,4.835511684417725
29,0.018812499940395355,6.64086389541626,0.01692708395421505,4.814121723175049
30,0.017625000327825546,6.576359272003174,0.01692708395421505,4.791886806488037
TensorFlow1.13
的log
。
epoch,acc,loss,val_acc,val_loss
0,0.0255625,5.077181438446045,0.018489582,4.2680645942687985
1,0.07375,4.232825519561768,0.021614583,4.310072708129883
2,0.189125,3.3906055221557616,0.09348958,3.697092080116272
3,0.3220625,2.782533073425293,0.3734375,2.67845033009847
4,0.4459375,2.2324572772979736,0.54010415,2.03217138449351
5,0.556875,1.7689155778884889,0.6799479,1.4857364217440288
6,0.627625,1.4569332599639893,0.70182294,1.284301801522573
7,0.6884375,1.1837729458808899,0.7653646,0.9562915066878
8,0.7413125,0.9627125940322876,0.8127604,0.7591553310553233
9,0.7959375,0.7573423423767089,0.83229166,0.6807228187719981
10,0.825875,0.6350686495304108,0.8546875,0.5848873257637024
11,0.853,0.5289459655284882,0.88411456,0.4935603012641271
12,0.884875,0.423498391866684,0.89765626,0.4287114977836609
13,0.8994375,0.36352477538585665,0.90260416,0.38502170890569687
14,0.9164375,0.29950085139274596,0.91197914,0.3697542185584704
15,0.9298125,0.2621841303110123,0.9252604,0.3068610819677512
16,0.9365,0.22799357956647873,0.9278646,0.29371654242277145
17,0.9459375,0.20434909224510192,0.9328125,0.26822425176699954
18,0.9556875,0.1652733843922615,0.93671876,0.24204065936307112
19,0.9613125,0.14448723024129867,0.94427085,0.22956395708024502
20,0.96525,0.13079090192914009,0.9427083,0.24467742964625358
21,0.9688125,0.1233107333779335,0.953125,0.2123609036207199
22,0.97225,0.10828567123413085,0.95234376,0.18939945536355177
23,0.9753125,0.09439545741677284,0.94635415,0.21550938238700232
24,0.978,0.08019113063812255,0.9427083,0.23453998242815335
25,0.9799375,0.07622397544980049,0.95416665,0.18319092725093167
26,0.981875,0.07041979816555977,0.95520836,0.18539929799735547
27,0.981625,0.06994605718553067,0.9557292,0.1825941739603877
28,0.9823125,0.06756504848599434,0.9578125,0.17843913696706296
29,0.9873125,0.05395500122010708,0.95390624,0.17013602449248236
30,0.987625,0.053968257874250414,0.9505208,0.1960703587780396
3. 错误排查#
我在知乎上进行了提问,多数程序员朋友认为是 BatchNormalization
层除了问题。
我把模型中的 BN
层去除以后,发现 1.13
和 2.3
的效果基本相同了,通过实验也排除了由Python
版本、模型参数等原因导致问题的嫌疑。
基本上可以确定是 TensorFlow2.3
中的 BN
层出现了一些问题,导致模型效果较差。
4. 解决方案#
TensorFlow2.3
中可能也存在许多问题,在Github上还有3.8K
个未解决的Issues
,所以在模型中需要使用 BatchNormalization
层时,尽量选择 TensorFlow 1.x
版本。
联系邮箱:curren_wong@163.com
CSDN:https://me.csdn.net/qq_41729780
知乎:https://zhuanlan.zhihu.com/c_1225417532351741952
公众号:复杂网络与机器学习
欢迎关注/转载,有问题欢迎通过邮箱交流。