📅  最后修改于: 2023-12-03 15:24:42.722000             🧑  作者: Mango
如果您正在使用 Keras 构建神经网络,您可能会需要导入它的默认数据集之一:MNIST,这是一个手写数字数据集,有 60,000 张训练图片和 10,000 张测试图片。
以下是在 Python 中使用 Keras 导入 MNIST 数据集的步骤。
首先,我们需要导入 Keras 库和 MNIST 数据集:
from keras.datasets import mnist
# 加载 MNIST 数据集,并将数据集分为训练数据和测试数据
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
在这里,我们使用 mnist.load_data()
方法来下载和加载 MNIST 数据集。返回的数据集具有以下属性:
train_images
和 train_labels
:这是训练集,即用于训练模型的输入数据和真实标签。test_images
和 test_labels
:这是测试集,即用于评估模型的输入数据和真实标签。现在,我们已经成功加载了 MNIST 数据集。在继续之前,让我们先看一下数据集的结构。
# 查看训练数据集的形状
print(train_images.shape) # (60000, 28, 28)
print(train_labels.shape) # (60000,)
# 查看测试数据集的形状
print(test_images.shape) # (10000, 28, 28)
print(test_labels.shape) # (10000,)
我们可以看到训练数据集和测试数据集形状相同,都有 28x28 像素的图像,但它们的数量不同。
接下来,我们可以将数据集可视化,以便更好地了解它们的结构和特征。在这个例子中,我们将显示前 10 张图像,并添加它们的真实标签。
import matplotlib.pyplot as plt
# 将数据集可视化
fig, axes = plt.subplots(2, 5, figsize=(10, 5))
fig.subplots_adjust(hspace=0.4, wspace=0.4)
for i, ax in enumerate(axes.flat):
ax.imshow(train_images[i], cmap='binary', interpolation='nearest')
ax.set(xticks=[], yticks=[], xlabel=train_labels[i])
我们可以看到这些手写数字都在黑白图像中呈现,它们的分辨率相对较低,但我们对它们的形状和特征可以做出准确的预测。
到此为止,您已经学会了如何导入 MNIST 数据集到 Keras 中,如何查看数据集的结构和特征,并可视化这些图像。接下来,您可以利用这些数据训练神经网络模型并进行预测。