📅  最后修改于: 2023-12-03 15:09:14.531000             🧑  作者: Mango
在 Pytorch 中,我们可以使用 dtype
属性来获取一个张量的数据类型。具体示例如下:
import torch
# 定义一个 float 类型的张量
x = torch.tensor([1, 2, 3], dtype=torch.float32)
# 获取张量的数据类型
print(x.dtype)
运行上述代码将输出以下结果:
torch.float32
我们也可以通过 type()
函数来获取张量的数据类型:
# 获取张量的数据类型
print(type(x))
运行上述代码将输出以下结果:
<class 'torch.Tensor'>
上述代码中,torch.Tensor
是 Pytorch 中张量的基础类。
值得注意的是,当我们创建一个张量时,如果不明确指定数据类型,Pytorch 将会默认使用 torch.float32
类型。
如果我们想修改一个张量的数据类型,可以使用 to()
方法,并传入目标数据类型:
# 修改张量的数据类型为 int 类型
x = x.to(torch.int)
# 获取修改后的张量的数据类型
print(x.dtype)
运行上述代码将输出以下结果:
torch.int32
除了常见的数据类型,Pytorch 还支持一些特殊的数据类型,例如布尔类型 torch.bool
、半精度浮点数类型 torch.float16
等。我们可以通过 Pytorch 提供的 官方文档 来获取完整的数据类型列表。