📜  PyBrain – 为数据集导入数据

📅  最后修改于: 2022-05-13 01:55:17.662000             🧑  作者: Mango

PyBrain – 为数据集导入数据

在本文中,我们将学习如何在 PyBrain 中为数据集导入数据。

数据集是用于在网络上测试、验证和训练的数据。要使用的数据集类型取决于我们将使用机器学习执行的任务。 Pybrain 支持的最常用的数据集是SupervisedDataSetClassificationDataSet 。顾名思义,ClassificationDataSet 用于分类问题,SupervisedDataSet 用于监督学习任务。

方法 1:使用 CSV 文件导入数据集的数据

这是从 CSV 文件导入任何数据集的最简单方法。为此,我们将使用 Pandas,因此必须导入 Pandas 库。

考虑我们要导入的 CSV 文件是 price.csv。

Python3
import pandas as pd
  
print('Read data...')
  
# enter the complete path of the csv file
df = pd.read_csv('../price.csv',header=0).head(1000) 
data = df.values


Python3
from pybrain.datasets import ClassificationDataSet
from sklearn import datasets
  
nums = datasets.load_iris()
x, y = nums.data, nums.target
ds = ClassificationDataSet(4, 1, nb_classes=3)
for j in range(len(x)):
    ds.addSample(x[j], y[j])
ds


Python3
from sklearn import datasets
from pybrain.datasets import ClassificationDataSet
  
digits = datasets.load_digits()
X, y = digits.data, digits.target
ds = ClassificationDataSet(64, 1, nb_classes=10)
for i in range(len(X)):
    ds.addSample(ravel(X[i]), y[i])


Python3
import sklearn as sk
sk.datasets.fetch_california_housing


Python3
from sklearn.datasets import make_moon
from matplotlib import pyplot as plt
from matplotlib import style
   
X, y = make_moons(n_samples = 1000, noise = 0.1)
plt.scatter(X[:, 0], X[:, 1], s = 40, color ='g')
plt.xlabel("X")
plt.ylabel("Y")
   
plt.show()
plt.clf()


方法 2:使用 Sklearn导入数据集的数据

Sklearn 库中有许多预制数据集。根据所需的数据集类型,可以使用三种主要类型的数据集接口来获取数据集。

  • 数据集加载器——它们可用于加载小型标准数据集,如玩具数据集部分所述。

示例 1:加载 Iris 数据集

Python3

from pybrain.datasets import ClassificationDataSet
from sklearn import datasets
  
nums = datasets.load_iris()
x, y = nums.data, nums.target
ds = ClassificationDataSet(4, 1, nb_classes=3)
for j in range(len(x)):
    ds.addSample(x[j], y[j])
ds

输出:

示例 2:加载数字数据集

Python3

from sklearn import datasets
from pybrain.datasets import ClassificationDataSet
  
digits = datasets.load_digits()
X, y = digits.data, digits.target
ds = ClassificationDataSet(64, 1, nb_classes=10)
for i in range(len(X)):
    ds.addSample(ravel(X[i]), y[i])

输出:

  • 数据集提取器——它们可用于下载和加载更大的数据集

例子:

Python3

import sklearn as sk
sk.datasets.fetch_california_housing

输出:

  • 数据集生成函数——它们可用于生成受控合成数据集,如生成的数据集部分所述。这些函数返回一个元组 (X, y),它由一个 n_samples * n_features NumPy 数组 X 和一个长度为 n_samples 的包含目标 y 的数组组成。

例子:

Python3

from sklearn.datasets import make_moon
from matplotlib import pyplot as plt
from matplotlib import style
   
X, y = make_moons(n_samples = 1000, noise = 0.1)
plt.scatter(X[:, 0], X[:, 1], s = 40, color ='g')
plt.xlabel("X")
plt.ylabel("Y")
   
plt.show()
plt.clf()

输出: