📅  最后修改于: 2023-12-03 15:00:46.243000             🧑  作者: Mango
fit_generator()
得到了一个意外的关键字参数 samples_per_epoch
问题描述
调用 fit_generator()
方法时,出现了一个意外的关键字参数 samples_per_epoch
的错误提示。
原因分析
在 Keras 2.0 版本后,fit_generator()
方法的参数发生了变化,samples_per_epoch
参数被改为了 steps_per_epoch
,而在该版本之前,是存在 samples_per_epoch
参数的。因此,当代码中使用了该参数时,就会提示关键字参数错误。
解决方案
方案一:如果使用的是旧版本的 Keras,则可以保留 samples_per_epoch
参数,不会出现错误,但建议及时更新 Keras 版本或者修改代码中的参数名称。
方案二:如果使用的是新版本的 Keras,则需要将 samples_per_epoch
参数替换为 steps_per_epoch
,即修改如下:
model.fit_generator(
generator=train_generator,
steps_per_epoch=1000,
epochs=10,
validation_data=val_generator,
validation_steps=50
)
注意:steps_per_epoch
表示每一个 epoch 中包含的 batch 数量,该参数与数据集大小、batch 大小等相关。可以通过以下方式计算:
steps_per_epoch = dataset_size // batch_size
其中,dataset_size
表示数据集大小,batch_size
表示 batch 大小。
以上就是关于 fit_generator()
方法中意外的关键字参数 samples_per_epoch
的介绍。建议开发者在使用该方法时,结合 Keras 版本及参数名称的变化,及时调整相关代码。