📜  如何获取 Pytorch 张量的数据类型?(1)

📅  最后修改于: 2023-12-03 15:09:14.531000             🧑  作者: Mango

如何获取 Pytorch 张量的数据类型?

在 Pytorch 中,我们可以使用 dtype 属性来获取一个张量的数据类型。具体示例如下:

import torch

# 定义一个 float 类型的张量
x = torch.tensor([1, 2, 3], dtype=torch.float32)

# 获取张量的数据类型
print(x.dtype)

运行上述代码将输出以下结果:

torch.float32

我们也可以通过 type() 函数来获取张量的数据类型:

# 获取张量的数据类型
print(type(x))

运行上述代码将输出以下结果:

<class 'torch.Tensor'>

上述代码中,torch.Tensor 是 Pytorch 中张量的基础类。

值得注意的是,当我们创建一个张量时,如果不明确指定数据类型,Pytorch 将会默认使用 torch.float32 类型。

如果我们想修改一个张量的数据类型,可以使用 to() 方法,并传入目标数据类型:

# 修改张量的数据类型为 int 类型
x = x.to(torch.int)

# 获取修改后的张量的数据类型
print(x.dtype)

运行上述代码将输出以下结果:

torch.int32

除了常见的数据类型,Pytorch 还支持一些特殊的数据类型,例如布尔类型 torch.bool、半精度浮点数类型 torch.float16 等。我们可以通过 Pytorch 提供的 官方文档 来获取完整的数据类型列表。