代码链接:

https://download.csdn.net/download/qq_38649386/12667825

 

实验结果:

问题&解决办法:

 

1.“csv_{0}.log”,log日志文件,路径设置:

CSVLogger_ = keras.callbacks.CSVLogger('csv/csv_{0}.log'.format(N_ADDITION), separator=',', append=False)

优化:

# 增加的神经元数量
# N_ADDITION=0,128,256,512,1024,2048,并观察实验结果
N_ADDITION = 512

还能继续优化不用手动更新

产生文件:

 

2.RuntimeError: You must compile your model before training/testing. Use `model.compile(optimizer, loss)
 

错误代码:

CSVLogger_ = keras.callbacks.CSVLogger('csv/csv_{0}.log'.format(N_ADDITION), separator=',', append=False)


model.fit(X_train, y_train, batch_size=BATCH_SIZE, epochs=NB_EPOCH, verbose=VERBOSE,validation_split=VALIDATION_SPLIT,callbacks=[CSVLogger_])

正确代码:

CSVLogger_ = keras.callbacks.CSVLogger('csv/csv_{0}.log'.format(N_ADDITION), separator=',', append=False)

model.compile(loss='categorical_crossentropy', optimizer=OPTIMIZER,metrics=['accuracy'])

model.fit(X_train, y_train, batch_size=BATCH_SIZE, epochs=NB_EPOCH, verbose=VERBOSE,validation_split=VALIDATION_SPLIT,callbacks=[CSVLogger_])

 

 

3.错误代码:

path = 'weights/csv/csv_{0}.log'.format(j)

正确代码:依照上文自己设定的path路径填写

path = 'csv/csv_{0}.log'.format(j)

 

4.KeyError: "['acc' 'val_acc'] not in index"

错误代码:

sns.lineplot(data=data[['acc','val_acc']])

正确写法:

#绘制曲线
    sns.lineplot(data=data[['accuracy','val_accuracy']])

 

5.AttributeError: 'numpy.float64' object has no attribute 'values'

代码错误定位:

正确:

 

6.plt.show() 图标重叠,如图:

添加代码:

# 方法一 效果最好
plt.tight_layout()

# 方法二 设置表格大小
#plt.figure(figsize=(16, 12))
# 方法三 设置尺寸
#plt.tight_layout(pad=0.1, w_pad=1.0, h_pad=1.0)

 

实验原理&代码讲解:

--》有空再写

 

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐