在机器学习中,使用scikit学习库时,我们需要将经过训练的模型保存在文件中并进行还原,以便重用它以将模型与其他模型进行比较,以在新数据上测试模型。保存数据称为序列化,而恢复数据称为反序列化。
此外,我们处理不同类型和大小的数据。某些数据集易于训练,即它们花费的时间更少,但是大小较大(大于1GB)的数据集即使在使用GPU的情况下,在本地计算机上的训练也可能需要非常长的时间。当我们在不同的项目中或以后需要相同的经过训练的数据时,为了避免浪费训练时间,请存储经过训练的模型,以便将来可以在任何时候使用它。
我们可以通过两种方式将模型保存在scikit学习中:
- Pickle字符串:pickle模块实现了一个基本但功能强大的算法,用于对Python对象结构进行序列化和反序列化。
Pickle model provides the following functions –
pickle.dump
to serialize an object hierarchy, you simply use dump().pickle.load
to deserialize a data stream, you call the loads() function.示例:让我们在虹膜数据集上应用K最近邻,然后保存模型。
import numpy as np # Load dataset from sklearn.datasets import load_iris iris = load_iris() X = iris.data y = iris.target # Split dataset into train and test X_train, X_test, y_train, y_test = \ train_test_split(X, y, test_size = 0.3, random_state = 2018) # import KNeighborsClassifier model from sklearn.neighbors import KNeighborsClassifier as KNN knn = KNN(n_neighbors = 3) # train model knn.fit(X_train, y_train)
使用泡菜将模型保存到字符串–
import pickle # Save the trained model as a pickle string. saved_model = pickle.dumps(knn) # Load the pickled model knn_from_pickle = pickle.loads(saved_model) # Use the loaded pickled model to make predictions knn_from_pickle.predict(X_test)
输出:
- 使用joblib将酸洗的模型作为文件:Joblib是酸洗的替代品,因为它对携带大型numpy数组的对象更有效。这些函数还接受类似文件的对象而不是文件名。
joblib.dump
to serialize an object hierarchyjoblib.load
to deserialize a data stream使用joblib保存到腌制文件中–
from sklearn.externals import joblib # Save the model as a pickle in a file joblib.dump(knn, 'filename.pkl') # Load the model from the file knn_from_joblib = joblib.load('filename.pkl') # Use the loaded model to make predictions knn_from_joblib.predict(X_test)
输出: