📜  TypeError: default_collate: batch 必须包含张量、numpy 数组、数字、dicts 或列表;成立<class 'PIL.Image.Image'>- 打字稿(1)

📅  最后修改于: 2023-12-03 15:05:38.884000             🧑  作者: Mango

TypeError: default_collate: batch 必须包含张量、numpy 数组、数字、dicts 或列表;成立<class 'PIL.Image.Image'>- 打字稿

这个错误出现在 PyTorch 训练中,通常是由于使用 DataLoader 转换数据时输入了不支持的数据类型,具体是 PIL.Image.Image。

错误原因

DataLoader 是 PyTorch 提供的一个数据转换和加载的工具,用于将数据转换为张量形式。当使用 DataLoader 转换数据时,数据必须是支持的数据类型,包括:

  • 张量 (torch.Tensor)
  • numpy 数组 (numpy.ndarray)
  • 数字 (int,float)
  • 字典 (dict)
  • 列表 (list)

而 PIL.Image.Image 类型的数据并不在支持的数据类型之列。如果使用 default_collate 类转换 PIL.Image.Image 数据,就会出现上述错误。

解决方案

可以通过自定义 collate 函数来解决此问题。方法是,在定义 DataLoader 时,指定 collate_fn 参数,将 PIL.Image.Image 类型的数据转换成可以支持的数据类型(例如 ndarray)。

以下是一个示例 collate 函数,将 PIL.Image.Image 类型的数据转换成 numpy 数组:

import numpy as np
from PIL import Image

def collate_fn(batch):
    """
    batch 是一个列表,包含了每个样本的信息,每个元素是一个 tuple (image, label),
    其中 image 是 PIL.Image.Image 类型,label 是对应的标签。
    """
    images = []
    labels = []
    for image, label in batch:
        image = np.array(image)  # 将 PIL.Image.Image 转换成 numpy 数组
        images.append(image)
        labels.append(label)
    return np.array(images), np.array(labels)

在定义 DataLoader 时,将 collate_fn 参数设置为 collate_fn 函数即可:

from torch.utils.data import DataLoader

train_loader = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=collate_fn)
总结

使用 DataLoader 时,必须注意数据类型的支持。当使用 PIL.Image.Image 类型的数据时,可以通过自定义 collate 函数来将其转换成支持的数据类型,从而避免上述错误的发生。