如何在 sklearn – Python中使用 datasets.fetch_mldata()?
mldata.org没有强制约定存储数据或命名数据集中的列。此函数的默认行为适用于以下提到的大多数常见情况:
- 存储在列中的数据值是“数据”,存储在列中的目标值是“标签”。
- 第一个表存储目标,第二个存储数据。
- 数据数组存储为特征和样本,需要转置以匹配sklearn 标准。
获取机器学习数据集,如果文件不存在,则自动从 mldata.org 下载。
sklearn.datasets 包使用函数直接加载数据集: sklearn.datasets.fetch_mldata( )
Syntax: sklearn.datasets.fetch_mldata(dataname, target_name=’label’, data_name=’data’, transpose_data=True, data_home=None)
Parameters:
- dataname: (<str>) It is the name of the dataset on mldata.org, e.g: “Iris” , “mnist”, “leukemia”, etc.
- target_name: (optional, default: ‘label’) It accepts the name or index of the column containing the target values and needed to pass the default values of the label.
- data_name: (optional, default: ‘data’) It accepts the name or index of the column containing the data and needed to pass default values of data.
- transpose_data: (optional, default: True) The default value passed is true, and if True, it transposes the loaded data.
- data_home: (optional, default: None) It loads cache folder for the datasets. By default, all sklearn data is stored in ‘~/scikit_learn_data’ subfolders.
Returns: data, (Bunch) Interesting attributes are: ‘data’, data to learn, ‘target’, classification labels, ‘DESCR’, description of the dataset, and ‘COL_NAMES’, the original names of the dataset columns.
让我们看看例子:
示例 1:从需要转置的 mldata 中加载“iris”数据集。
Python3
# import fetch_mldata function
from sklearn.datasets.mldata import fetch_mldata
# load data and transpose data
iris = fetch_mldata('iris',
transpose_data = False)
# iris data is very large
# so print the dataset shape
# print(iris)
print(iris.data.shape)
Python3
# import fetch_mldata function
from sklearn.datasets.mldata import fetch_mldata
# load data
mnist = fetch_mldata('MNIST original')
# mnist data is very large
# so print the shape of data
print(mnist.data.shape)
输出:
(4,150)
示例 2:从 mldata 加载 MNIST 数字识别数据集。
Python3
# import fetch_mldata function
from sklearn.datasets.mldata import fetch_mldata
# load data
mnist = fetch_mldata('MNIST original')
# mnist data is very large
# so print the shape of data
print(mnist.data.shape)
输出:
(70000, 784)
注意:这篇文章是根据 Scikit-learn(版本 0.19)。