📜  pytorch 数据加载器到 numpy 数组 - Python 代码示例

📅  最后修改于: 2022-03-11 14:46:41.786000             🧑  作者: Mango

代码示例1
# index is the specific data your want in dataloader
import numpy as np
def transfer_dataloader(dataloader, index=0):
    cache_list = list(iter(dataloader))
    assert len(cache_list) > 0
    assert index < len(cache_list[0])
    result_list = np.array(list(map(lambda x: x[index].numpy(), cache_list)))
    return result_list