
这个本来是没打算写的,因为看了官方的解释以及在网上看了好几个教程都没理解什么意思,所以把自己理解的东西整理分享一下。
官方的解释官网链接:torch.gather()
给个截图如下
常用的参数有3个,第一个input表示要从中选取元素,第二个dim表示操作的维度,第三个index表示选取元素的索引。
按照官方的解释我是没看懂的,后面去找教程也一知半解,所以自己琢磨了一下,终于悟了。
结合着例子,直接看代码把:
import torch
a = torch.arange(3, 12).view(3, 3)
print(a)
# tensor([[ 3, 4, 5],
# [ 6, 7, 8],
# [ 9, 10, 11]])
index = torch.tensor([[2, 1, 0]])
b = torch.gather(a, dim=0, index=index)
print(b)
# tensor([[9, 7, 5]])
# 1、将index中的各个元素的索引明确,获得具体坐标:
# index = torch.tensor([[2, 1, 0]])中,
# 2的索引(坐标)为(0,0),1的索引(坐标)为(0,1),0的索引(坐标)为(0,2)
# 2、将具体坐标中对应的维度替换成index中的值:
# 2的索引(坐标)为(0,0),将第0个维度的索引替换后的新坐标为(2, 0),用2替换掉0
# 1的索引(坐标)为(0,1),将第0个维度的索引替换后的新坐标为(1, 1),用1替换掉0
# 0的索引(坐标)为(0,2),将第0个维度的索引替换后的新坐标为(0, 2),用0替换掉0
# 3、按照新的坐标取输入中的值:
# tensor([[ 3, 4, 5],
# [ 6, 7, 8],
# [ 9, 10, 11]]),坐标(2,0)值为9,坐标(1,1)值为7,坐标(0,2)值为5,得到最后的结果[9,7,5].
index = torch.tensor([[2, 1, 0]])
c = torch.gather(a, dim=1, index=index)
print(c) # tensor([[5, 4, 3]])
# 1、获取具体坐标:(0,0),(0,1),(0,2)
# 2、第1维度替换坐标:(0,2),(0,1),(0,0)
# 3、找元素:[5,4,3]
# 二维的情况也一样
index = torch.tensor([[0, 2],
[1, 2]])
d = torch.gather(a, dim=1, index=index)
print(d)
# tensor([[3, 5],
# [7, 8]])
# 1、获取具体坐标:(0,0),(0,1),(1,0),(1,1)
# 2、第1维度替换坐标:(0,0),(0,2),(1,1),(1,2)
# 3、找元素:[[3, 5],[7, 8]]
怕在代码里面太暗了看不清楚,在这里再贴一次:
以第一个为例:
创建张量
a = torch.arange(3, 12).view(3, 3)
print(a)
index = torch.tensor([[2, 1, 0]])
b = torch.gather(a, dim=0, index=index)
print(b)
a 的值如下:
tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
b 的值为tensor([[9, 7, 5]])
具体过程:
1、将index中的各个元素的索引明确,获得具体坐标:
index = torch.tensor([[2, 1, 0]])中,
2的索引(坐标)为(0,0),1的索引(坐标)为(0,1),0的索引(坐标)为(0,2)
2、将具体坐标中对应的维度替换成index中的值:
2的索引(坐标)为(0,0),将第0个维度的索引替换后的新坐标为(2, 0),用2替换掉0
1的索引(坐标)为(0,1),将第0个维度的索引替换后的新坐标为(1, 1),用1替换掉0
0的索引(坐标)为(0,2),将第0个维度的索引替换后的新坐标为(0, 2),用0替换掉0
3、按照新的坐标取输入中的值:
tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]]),坐标(2,0)值为9,坐标(1,1)值为7,坐标(0,2)值为5,得到最后的结果[9,7,5].
图解PyTorch中的torch.gather函数
Pytorch系列(1):torch.gather()
pytorch之torch.gather方法
文章为分享、记录、整理自己的经历情况,水平有限,如有错误之处敬请指出。