官方类型如下:
方法一
import torch
my_tensor = torch.randn(2, 4) # 默认为float32类型
print("原始:", my_tensor)
print('____________________________________________________________')
# torch.long() 将tensor投射为long类型
long_tensor = my_tensor.long()
print(long_tensor)
print('____________________________________________________________')
# torch.half()将tensor投射为半精度浮点类型
half_tensor = my_tensor.half()
print(half_tensor)
print('____________________________________________________________')
# torch.int()将该tensor投射为int类型
int_tensor = my_tensor.int()
print(int_tensor)
print('____________________________________________________________')
# torch.double()将该tensor投射为double类型
double_tensor = my_tensor.double()
print(double_tensor)
print('____________________________________________________________')
# torch.float()将该tensor投射为float类型
float_tensor = my_tensor.float()
print(float_tensor)
print('____________________________________________________________')
# torch.char()将该tensor投射为char类型
char_tensor = my_tensor.char()
print(char_tensor)
print('____________________________________________________________')
# torch.byte()将该tensor投射为byte类型
byte_tensor = my_tensor.byte()
print(byte_tensor)
print('____________________________________________________________')
# torch.short()将该tensor投射为short类型
short_tensor = my_tensor.short()
print(short_tensor)
print('____________________________________________________________')
方法二
my_tensor = torch.randn(2, 4) # 默认为float32类型
my_tensor.type(torch.float16)
print('____________________________________________________________')
print(my_tensor.type(torch.float16))
print('____________________________________________________________')
print(my_tensor.type(torch.float32))
print('____________________________________________________________')
print(my_tensor.type(torch.int32))
print('____________________________________________________________')
print(my_tensor.type(torch.long))
print('____________________________________________________________')