在Python中将图像转换为 Torch 张量
在本文中,我们将了解如何将图像转换为 PyTorch 张量。 PyTorch 中的张量就像一个 NumPy 数组,包含相同 dtype 的元素。
张量可以是标量类型、一维或多维的。为了在 PyTorch 中将图像转换为张量,我们使用PILToTensor()和ToTensor()转换。这些转换在torchvision.transforms包中提供。使用这些转换,我们可以转换 PIL 图像或numpy.ndarray 。 numpy.ndarray必须是 [H, W, C] 格式,其中 H、W 和 C 是图像的高度、宽度和通道数。
transform = transforms.Compose([transforms.PILToTensor()])
tensor = transform(img)
此转换将PIL 图像转换为数据类型为torch.uint8的张量,范围在0 到 255之间。这里的img是一个 PIL 图像。
transform = transforms.Compose([transforms.ToTensor()])
tensor = transform(img)
此转换将任何numpy.ndarray转换为范围为 0 和 1的数据类型torch.float32的火炬张量。这里img是一个numpy.ndarray 。
方法:
- 导入所需的库。
- 读取输入图像。输入图像是 PIL 图像或 NumPy N 维数组。
- 定义将图像转换为 Torch 张量的变换。我们使用transforms.Compose()定义一个变换。您可以直接使用transforms.PILToTensor()或transforms.ToTensor() 。
- 使用上面定义的变换将图像转换为张量。
- 打印张量值。
下图在两个示例中都用作输入图像:
示例 1:
在下面的示例中,我们将 PIL 图像转换为 Torch 张量。
Python3
# Import necessary libraries
import torch
from PIL import Image
import torchvision.transforms as transforms
# Read a PIL image
image = Image.open('iceland.jpg')
# Define a transform to convert PIL
# image to a Torch tensor
transform = transforms.Compose([
transforms.PILToTensor()
])
# transform = transforms.PILToTensor()
# Convert the PIL image to Torch tensor
img_tensor = transform(image)
# print the converted Torch tensor
print(img_tensor)
Python3
# Import required libraries
import torch
import cv2
import torchvision.transforms as transforms
# Read the image
image = cv2.imread('iceland.jpg')
# Convert BGR image to RGB image
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Define a transform to convert
# the image to torch tensor
transform = transforms.Compose([
transforms.ToTensor()
])
# Convert the image to Torch tensor
tensor = transform(image)
# print the converted image tensor
print(tensor)
输出:
请注意,输出张量的数据类型是torch.uint8并且值在[0,255]范围内。
示例 2:
在此示例中,我们使用OpenCV读取 RGB 图像。使用 OpenCV 读取的图像类型是numpy.ndarray 。我们使用变换ToTensor()将其转换为火炬张量。
Python3
# Import required libraries
import torch
import cv2
import torchvision.transforms as transforms
# Read the image
image = cv2.imread('iceland.jpg')
# Convert BGR image to RGB image
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Define a transform to convert
# the image to torch tensor
transform = transforms.Compose([
transforms.ToTensor()
])
# Convert the image to Torch tensor
tensor = transform(image)
# print the converted image tensor
print(tensor)
输出:
请注意,输出张量的数据类型是torch.float32 ,值在[0, 1] 范围内。