📅  最后修改于: 2023-12-03 15:06:40.068000             🧑  作者: Mango
在深度学习中,PyTorch 是一个受欢迎的框架,但它不能用于某些低功耗设备或片上系统。ONNX 是一种格式,它可以让模型在几乎所有设备上运行。因此,本文将介绍如何利用 PyTorch 将模型导出为 ONNX 格式。
要导出 PyTorch 模型,需要以下组件:
pip install onnx
)import torch
import torchvision
model = torchvision.models.resnet18(pretrained=True)
model.eval()
batch_size = 1
input_dims = (3, 224, 224)
input_data = torch.randn(batch_size, *input_dims)
import onnx
import onnx.utils
import torch.onnx
import os
# Export the PyTorch model to ONNX format
dummy_input = torch.randn(1, input_dims[0], input_dims[1], input_dims[2])
output_path = "resnet18.onnx"
# Use the "export" function to convert the PyTorch model to an ONNX model
torch.onnx.export(model, dummy_input, output_path, opset_version=11)
# Verify that the ONNX model is valid
onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)
注意,此处的 opset_version
参数是指使用的 ONNX 版本号。尽量使用最新版。
import onnxruntime
ort_session = onnxruntime.InferenceSession(output_path)
# Provide inputs as a list of NumPy arrays or ONNX Tensors
ort_inputs = {ort_session.get_inputs()[0].name: input_data.cpu().numpy()}
# Run the model
ort_outputs = ort_session.run(None, ort_inputs)
print(f"PyTorch model output: {model(input_data)}")
print(f"ONNX model output: {ort_outputs[0]}")
到此,我们就成功将 PyTorch 模型导出为 ONNX 模型,并成功在 ONNX 运行时中运行了它。
本文介绍了如何将 PyTorch 模型导出为 ONNX 格式。ONNX 可以让模型在几乎所有设备上运行,因此值得使用。