📅  最后修改于: 2023-12-03 15:36:59.635000             🧑  作者: Mango
当我们训练完一个模型并保存好了它的参数,我们可以加载这个模型并使用它来进行预测或继续训练。在本文中,我们将介绍如何在 Python 中加载已保存的模型。
要加载保存的模型,我们可以使用以下代码:
import torch.nn as nn
import torch.optim as optim
# 创建模型
model = nn.Sequential(
nn.Linear(10, 100),
nn.ReLU(),
nn.Linear(100, 1)
)
# 定义优化器和损失函数
optimizer = optim.Adam(model.parameters())
criterion = nn.MSELoss()
# 加载已保存的模型
checkpoint = torch.load('model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
# 设置为评估模式
model.eval()
在这个例子中,我们首先定义了一个简单的模型,然后定义了优化器和损失函数。接着,我们使用 torch.load()
函数加载保存的参数,并使用 load_state_dict()
函数将参数加载到模型和优化器中。最后,我们可以将模型设置为评估模式,以便进行预测。
如果我们想保存训练好的模型,我们可以使用以下代码:
# 训练模型
# 定义需要保存的内容
state = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}
# 保存模型
torch.save(state, 'model.pth')
在这个例子中,我们首先定义了需要保存的内容,包括当前的训练轮数、模型和优化器的参数以及当前的损失。然后,我们使用 torch.save()
函数将这些内容保存到一个文件中。
总之,加载已保存的模型十分简单,只需使用 torch.load()
和 load_state_dict()
函数即可。保存模型也很容易,只需将需要保存的内容打包成一个字典并使用 torch.save()
函数保存即可。