📜  如何在 pytorch 中加载预训练模型 - Python (1)

📅  最后修改于: 2023-12-03 15:24:20.495000             🧑  作者: Mango

如何在 PyTorch 中加载预训练模型

PyTorch 是一个广泛使用的深度学习框架,它内置了许多已经预训练好的模型,比如 VGG,ResNet,Inception 等等。在实际应用中,经常需要使用这些预训练好的模型进行迁移学习或者 fine-tune。那么,如何在 PyTorch 中加载预训练模型呢?

1. 下载预训练模型

PyTorch 官方提供了许多已经预训练好的模型,这些模型可以直接使用,也可以继续在其基础上进行微调。在 PyTorch 官网上可以找到这些模型的下载链接。以 ResNet50 为例,模型参数下载链接为:https://download.pytorch.org/models/resnet50-19c8e357.pth。

2. 导入 PyTorch 库

在加载预训练模型前,需要先导入 PyTorch 库。

import torch
import torchvision.models as models
3. 加载预训练模型

在 PyTorch 中加载预训练模型通常包括两个步骤:首先创建一个空的模型实例,然后将下载的预训练模型参数加载到该实例中。这里以 ResNet50 为例进行说明。

# 创建一个空的 ResNet50 模型实例
model = models.resnet50(pretrained=False)

# 加载下载的预训练模型参数
model.load_state_dict(torch.load('resnet50-19c8e357.pth'))

需要注意的是,当 pretrained 参数为 True 时,resnet50 函数会自动从 PyTorch 官方服务器下载 ResNet50 的预训练模型参数。

4. 使用预训练模型进行推断

加载预训练模型后,我们可以使用该模型进行推断。以 ResNet50 为例,我们可以使用以下代码进行推断:

import torch.nn.functional as F

# 将模型设置为评估模式
model.eval()

# 加载一张图片
image = torch.randn(1, 3, 224, 224)

# 使用模型进行推断
output = model(image)
output = F.softmax(output, dim=1)

# 打印预测结果
print(output)

如果需要使用 GPU 进行加速,可以将模型和数据移动到 GPU 上进行计算:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 将数据和模型移动到 GPU 上
image = image.to(device)
model.to(device)

# 使用模型进行推断
output = model(image)
output = F.softmax(output, dim=1)

# 将数据移回 CPU,并打印预测结果
output = output.cpu().detach().numpy()
print(output)

至此,在 PyTorch 中加载预训练模型的方法就介绍完了。通过本文的介绍,我们了解了如何下载预训练模型、如何在 PyTorch 中加载预训练模型,以及如何使用预训练模型进行推断。