📜  以 ONNX 运行时格式导出 PyTorch 模型 - Python (1)

📅  最后修改于: 2023-12-03 15:06:40.068000             🧑  作者: Mango

以 ONNX 运行时格式导出 PyTorch 模型

在深度学习中,PyTorch 是一个受欢迎的框架,但它不能用于某些低功耗设备或片上系统。ONNX 是一种格式,它可以让模型在几乎所有设备上运行。因此,本文将介绍如何利用 PyTorch 将模型导出为 ONNX 格式。

准备工作

要导出 PyTorch 模型,需要以下组件:

  • PyTorch(安装方法视平台而定)
  • ONNX(可以使用以下命令进行安装:pip install onnx
步骤
  1. 将 PyTorch 模型加载到内存中。在此之前,需要确保 PyTorch 模型已经训练完毕,并且在 PyTorch 中已经定义。以下代码是一个样例:
import torch
import torchvision

model = torchvision.models.resnet18(pretrained=True)
model.eval()
  1. 创建一个具有示例输入的虚拟张量。
batch_size = 1
input_dims = (3, 224, 224)
input_data = torch.randn(batch_size, *input_dims)
  1. 转换 PyTorch 模型为 ONNX 模型。以下代码实现了将 PyTorch 模型转换为 ONNX 模型:
import onnx
import onnx.utils
import torch.onnx
import os

# Export the PyTorch model to ONNX format
dummy_input = torch.randn(1, input_dims[0], input_dims[1], input_dims[2])
output_path = "resnet18.onnx"

# Use the "export" function to convert the PyTorch model to an ONNX model
torch.onnx.export(model, dummy_input, output_path, opset_version=11)

# Verify that the ONNX model is valid
onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)

注意,此处的 opset_version 参数是指使用的 ONNX 版本号。尽量使用最新版。

  1. 测试 ONNX 模型。以下代码使用 ONNX 运行时测试 ONNX 模型:
import onnxruntime

ort_session = onnxruntime.InferenceSession(output_path)

# Provide inputs as a list of NumPy arrays or ONNX Tensors
ort_inputs = {ort_session.get_inputs()[0].name: input_data.cpu().numpy()}

# Run the model
ort_outputs = ort_session.run(None, ort_inputs)

print(f"PyTorch model output: {model(input_data)}")
print(f"ONNX model output: {ort_outputs[0]}")

到此,我们就成功将 PyTorch 模型导出为 ONNX 模型,并成功在 ONNX 运行时中运行了它。

结论

本文介绍了如何将 PyTorch 模型导出为 ONNX 格式。ONNX 可以让模型在几乎所有设备上运行,因此值得使用。