📜  加载保存的模型 - Python (1)

📅  最后修改于: 2023-12-03 15:36:59.635000             🧑  作者: Mango

加载保存的模型 - Python

当我们训练完一个模型并保存好了它的参数,我们可以加载这个模型并使用它来进行预测或继续训练。在本文中,我们将介绍如何在 Python 中加载已保存的模型。

加载模型

要加载保存的模型,我们可以使用以下代码:

import torch.nn as nn
import torch.optim as optim

# 创建模型
model = nn.Sequential(
    nn.Linear(10, 100),
    nn.ReLU(),
    nn.Linear(100, 1)
)

# 定义优化器和损失函数
optimizer = optim.Adam(model.parameters())
criterion = nn.MSELoss()

# 加载已保存的模型
checkpoint = torch.load('model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

# 设置为评估模式
model.eval()

在这个例子中,我们首先定义了一个简单的模型,然后定义了优化器和损失函数。接着,我们使用 torch.load() 函数加载保存的参数,并使用 load_state_dict() 函数将参数加载到模型和优化器中。最后,我们可以将模型设置为评估模式,以便进行预测。

保存模型

如果我们想保存训练好的模型,我们可以使用以下代码:

# 训练模型

# 定义需要保存的内容
state = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
}

# 保存模型
torch.save(state, 'model.pth')

在这个例子中,我们首先定义了需要保存的内容,包括当前的训练轮数、模型和优化器的参数以及当前的损失。然后,我们使用 torch.save() 函数将这些内容保存到一个文件中。

总之,加载已保存的模型十分简单,只需使用 torch.load()load_state_dict() 函数即可。保存模型也很容易,只需将需要保存的内容打包成一个字典并使用 torch.save() 函数保存即可。