📜  UserWarning:给定的 NumPy 数组不可写,PyTorch 不支持不可写张量. (1)

📅  最后修改于: 2023-12-03 15:35:32.541000             🧑  作者: Mango

PyTorch不支持不可写张量

在使用PyTorch时,您可能会看到以下警告信息:

UserWarning: Given numpy array is not writeable, PyTorch doesn't support non-writeable tensors.

这个警告信息通常意味着您正在尝试将一个不可写的NumPy array转换为PyTorch张量。由于PyTorch张量需要可写性,因此不可写的张量会导致错误。

什么是NumPy array?

NumPy是Python中广泛使用的一个科学计算库。它允许您使用高效的数组和矩阵操作来进行数字计算、数据分析、机器学习等任务。NumPy中最重要的对象是多维数组对象(即ndarray),它是一种可变大小的数组,也是PyTorch中张量的底层实现。

为什么PyTorch需要可写性?

PyTorch张量实际上是计算图的一部分。当您创建一个张量时,PyTorch会在计算图中为该张量分配内存,并且该内存只能被该张量引用。该张量的值可以随着计算图的运行而改变,因此必须保证该值是可写的。

如何避免该警告信息?

要避免该警告信息,您需要确保将可写的张量传递给PyTorch,而不是不可写的NumPy数组。

一种简单的方法是使用torch.from_numpy()函数来将NumPy数组转换为张量。例如,以下代码将一个可写的NumPy数组转换为张量:

import numpy as np
import torch

# Create a writable NumPy array
arr = np.zeros((3, 3))
arr.flags.writeable = True  # Set the array to be writable

# Convert the NumPy array to a PyTorch tensor
tensor = torch.from_numpy(arr)

另一种方法是确保使用NumPy数组的可写副本。例如,以下代码将使用copy()函数创建NumPy数组的可写副本:

import numpy as np
import torch

# Create a non-writable NumPy array
arr = np.zeros((3, 3))
arr.flags.writeable = False

# Convert the NumPy array to a writable copy
writable_arr = arr.copy()
writable_arr.flags.writeable = True

# Convert the writable copy to a PyTorch tensor
tensor = torch.from_numpy(writable_arr)

无论您使用哪种方法,都应该避免在使用PyTorch时使用不可写的NumPy数组。

总之,PyTorch不支持不可写的 NumPy 数组。对于在PyTorch中使用NumPy数组,您需要确保始终使用可写的数组。