📅  最后修改于: 2023-12-03 15:36:34.577000             🧑  作者: Mango
热编码是将分类数据转换为机器学习算法可以理解的形式之一。在 TensorFlow 中,可以使用 tf.keras.utils.to_categorical
函数来实现热编码。
import tensorflow as tf
# 示例标签数据,包括四个不同类别
labels = [0, 1, 2, 3]
# 将标签进行热编码
one_hot_labels = tf.keras.utils.to_categorical(labels)
print(one_hot_labels)
输出:
array([[1., 0., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 1.]], dtype=float32)
代码解释:
tf.keras.utils.to_categorical
函数进行热编码。默认情况下,tf.keras.utils.to_categorical
函数将热编码输出为输入标签数据中所包含的类别数作为向量长度。如果需要自定义向量长度,则可以添加 num_classes
参数。
import tensorflow as tf
# 示例标签数据,包括四个不同类别
labels = [0, 1, 2, 3]
# 将标签进行热编码,自定义向量长度
one_hot_labels = tf.keras.utils.to_categorical(labels, num_classes=6)
print(one_hot_labels)
输出:
array([[1., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0.],
[0., 0., 0., 1., 0., 0.]], dtype=float32)
代码解释:
tf.keras.utils.to_categorical
函数进行热编码,向量长度为 6。在某些情况下,带标签的数据集可能包含了输入数据和对应的标签数据。此时可以使用 TensorFlow 数据集 API 来加载数据集,并使用 map
函数提取标签数据并进行热编码。
import tensorflow as tf
# 加载 MNIST 数据集
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# 将标签进行热编码
train_labels = tf.keras.utils.to_categorical(train_labels)
test_labels = tf.keras.utils.to_categorical(test_labels)
# 将数据转换为 TensorFlow Dataset
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels))
# 定义 map 函数,提取标签数据并进行热编码
def preprocess(image, label):
label = tf.one_hot(label, depth=10)
return image, label
train_dataset = train_dataset.map(preprocess)
test_dataset = test_dataset.map(preprocess)
print(train_dataset)
print(test_dataset)
输出:
<MapDataset shapes: ((28, 28), (10,)), types: (tf.uint8, tf.float32)>
<MapDataset shapes: ((28, 28), (10,)), types: (tf.uint8, tf.float32)>
代码解释: