📅  最后修改于: 2023-12-03 15:32:28.135000             🧑  作者: Mango
Keras 是一个高级神经网络API,能够在 TensorFlow, CNTK 或 Theano 等后端引擎之上运行。Keras 中的 fit()
和 fit_generator()
方法是用于训练模型的两个主要方法。这两个方法的区别在于输入数据的形式不同,fit()
方法适用于小规模数据,而 fit_generator()
适用于大规模数据。
fit()
方法是 Keras 中最常用的训练方法之一。它接受四个参数:
x
:训练数据,采用 Numpy 数组格式。y
:训练标签,采用 Numpy 数组格式。batch_size
:每个批次的样本数,一般取2的幂,例如32、64、128等。epochs
: 训练数据的迭代次数,通常建议运行 20 至 50 个 epochs。下面是 fit()
方法的示例代码:
model.fit(x_train, y_train, batch_size=32, epochs=50)
fit_generator()
方法与 fit()
方法类似,不同之处在于它需要传递一个数据生成器作为输入,而不是一次性传递所有的数据。这个方法接受三个参数:
generator
:数据生成器,用于逐个生成数据批次。steps_per_epoch
:每个 epoch 中的步数,其中一个 epoch 完成后,我们将启动下一 epoch。对于 Sequence 输入,它应该通常等于数据集大小除以批量大小。例如,数据集大小为 1000,批量大小为 32,则 steps_per_epoch
应该为 31。epochs
:训练数据的迭代次数,通常建议运行 20 至 50 个 epochs。下面是 fit_generator()
方法的示例代码:
model.fit_generator(train_datagen.flow(train_images, train_labels, batch_size=32),
steps_per_epoch=len(train_images)/32, epochs=50)
这两个方法都非常实用,具体选择哪一个,需要根据数据集的大小、计算资源和任务需求来决定,对于大规模数据,建议使用 fit_generator()
方法。无论哪种方法,都需要认真调试超参数,以达到最佳性能。