📅  最后修改于: 2023-12-03 14:52:23.152000             🧑  作者: Mango
当我们使用 Keras 进行图像分类或目标检测时,通常需要将数据集分成训练集和测试集。Keras 提供了 ImageDataGenerator 类来方便地进行数据增强和扩充,本文将介绍如何在 ImageDataGenerator 中进行训练测试拆分。
在进行训练测试拆分前,我们需要先将图像数据准备好。我们以 CIFAR-10 数据集为例,该数据集包含 50,000 张训练图像和 10,000 张测试图像。我们可以通过以下方式加载数据集:
from keras.datasets import cifar10
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
训练集与测试集分别被加载到 X_train 和 X_test 二维数组中,标签集被加载到 y_train 和 y_test 一维数组中。
我们可以使用 Scikit-Learn 的 train_test_split 函数将训练集进一步分成训练集和验证集:
from sklearn.model_selection import train_test_split
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)
其中,test_size 参数指定验证集所占的比例,random_state 参数将随机数发生器的种子设为了固定值 42,以确保每次运行代码时划分结果相同。
接下来,我们使用 ImageDataGenerator 对训练集和验证集进行数据增强和扩充。
from keras.preprocessing.image import ImageDataGenerator
train_datagen = ImageDataGenerator(
rescale=1./255,
shear_range=0.1,
zoom_range=0.1,
horizontal_flip=True
)
val_datagen = ImageDataGenerator(
rescale=1./255
)
train_generator = train_datagen.flow(
X_train,
y_train,
batch_size=32
)
val_generator = val_datagen.flow(
X_val,
y_val,
batch_size=32
)
ImageDataGenerator 的参数中,rescale 对图像进行缩放归一化,shear_range 和 zoom_range 通过随机变换来增加数据集的随机性,horizontal_flip 对数据进行水平翻转。
flow 函数将图像和标签以批量形式生成,并支持在训练和验证期间使用不同的数据增强和扩充。
本文介绍了如何使用 Scikit-Learn 和 Keras 的 ImageDataGenerator 实现训练测试拆分,并进行了简单的数据增强和扩充。通过训练测试拆分,我们可以有效评估模型的性能和泛化能力,提高模型的准确性和鲁棒性。
参考资料: