该函数的作用为:收集指定索引位置的值。
先将函数原型写出:
torch.gather(input, dim, index, out=None) → Tensor
参数:
- input (Tensor) – 源张量
- dim (int) – 索引的轴
- index (LongTensor) – 聚合元素的下标
- out (Tensor, optional) – 目标张量
首先来用自己的语言解释该函数的各个参数。
第一个参数就是自己要收集值的Tensor,不用多解释。
第二个参数就是指你要收集值的轴(也可以理解为行或列),如果是0,则按照横轴收集。如果是1,则按照纵轴收集。
第三个参数就是对应于你要搜集的Tensor的下标。
第四个参数一般缺省,这里不做细致的讨论。
下面直接上例子:
首先,我们先创建一个2维的Tensor。
然后我们就采用torch.gather()函数来取这个Tensor里面的值。
这里我们的dim取0,就是按照行进行取值,后面我们跟上一个Tensor,注意这里的Tensor一定要和我们前面的test Tensor的维数相同,要不然会报错。【你当然也可以理解,从一个二维的Tensor取值,你当然也是进入一个二维的Tensor索引】
这里就是从上往下看,第一个取的索引值为0的值就是1,第二个取索引值为0的值就是2.【也就是第一列取了第一个,第二列取了第一个】
我们改成其他的值也可以很明显的看出来。
然后我们来看一下取得索引多一点:
这样行的就理解的差不多了,然后我们把那个dim改成1进行观察。
这里我们进行解释一下,首先,我们还是那个tensor。
这里我们将dim改成1,当Tensor[[0,1]]时,第一个取的还是我们test Tensor中第一个值1,第二个就是索引值为1的2.
当tensor[[0,0],[0,0]]也是同理。
但是这里我们发现了不同,我们写入3个值时,却无法正常的取值。而上面的当dim=0时却可以正常的取值。那是因为这里我们是按照列进行取值的,我们这里的列数只有两个,所以当我们写第三个值的时候,就是取第三行的值,因为我们的test内的Tensor只有两行,就出现了错误。而与按行取的不同,每一次取的都是按照行索引值进行取值,所以,无论我们写多少个也不会报错。