📜  Python保存机器学习模型(1)

📅  最后修改于: 2023-12-03 15:19:31.826000             🧑  作者: Mango

Python保存机器学习模型

在机器学习领域中,建立模型是非常耗时费力的,而建立出好的模型更是一件不容易的事情。因此,我们需要将建立好的模型保存下来,以便后续的使用。Python中提供了多种保存机器学习模型的方式,本文将为大家介绍其中的几种方法。

保存为pickle文件

Python中的pickle模块可以将Python对象序列化为二进制文件,并在需要的时候重新反序列化回来。我们可以将机器学习模型保存为pickle文件,以后需要使用时再载入即可。

以下是一个使用pickle保存和加载模型的例子:

import pickle
from sklearn import svm, datasets

# 加载数据集
iris = datasets.load_iris()
X, y = iris.data, iris.target

# 建立SVM模型
clf = svm.SVC()

# 训练模型
clf.fit(X, y)

# 保存模型
with open('svm.pickle', 'wb') as f:
    pickle.dump(clf, f)

# 加载模型
with open('svm.pickle', 'rb') as f:
    clf = pickle.load(f)

# 使用模型进行预测
print(clf.predict(X[[0, 50, 100]]))

在上面的例子中,我们首先使用sklearn库中的load_iris()函数加载了一个数据集,然后建立了一个SVM模型并进行训练。接着,使用pickle.dump()将模型保存在svm.pickle文件中。最后,我们使用pickle.load()从文件中载入模型,并使用它进行预测。

保存为joblib文件

与pickle模块相比,joblib模块更适合大型的数据集、科学计算和建模方面的处理,因为它可以在内存中缓存numpy数组,并读/写磁盘文件。因此,将机器学习模型保存为joblib文件的方式也是非常常见的。

以下是一个使用joblib保存和加载模型的例子:

from sklearn import svm, datasets
from joblib import dump, load

# 加载数据集
iris = datasets.load_iris()
X, y = iris.data, iris.target

# 建立SVM模型
clf = svm.SVC()

# 训练模型
clf.fit(X, y)

# 保存模型
dump(clf, 'svm.joblib')

# 加载模型
clf = load('svm.joblib')

# 使用模型进行预测
print(clf.predict(X[[0, 50, 100]]))

在上面的例子中,我们依然使用sklearn库中的load_iris()函数加载数据集,建立了一个SVM模型并进行训练。接着,使用dump()函数将模型保存在svm.joblib文件中。最后,我们使用load()函数从文件中载入模型,并使用它进行预测。

保存为ONNX文件

ONNX(Open Neural Network Exchange)是一种开放式的AI模型格式,可以用于在各种AI框架之间共享模型。由于它是开放式的标准格式,因此在机器学习领域中正在变得越来越流行。

以下是一个使用ONNX保存和加载模型的例子:

import numpy as np
import onnx
import onnxruntime as ort
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from onnxmltools.convert.sklearn import to_onnx

# 加载数据集
iris = datasets.load_iris()
X, y = iris.data.astype(np.float32), iris.target.astype(np.float32)

# 建立逻辑回归模型
clf = LogisticRegression()

# 训练模型
clf.fit(X, y)

# 将模型转换为ONNX格式
onnx_model = to_onnx(clf, X.astype(np.float32))

# 保存模型
onnx.save_model(onnx_model, 'logistic_regression.onnx')

# 加载模型
sess = ort.InferenceSession('logistic_regression.onnx')

# 使用模型进行预测
input_data = X[0:3]
preds = sess.run(None, {'X': input_data})
print(np.argmax(preds, axis=1))

在上面的例子中,我们使用sklearn库中的load_iris()函数加载数据集,建立了一个逻辑回归模型并进行训练。接着,使用to_onnx()函数将模型转换为ONNX格式。最后,我们使用onnx.save_model()将模型保存在logistic_regression.onnx文件中。在加载模型时,我们使用onnxruntime库中的InferenceSession类。最后,我们使用它进行预测。