📅  最后修改于: 2020-12-11 05:05:31             🧑  作者: Mango
在我们将数据提供给网络之前,必须将其转换为网络所需的格式。这称为为网络准备数据。它通常包括将多维输入转换为一维向量,并对数据点进行归一化。
我们的数据集中的图像包含28 x 28像素。必须将其转换为大小为28 * 28 = 784的一维向量,才能将其馈送到我们的网络中。我们通过在向量上调用reshape方法来实现。
X_train = X_train.reshape(60000, 784)
X_test = X_test.reshape(10000, 784)
现在,我们的训练向量将由60000个数据点组成,每个数据点均由大小为784的单个维向量组成。类似地,我们的测试向量将由大小为784的单维向量的10000个数据点组成。
输入向量包含的数据当前具有介于0到255之间的离散值-灰度级。将这些像素值标准化为0到1有助于加快训练速度。当我们将使用随机梯度下降法时,对数据进行归一化也将有助于减少陷入局部最优的机会。
为了规范化数据,我们将其表示为浮点类型,然后将其除以255,如以下代码片段所示-
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255
现在让我们看一看标准化数据的样子。
要查看标准化数据,我们将调用直方图函数,如下所示:
plot.hist(X_train[0])
plot.title("Digit: {}".format(y_train[0]))
在这里,我们绘制X_train矢量的第一个元素的直方图。我们还将打印此数据点表示的数字。运行上述代码的输出如下所示-
您会发现点的密集密度值接近零。这些是图像中的黑点,这显然是图像的主要部分。其余接近白色的灰度点代表数字。您可以查看其他数字的像素分布。下面的代码在训练数据集中打印索引为2的数字的直方图。
plot.hist(X_train[2])
plot.title("Digit: {}".format(y_train[2])
运行上面的代码的输出如下所示-
比较上面的两个图,您会注意到两个图像中白色像素的分布不同,表示上面两个图像中的不同数字表示为“ 5”和“ 4”。
接下来,我们将检查完整训练数据集中的数据分布。
在我们的数据集上训练机器学习模型之前,我们应该知道数据集中唯一数字的分布。我们的图像代表10个不同的数字,范围从0到9。我们想知道数据集中的数字0、1等。我们可以使用Numpy的独特方法来获取此信息。
使用以下命令来打印唯一值的数量以及每个值的出现次数
print(np.unique(y_train, return_counts=True))
运行上面的命令时,您将看到以下输出-
(array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=uint8), array([5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]))
它显示了10个不同的值-0到9。存在5923个数字0,6742个数字1,依此类推。输出的屏幕截图显示在这里-
作为数据准备的最后一步,我们需要对数据进行编码。
我们的数据集中有十个类别。因此,我们将使用一键编码将输出分类为这十个类别。我们使用Numpy实用程序的to_categorial方法执行编码。编码输出数据后,每个数据点将转换为大小为10的一维向量。例如,数字5现在将表示为[0,0,0,0,0,1,0,0,0 ,0]。
使用以下代码对数据进行编码-
n_classes = 10
Y_train = np_utils.to_categorical(y_train, n_classes)
您可以通过打印分类的Y_train向量的前5个元素来检查编码结果。
使用以下代码打印前5个向量-
for i in range(5):
print (Y_train[i])
您将看到以下输出-
[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
第一个元素表示数字5,第二个元素表示数字0,依此类推。
最后,您还必须对测试数据进行分类,这是使用以下语句完成的:
Y_test = np_utils.to_categorical(y_test, n_classes)
在此阶段,您的数据已准备就绪,可以馈入网络。
接下来,是最重要的部分,那就是训练我们的网络模型。