📅  最后修改于: 2023-12-03 15:21:18.332000             🧑  作者: Mango
save_model
函数是 XGBoost 库中一个用于保存模型的 Python API。它允许用户将已训练的模型保存到硬盘上,以备将来使用。本文将介绍如何使用 save_model
来保存 XGBoost 模型,并说明一些注意点。
save_model
函数将模型保存为二进制格式,并使用 pickle 序列化模型对象。它的参数包括:
目前,XGBoost 支持两种保存格式:
当 dump_format='auto'
时,函数会自动根据文件后缀名选择格式。
在调用 save_model
之前,我们需要先使用 train
函数训练出一个模型。下面是一个简单的示例:
import xgboost as xgb
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
# 加载数据
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2)
# 定义参数
params = {
'max_depth': 3,
'eta': 0.1,
'objective': 'multi:softmax',
'num_class': 3
}
# 训练模型
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test)
model = xgb.train(params, dtrain, num_boost_round=10, evals=[(dtest, 'test')])
在模型训练完成后,我们可以使用 save_model
将模型保存到硬盘上:
# 保存模型
model.save_model('my_model.bin')
此时,当前目录下就会生成一个名为 my_model.bin
的文件,它就是我们刚刚保存的 XGBoost 模型。如果我们想要读取这个模型并进行预测,可以使用 xgb.Booster
类中的 load_model
函数:
# 读取模型
loaded_model = xgb.Booster()
loaded_model.load_model('my_model.bin')
# 预测
dtest = xgb.DMatrix(X_test)
preds = loaded_model.predict(dtest)
这样,我们就可以使用之前训练好的模型进行新数据的预测了。
save_model
只能用于保存 XGBoost 模型,不能用于保存 Booster 模型。如果要保存 Booster 模型,可以使用 booster().save_model
或 booster.dump_model
函数。pickle
库中的 HIGHEST_PROTOCOL
,例如 pickle.dump(model, open('my_model.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
。save_model
函数会抛出一个 PickleError 异常,解决方法是:将训练线程数设置为 1 或关闭 Pickle 多进程支持。save_model
是一个重要的 Python API,它可以使我们更加方便地保存和读取 XGBoost 模型。在使用它时,需要注意保存格式和模型对象的兼容性问题。