📅  最后修改于: 2023-12-03 15:05:32.363000             🧑  作者: Mango
CIFAR-10 是一个经典的图像分类数据集,包括 60,000 个 32x32 像素的彩色图像,分为 10 个不同的类别,每类有 6,000 个图片。本文将介绍如何使用 TensorFlow 进行 CIFAR-10 图像分类。
准备数据
tf.data.Dataset
API 加载数据。具体实现可以参照以下代码片段:import tensorflow as tf
BATCH_SIZE = 128
def load_data():
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()
# Normalize pixel values to be between 0 and 1
train_images, test_images = train_images / 255.0, test_images / 255.0
# Convert labels to categorical one-hot encoding
train_labels = tf.keras.utils.to_categorical(train_labels, num_classes=10)
test_labels = tf.keras.utils.to_categorical(test_labels, num_classes=10)
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
train_dataset = train_dataset.shuffle(buffer_size=50000).batch(BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels))
test_dataset = test_dataset.batch(BATCH_SIZE)
return train_dataset, test_dataset
建立模型
tf.keras
建立卷积神经网络模型,可以参照以下代码片段:from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
NUM_CLASSES = 10
def build_model():
model = tf.keras.Sequential([
Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
MaxPooling2D((2, 2)),
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D((2, 2)),
Conv2D(64, (3, 3), activation='relu'),
Flatten(),
Dense(64, activation='relu'),
Dropout(0.5),
Dense(NUM_CLASSES, activation='softmax')
])
return model
训练模型
tf.keras.Model
中的 fit()
函数,可以参照以下代码片段:EPOCHS = 10
def train(model, train_dataset, test_dataset):
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
model.fit(train_dataset, epochs=EPOCHS,
validation_data=test_dataset, verbose=2)
评估模型
tf.keras.Model
中的 evaluate()
函数评估模型,可以参照以下代码片段:def evaluate(model, test_dataset):
loss, accuracy = model.evaluate(test_dataset, verbose=2)
print(f'Test loss: {loss}, Test accuracy: {accuracy}')
预测
tf.keras.Model
中的 predict()
函数预测新的图片,可以参照以下代码片段:def predict(model, new_images):
predictions = model.predict(new_images)
return predictions
本文介绍了如何使用 TensorFlow 对 CIFAR-10 图像分类,涉及数据加载、模型建立、训练、评估和预测。如果你希望深入了解 TensorFlow,建议查看 TensorFlow 的官方文档。