📅  最后修改于: 2023-12-03 15:34:33.017000             🧑  作者: Mango
当我们训练了一个 PyTorch 模型并已调优好了超参数,我们需要将其保存到磁盘中,以后作为推理时使用。本篇文章将介绍如何保存 PyTorch 模型。
如果要保存整个模型,包括超参数和权重等信息,可以将模型保存为 .pth
或 .pt
格式的文件。使用 torch.save()
函数即可实现。
import torch
# 定义模型并训练...
# 保存模型
PATH = "model.pt"
torch.save(model, PATH)
这将把整个模型及其参数都存储在 model.pt
文件中。
如果仅保存模型的参数,可以将其保存为 .pth
或 .pt
格式的文件。使用 torch.save()
函数,并传递模型的 state_dict()
作为参数即可实现。
import torch
# 定义模型并训练...
# 保存模型参数
PATH = "model_parameters.pth"
torch.save(model.state_dict(), PATH)
这将把模型的参数存储在 model_parameters.pth
文件中,方便以后加载。
要加载整个模型,可以使用 torch.load()
函数,如下所示:
import torch
# 加载模型
PATH = "model.pt"
model = torch.load(PATH)
要加载模型参数,可以先创建模型,并使用 load_state_dict()
方法将参数加载到模型中。如下所示:
import torch
# 定义模型
model = Net()
# 加载模型参数
PATH = "model_parameters.pth"
model.load_state_dict(torch.load(PATH))
这样,我们就可以在 PyTorch 中轻松地保存和加载模型了。