📌  相关文章
📜  fit_generator() 得到了一个意外的关键字参数 'samples_per_epoch' (1)

📅  最后修改于: 2023-12-03 15:00:46.243000             🧑  作者: Mango

fit_generator() 得到了一个意外的关键字参数 samples_per_epoch

问题描述

调用 fit_generator() 方法时,出现了一个意外的关键字参数 samples_per_epoch 的错误提示。

原因分析

在 Keras 2.0 版本后,fit_generator() 方法的参数发生了变化,samples_per_epoch 参数被改为了 steps_per_epoch,而在该版本之前,是存在 samples_per_epoch 参数的。因此,当代码中使用了该参数时,就会提示关键字参数错误。

解决方案

  • 方案一:如果使用的是旧版本的 Keras,则可以保留 samples_per_epoch 参数,不会出现错误,但建议及时更新 Keras 版本或者修改代码中的参数名称。

  • 方案二:如果使用的是新版本的 Keras,则需要将 samples_per_epoch 参数替换为 steps_per_epoch,即修改如下:

    model.fit_generator(
         generator=train_generator,
         steps_per_epoch=1000,
         epochs=10,
         validation_data=val_generator,
         validation_steps=50
    )
    

    注意:steps_per_epoch 表示每一个 epoch 中包含的 batch 数量,该参数与数据集大小、batch 大小等相关。可以通过以下方式计算:

    steps_per_epoch = dataset_size // batch_size
    

    其中,dataset_size 表示数据集大小,batch_size 表示 batch 大小。

以上就是关于 fit_generator() 方法中意外的关键字参数 samples_per_epoch 的介绍。建议开发者在使用该方法时,结合 Keras 版本及参数名称的变化,及时调整相关代码。