📜  pytorch 保存模型 - Python (1)

📅  最后修改于: 2023-12-03 15:34:33.017000             🧑  作者: Mango

PyTorch 保存模型 - Python

当我们训练了一个 PyTorch 模型并已调优好了超参数,我们需要将其保存到磁盘中,以后作为推理时使用。本篇文章将介绍如何保存 PyTorch 模型。

保存整个模型

如果要保存整个模型,包括超参数和权重等信息,可以将模型保存为 .pth.pt 格式的文件。使用 torch.save() 函数即可实现。

import torch

# 定义模型并训练...

# 保存模型
PATH = "model.pt"
torch.save(model, PATH)

这将把整个模型及其参数都存储在 model.pt 文件中。

保存模型参数

如果仅保存模型的参数,可以将其保存为 .pth.pt 格式的文件。使用 torch.save() 函数,并传递模型的 state_dict() 作为参数即可实现。

import torch

# 定义模型并训练...

# 保存模型参数
PATH = "model_parameters.pth"
torch.save(model.state_dict(), PATH)

这将把模型的参数存储在 model_parameters.pth 文件中,方便以后加载。

加载模型

要加载整个模型,可以使用 torch.load() 函数,如下所示:

import torch

# 加载模型
PATH = "model.pt"
model = torch.load(PATH)

要加载模型参数,可以先创建模型,并使用 load_state_dict() 方法将参数加载到模型中。如下所示:

import torch

# 定义模型
model = Net()

# 加载模型参数
PATH = "model_parameters.pth"
model.load_state_dict(torch.load(PATH))

这样,我们就可以在 PyTorch 中轻松地保存和加载模型了。