📜  torch.utils.data.random_split(dataset, lengths) - Python (1)

📅  最后修改于: 2023-12-03 14:48:01.335000             🧑  作者: Mango

torch.utils.data.random_split(dataset, lengths) - Python

torch.utils.data.random_split(dataset, lengths) 是 PyTorch 中的一个函数,用于将一个数据集(Dataset)随机分成若干个子集。这一函数现在已在 PyTorch 的 1.7.0 版本中推出,可被广泛应用于深度学习项目中。

用法

这个函数带有两个参数:

  • dataset(必选参数):要拆分的数据集,类型为 PyTorch 的 Dataset 对象;
  • lengths(必选参数):自然数列表,用于定义每个分割出的子集的长度。比如,lengths=[2, 5] 就会将数据集分割成长度为 2 和 5 的两个子集。

这个函数将返回一个由分割好的子集组成的新的数据集对象。每个子集都是 PyTorch 中的 Subset 对象。

下面是一个使用案例:

from torch.utils.data import Dataset, DataLoader, random_split

class MyDataset(Dataset):
    def __init__(self):
        self.samples = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

    def __getitem__(self, index):
        return self.samples[index]

    def __len__(self):
        return len(self.samples)

dataset = MyDataset()

# 将数据集随机拆分成长度分别为 6 和 4 的两个子集
train_set, test_set = random_split(dataset, [6, 4])

print(f'train_set: {len(train_set)}')
print(f'test_set: {len(test_set)}')

输出:

train_set: 6
test_set: 4

我们可以看到,数据集成功被随机拆分成了长度分别为 6 和 4 的两个子集。这使得我们可以在训练和测试深度学习模型时更好地管理数据集。

总结

torch.utils.data.random_split(dataset, lengths) 是一个功能强大的函数,它可以帮助我们快速地将数据集拆分成若干个子集。它现在已在 PyTorch 1.7.0 版本中得以支持,我们可以用它进行深度学习项目开发。