📅  最后修改于: 2023-12-03 15:24:20.495000             🧑  作者: Mango
PyTorch 是一个广泛使用的深度学习框架,它内置了许多已经预训练好的模型,比如 VGG,ResNet,Inception 等等。在实际应用中,经常需要使用这些预训练好的模型进行迁移学习或者 fine-tune。那么,如何在 PyTorch 中加载预训练模型呢?
PyTorch 官方提供了许多已经预训练好的模型,这些模型可以直接使用,也可以继续在其基础上进行微调。在 PyTorch 官网上可以找到这些模型的下载链接。以 ResNet50 为例,模型参数下载链接为:https://download.pytorch.org/models/resnet50-19c8e357.pth。
在加载预训练模型前,需要先导入 PyTorch 库。
import torch
import torchvision.models as models
在 PyTorch 中加载预训练模型通常包括两个步骤:首先创建一个空的模型实例,然后将下载的预训练模型参数加载到该实例中。这里以 ResNet50 为例进行说明。
# 创建一个空的 ResNet50 模型实例
model = models.resnet50(pretrained=False)
# 加载下载的预训练模型参数
model.load_state_dict(torch.load('resnet50-19c8e357.pth'))
需要注意的是,当 pretrained
参数为 True
时,resnet50
函数会自动从 PyTorch 官方服务器下载 ResNet50 的预训练模型参数。
加载预训练模型后,我们可以使用该模型进行推断。以 ResNet50 为例,我们可以使用以下代码进行推断:
import torch.nn.functional as F
# 将模型设置为评估模式
model.eval()
# 加载一张图片
image = torch.randn(1, 3, 224, 224)
# 使用模型进行推断
output = model(image)
output = F.softmax(output, dim=1)
# 打印预测结果
print(output)
如果需要使用 GPU 进行加速,可以将模型和数据移动到 GPU 上进行计算:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 将数据和模型移动到 GPU 上
image = image.to(device)
model.to(device)
# 使用模型进行推断
output = model(image)
output = F.softmax(output, dim=1)
# 将数据移回 CPU,并打印预测结果
output = output.cpu().detach().numpy()
print(output)
至此,在 PyTorch 中加载预训练模型的方法就介绍完了。通过本文的介绍,我们了解了如何下载预训练模型、如何在 PyTorch 中加载预训练模型,以及如何使用预训练模型进行推断。