📅  最后修改于: 2023-12-03 14:48:01.335000             🧑  作者: Mango
torch.utils.data.random_split(dataset, lengths)
是 PyTorch 中的一个函数,用于将一个数据集(Dataset)随机分成若干个子集。这一函数现在已在 PyTorch 的 1.7.0 版本中推出,可被广泛应用于深度学习项目中。
这个函数带有两个参数:
dataset
(必选参数):要拆分的数据集,类型为 PyTorch 的 Dataset
对象;lengths
(必选参数):自然数列表,用于定义每个分割出的子集的长度。比如,lengths=[2, 5]
就会将数据集分割成长度为 2 和 5 的两个子集。这个函数将返回一个由分割好的子集组成的新的数据集对象。每个子集都是 PyTorch 中的 Subset
对象。
下面是一个使用案例:
from torch.utils.data import Dataset, DataLoader, random_split
class MyDataset(Dataset):
def __init__(self):
self.samples = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
def __getitem__(self, index):
return self.samples[index]
def __len__(self):
return len(self.samples)
dataset = MyDataset()
# 将数据集随机拆分成长度分别为 6 和 4 的两个子集
train_set, test_set = random_split(dataset, [6, 4])
print(f'train_set: {len(train_set)}')
print(f'test_set: {len(test_set)}')
输出:
train_set: 6
test_set: 4
我们可以看到,数据集成功被随机拆分成了长度分别为 6 和 4 的两个子集。这使得我们可以在训练和测试深度学习模型时更好地管理数据集。
torch.utils.data.random_split(dataset, lengths)
是一个功能强大的函数,它可以帮助我们快速地将数据集拆分成若干个子集。它现在已在 PyTorch 1.7.0 版本中得以支持,我们可以用它进行深度学习项目开发。