Keras 中的数据集
Keras 是一个Python库,广泛用于训练深度学习模型。深度学习中的常见问题之一是为开发模型找到合适的数据集。在本文中,我们将看到已包含在keras.datasets
模块中的流行数据集列表。
MNIST(10位数字分类):
该数据集用于对手写数字进行分类。它包含训练集中的60,000张图像和测试集中的 10,000 张图像。每张图片的大小为28×28 。
from keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
回报:
- x_train, x_test:灰度图像数据的无符号整数 (0-255) 数组,形状为 (num_samples, 28, 28)。
- y_train, y_test:一个无符号整数 (0-255) 数字标签数组(0-9 范围内的整数),形状为 (num_samples,)。
Fashion-MNIST(10个时尚类别的分类):
该数据集可用作 MNIST 的替代品。它由 10 个时尚类别的 60,000 张 28×28 灰度图像以及 10,000 张图像的测试集组成。类标签是:
Label | Description |
0 | T-shirt/top |
1 | Trouser |
2 | Pullover |
3 | Dress |
4 | Coat |
5 | Sandal |
6 | Shirt |
7 | Sneaker |
8 | Bag |
9 | Ankle boot |
from keras.datasets import fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
返回:
- x_train, x_test:灰度图像数据的无符号整数 (0-255) 数组,形状为 (num_samples, 28, 28)。
- y_train, y_test:一个无符号整数 (0-255) 数字标签数组(0-9 范围内的整数),形状为 (num_samples,)。
CIFAR10(10个图像标签的分类):
该数据集包含 10 种不同类别的图像,广泛用于图像分类任务。它由 50,000 个 32×32 彩色训练图像、超过 10 个类别和 10,000 个测试图像组成。数据集分为五个训练批次,每个批次有 10000 张图像。测试批次恰好包含来自每个类别的 1000 个随机选择的图像。训练批次包含随机顺序的剩余图像,但一些训练批次可能包含来自一个类的图像多于另一个。在它们之间,训练批次恰好包含来自每个类别的 5000 张图像。这些类是完全互斥的。汽车和卡车之间没有重叠。类标签是: Label Description 0 airplane 1 automobile 2 bird 3 cat 4 deer 5 dog 6 frog 7 horse 8 ship 9 truck
from keras.datasets import cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
回报:
- x_train, x_test: RGB 图像数据的无符号整数 (0-255) 数组,形状为 (num_samples, 3, 32, 32) 或 (num_samples, 32, 32, 3),分别基于 channels_first 或 channels_last 的 image_data_format 后端设置.形状中的值“3”指的是 3 个 RGB 通道。
- y_train, y_test:一个无符号整数 (0-255) 类别标签数组(0-9 范围内的整数),形状为 (num_samples, 1)。
CIFAR100(100个图像标签的分类):
该数据集包含 10 种不同类别的图像,广泛用于图像分类任务。它由 50,000 个 32×32 彩色训练图像、超过 10 个类别和 10,000 个测试图像组成。这个数据集就像 CIFAR-10,除了它有 100 个类,每个类包含 600 张图像。每个类有 500 个训练图像和 100 个测试图像。 CIFAR-100 中的 100 个类分为 20 个超类。每个图像都带有一个“精细”标签(它所属的类)和一个“粗略”标签(它所属的超类)。
from keras.datasets import cifar100
(x_train, y_train), (x_test, y_test) = cifar100.load_data(label_mode='fine')
回报:
- x_train, x_test: RGB 图像数据的无符号整数 (0-255) 数组,形状为 (num_samples, 3, 32, 32) 或 (num_samples, 32, 32, 3),分别基于 channels_first 或 channels_last 的 image_data_format 后端设置.形状中的值“3”指的是 3 个 RGB 通道。
- y_train, y_test:一个无符号整数 (0-255) 类别标签数组(0-99 范围内的整数),形状为 (num_samples, 1)。
论据:
- label_mode :“精细”或“粗糙”。
波士顿房价(回归):
该数据集取自卡内基梅隆大学维护的 StatLib 库。该数据集包含 1970 年代后期波士顿郊区不同地点的 13 个房屋属性。目标是某个位置的房屋中值(以 k$ 为单位)。训练集包含 404 个不同家庭的数据,而测试集包含 102 个不同家庭的数据
from keras.datasets import boston_housing
(x_train, y_train), (x_test, y_test) = boston_housing.load_data()
回报:
- x_train, x_test:具有不同属性的 numpy 数组,形状为 (num_samples, 13) 。
- y_train, y_test:一个 numpy 数组,由不同属性的值组成,形状为 (num_samples, )。
论据:
- 种子:用于在计算测试拆分之前对数据进行洗牌的随机种子。
- test_split :保留作为测试集的数据的一部分。
IMDB 电影评论(情感分类):
该数据集用于评论的二元分类,即正面或负面。它包含来自 IMDB 的 25,000 条电影评论,按情绪(正面/负面)标记。这些评论已经过预处理,每条评论都被编码为一系列单词索引(整数)。这些词按它们在数据集中出现的总体频率进行索引。例如,整数“5”编码数据中第 5 个最常见的词。这允许快速过滤操作,例如仅考虑前 5000 个单词作为模型词汇等。
from keras.datasets import imdb
(x_train, y_train), (x_test, y_test) = imdb.load_data()
回报:
- x_train, x_test :序列列表,即索引列表(整数)。如果 num_words 参数是特定的,则最大可能的索引值为 num_words-1。如果指定了 maxlen 参数,则最大可能的序列长度为 maxlen。
- y_train, y_test :整数标签列表(1 表示正或 0 表示负)。
论据:
- num_words(int or None) :要考虑的最常见的单词。任何不太常见的单词都将在序列数据中显示为“oov_char”值。
- skip_top(int) :要忽略的最频繁出现的单词(它们将在序列数据中显示为 oov_char 值)。
- maxlen(int) :最大序列长度。任何更长的序列都将被截断。
- seed(int) :用于可重复数据混洗的种子。
- start_char(int) :序列的开始将用这个字符标记。设置为 1 因为 0 通常是填充字符。
- oov_char(int) :由于 num_words 或 skip_top 限制而被删除的单词将被替换为该字符。
- index_from(int) :使用此索引或更高索引来索引实际单词。
路透社新闻专线主题分类:
该数据集用于多类文本分类。它由来自路透社的 11,228 条新闻专线组成,标记了超过 46 个主题。就像 IMDB 数据集一样,每条线路都被编码为一系列单词索引(相同的约定)。
from keras.datasets import reuters
(x_train, y_train), (x_test, y_test) = reuters.load_data()
回报:
- x_train, x_test :序列列表,即索引列表(整数)。如果 num_words 参数是特定的,则最大可能的索引值为 num_words-1。如果指定了 maxlen 参数,则最大可能的序列长度为 maxlen。
- y_train, y_test :整数标签列表(1 表示正或 0 表示负)。
论据:
- num_words(int or None) :要考虑的最常见的单词。任何不太常见的单词都将在序列数据中显示为“oov_char”值。
- skip_top(int) :要忽略的最频繁出现的单词(它们将在序列数据中显示为 oov_char 值)。
- maxlen(int) :最大序列长度。任何更长的序列都将被截断。
- seed(int) :用于可重复数据混洗的种子。
- start_char(int) :序列的开始将用这个字符标记。设置为 1 因为 0 通常是填充字符。
- oov_char(int) :由于 num_words 或 skip_top 限制而被删除的单词将被替换为该字符。
- index_from(int) :使用此索引或更高索引来索引实际单词。