📜  如何在 PyTorch 中标准化图像?

📅  最后修改于: 2022-05-13 01:54:33.261000             🧑  作者: Mango

如何在 PyTorch 中标准化图像?

图像变换是将图像像素的原始值改变为一组新值的过程。我们对图像进行的一种转换是将图像转换为 PyTorch 张量。当图像转换为 PyTorch 张量时,像素值在 0.0 和 1.0 之间缩放。在 PyTorch 中,可以使用torchvision.transforms.ToTensor()完成这种转换。它将像素范围为 [0, 255] 的 PIL 图像转换为范围为 [0.0, 1.0] 的形状为 (C, H, W) 的 PyTorch FloatTensor。

当我们使用深度神经网络时,图像的归一化是一个非常好的做法。对图像进行归一化意味着将图像转换为图像的均值和标准差分别变为 0.0 和 1.0 的值。为此,首先从每个输入通道中减去通道均值,然后将结果除以通道标准偏差。

output[channel] = (input[channel] - mean[channel]) / std[channel]

为什么我们要标准化图像?

归一化有助于获得一定范围内的数据并减少偏度,这有助于更快更好地学习。归一化还可以解决梯度下降和爆炸的问题。

在 PyTorch 中标准化图像

PyTorch 中的标准化是使用torchvision.transforms.Normalize() 完成的。这使用均值和标准差对张量图像进行归一化。

方法:

我们将在 PyTorch 中规范化图像时执行以下步骤:

  • 加载和可视化图像并绘制像素值。
  • 使用torchvision.transforms.ToTensor()将图像转换为张量
  • 计算平均值和标准偏差 (std)
  • 使用torchvision.transforms.Normalize()规范化图像。
  • 可视化标准化图像。
  • 归一化后计算均值和标准差并验证它们。

示例:加载图像

输入图像:

使用 PIL 加载上述输入图像。我们在我们的程序中使用了上面的 Koala.jpg 图像。并绘制图像的像素值。

Python3
# python code to load and visualize 
# an image
  
# import necessary libraries
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
  
# load the image
img_path = 'Koala.jpg'
img = Image.open(img_path)
  
# convert PIL image to numpy array
img_np = np.array(img)
  
# plot the pixel values
plt.hist(img_np.ravel(), bins=50, density=True)
plt.xlabel("pixel values")
plt.ylabel("relative frequency")
plt.title("distribution of pixels")


Python3
# Python code for converting PIL Image to
# PyTorch Tensor image and plot pixel values
  
# import necessary libraries
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
  
# define custom transform function
transform = transforms.Compose([
    transforms.ToTensor()
])
  
# transform the pIL image to tensor 
# image
img_tr = transform(img)
  
# Convert tensor image to numpy array
img_np = np.array(img_tr)
  
# plot the pixel values
plt.hist(img_np.ravel(), bins=50, density=True)
plt.xlabel("pixel values")
plt.ylabel("relative frequency")
plt.title("distribution of pixels")


Python3
# Python code to calculate mean and std
# of image
  
# get tensor image
img_tr = transform(img)
  
# calculate mean and std
mean, std = img_tr.mean([1,2]), img_tr.std([1,2])
  
# print mean and std
print("mean and std before normalize:")
print("Mean of the image:", mean)
print("Std of the image:", std)


Python3
# python code to normalize the image
  
  
from torchvision import transforms
  
# define custom transform
# here we are using our calculated
# mean & std
transform_norm = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])
  
# get normalized image
img_normalized = transform_norm(img)
  
# convert normalized image to numpy
# array
img_np = np.array(img_normalized)
  
# plot the pixel values
plt.hist(img_np.ravel(), bins=50, density=True)
plt.xlabel("pixel values")
plt.ylabel("relative frequency")
plt.title("distribution of pixels")


Python3
# Python Code to visualize normalized image
  
# get normalized image
img_normalized = transform_norm(img)
  
# convert tis image to numpy array
img_normalized = np.array(img_normalized)
  
# transpose from shape of (3,,) to shape of (,,3)
img_normalized = img_normalized.transpose(1, 2, 0)
  
# display the normalized image
plt.imshow(img_normalized)
plt.xticks([])
plt.yticks([])


Python3
# Python code to calculate mean and std
# of normalized image
  
# get normalized image
img_nor = transform_norm(img)
  
# cailculate mean and std
mean, std = img_nor.mean([1,2]), img_nor.std([1,2])
  
# print mean and std
print("Mean and Std of normalized image:")
print("Mean of the image:", mean)
print("Std of the image:", std)


输出:

我们发现 RGB 图像的像素值范围从 0 到 255。

使用torchvision.transforms.ToTensor()将图像转换为张量

使用ToTensor()将 PIL 图像转换为 PyTorch 张量并绘制此张量图像的像素值。我们定义了我们的变换函数来将 PIL 图像转换为 PyTorch 张量图像。

蟒蛇3

# Python code for converting PIL Image to
# PyTorch Tensor image and plot pixel values
  
# import necessary libraries
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
  
# define custom transform function
transform = transforms.Compose([
    transforms.ToTensor()
])
  
# transform the pIL image to tensor 
# image
img_tr = transform(img)
  
# Convert tensor image to numpy array
img_np = np.array(img_tr)
  
# plot the pixel values
plt.hist(img_np.ravel(), bins=50, density=True)
plt.xlabel("pixel values")
plt.ylabel("relative frequency")
plt.title("distribution of pixels")

输出:

我们发现张量图像的像素值范围从 0.0 到 1.0。我们注意到 RBG 和张量图像的像素分布看起来相同,但像素值范围不同。

计算平均值和标准偏差 (std)

我们计算图像的均值和标准差。

蟒蛇3



# Python code to calculate mean and std
# of image
  
# get tensor image
img_tr = transform(img)
  
# calculate mean and std
mean, std = img_tr.mean([1,2]), img_tr.std([1,2])
  
# print mean and std
print("mean and std before normalize:")
print("Mean of the image:", mean)
print("Std of the image:", std)

输出:

在这里,我们计算了所有三个通道红色、绿色和蓝色的图像的均值和标准差。这些值在归一化之前。我们将使用这些值来规范化图像。我们将这些值与归一化后的值进行比较。

使用torchvision.transforms.Normalize()规范化图像

为了对图像进行归一化,这里我们使用上面计算出的图像的均值和标准差。如果图像与 ImageNet 图像相似,我们也可以使用 ImageNet 数据集的均值和标准差。 ImageNet 的均值和标准差为:均值 = [0.485, 0.456, 0.406] 和标准差 = [0.229, 0.224, 0.225]。如果图像与 ImageNet 不相似,比如医学图像,那么总是建议计算数据集的均值和标准差,并使用它们来规范化图像。

蟒蛇3

# python code to normalize the image
  
  
from torchvision import transforms
  
# define custom transform
# here we are using our calculated
# mean & std
transform_norm = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])
  
# get normalized image
img_normalized = transform_norm(img)
  
# convert normalized image to numpy
# array
img_np = np.array(img_normalized)
  
# plot the pixel values
plt.hist(img_np.ravel(), bins=50, density=True)
plt.xlabel("pixel values")
plt.ylabel("relative frequency")
plt.title("distribution of pixels")

输出:

我们已经使用我们计算的平均值和标准对图像进行了标准化。上面的输出显示了归一化图像的像素值分布。我们可以注意到张量图像(归一化之前)和归一化图像的像素分布之间的差异。

可视化标准化图像

现在可视化标准化图像。



蟒蛇3

# Python Code to visualize normalized image
  
# get normalized image
img_normalized = transform_norm(img)
  
# convert tis image to numpy array
img_normalized = np.array(img_normalized)
  
# transpose from shape of (3,,) to shape of (,,3)
img_normalized = img_normalized.transpose(1, 2, 0)
  
# display the normalized image
plt.imshow(img_normalized)
plt.xticks([])
plt.yticks([])

输出:

我们可以注意到,输入图像和归一化图像之间存在明显差异。

归一化后计算均值和标准差

我们再次计算归一化图像/数据集的均值和标准差。现在归一化后,均值应为 0.0,标准值为 1.0。

蟒蛇3

# Python code to calculate mean and std
# of normalized image
  
# get normalized image
img_nor = transform_norm(img)
  
# cailculate mean and std
mean, std = img_nor.mean([1,2]), img_nor.std([1,2])
  
# print mean and std
print("Mean and Std of normalized image:")
print("Mean of the image:", mean)
print("Std of the image:", std)

输出:

这里我们发现归一化后均值和标准差的值分别为 0.0 和 1.0。这验证了归一化后图像均值和标准差分别变为 0 和 1。