📅  最后修改于: 2023-12-03 15:20:40.001000             🧑  作者: Mango
torchvision.transforms
是PyTorch中专门用于处理图像数据的模块。它提供了一系列用于数据增强、裁剪、变换等操作的函数和类,可以帮助我们更方便地对图像数据进行预处理。
torchvision.transforms
是torchvision
的子模块,安装torchvision
即可使用。可以通过下面的命令安装:
pip install torchvision
Compose
Compose
是torchvision.transforms
中最常用的函数之一,它可以将一系列的变换操作组合在一起,用于组合多种预处理方法。下面是一个例子:
from torchvision import transforms
transform = transforms.Compose([transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
这段代码定义了一个transform
对象,通过Compose
函数将三个变换操作依次组合在一起,分别是将图像大小调整为$224\times224$,将图像转换为张量以便用于神经网络输入,并对每个通道进行归一化操作。
RandomCrop
RandomCrop
可以对输入的图像进行随机的裁剪操作,可用于增加数据的多样性:
transform = transforms.RandomCrop(size=(224,224), padding=4)
上述代码对图像进行了$224\times224$的随机裁剪,padding参数是可选的,当需要在边缘填充数据时可以使用。
RandomHorizontalFlip
RandomHorizontalFlip
可以随机地将图像水平翻转,用于增加数据的多样性:
transform = transforms.RandomHorizontalFlip(p=0.5)
上述代码以$p=0.5$的概率对图像进行水平翻转。
ColorJitter
ColorJitter
可以对输入的图像进行颜色抖动操作,可用于增加数据的多样性:
transform = transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)
上述代码对图像进行了亮度、对比度、饱和度、色相四个方面的随机变换,可用于模拟不同光照条件下的图像。
from PIL import Image
from torchvision import transforms
# 定义变换
transform = transforms.Compose([transforms.Resize((256, 256)),
transforms.RandomCrop(size=(224,224), padding=4),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 读取图像
img = Image.open('test.jpg')
# 应用变换
img_transformed = transform(img)
# 显示结果
import matplotlib.pyplot as plt
img_transformed = img_transformed.permute(1, 2, 0)
plt.imshow(img_transformed)
plt.show()
上述代码将读入一张图像,并对其应用上述的变换操作,最终展示结果图像。