📌  相关文章
📜  torchvision.transforms - Python (1)

📅  最后修改于: 2023-12-03 15:20:40.001000             🧑  作者: Mango

torchivision.transforms - Python

torchvision.transforms是PyTorch中专门用于处理图像数据的模块。它提供了一系列用于数据增强、裁剪、变换等操作的函数和类,可以帮助我们更方便地对图像数据进行预处理。

安装

torchvision.transformstorchvision的子模块,安装torchvision即可使用。可以通过下面的命令安装:

pip install torchvision
常用函数

Compose

Composetorchvision.transforms中最常用的函数之一,它可以将一系列的变换操作组合在一起,用于组合多种预处理方法。下面是一个例子:

from torchvision import transforms

transform = transforms.Compose([transforms.Resize((224, 224)),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                     std=[0.229, 0.224, 0.225])
                               ])

这段代码定义了一个transform对象,通过Compose函数将三个变换操作依次组合在一起,分别是将图像大小调整为$224\times224$,将图像转换为张量以便用于神经网络输入,并对每个通道进行归一化操作。

RandomCrop

RandomCrop可以对输入的图像进行随机的裁剪操作,可用于增加数据的多样性:

transform = transforms.RandomCrop(size=(224,224), padding=4)

上述代码对图像进行了$224\times224$的随机裁剪,padding参数是可选的,当需要在边缘填充数据时可以使用。

RandomHorizontalFlip

RandomHorizontalFlip可以随机地将图像水平翻转,用于增加数据的多样性:

transform = transforms.RandomHorizontalFlip(p=0.5)

上述代码以$p=0.5$的概率对图像进行水平翻转。

ColorJitter

ColorJitter可以对输入的图像进行颜色抖动操作,可用于增加数据的多样性:

transform = transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)

上述代码对图像进行了亮度、对比度、饱和度、色相四个方面的随机变换,可用于模拟不同光照条件下的图像。

使用示例
from PIL import Image
from torchvision import transforms

# 定义变换
transform = transforms.Compose([transforms.Resize((256, 256)),
                                transforms.RandomCrop(size=(224,224), padding=4),
                                transforms.RandomHorizontalFlip(p=0.5),
                                transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                     std=[0.229, 0.224, 0.225])
                               ])

# 读取图像
img = Image.open('test.jpg')

# 应用变换
img_transformed = transform(img)

# 显示结果
import matplotlib.pyplot as plt

img_transformed = img_transformed.permute(1, 2, 0)
plt.imshow(img_transformed)
plt.show()

上述代码将读入一张图像,并对其应用上述的变换操作,最终展示结果图像。