栏目分类:
子分类:
返回
终身学习网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
终身学习网 > IT > 软件开发 > 后端开发 > Python

torch.gather()使用解析

Python 更新时间:发布时间: 百科书网 趣学号
前言

这个本来是没打算写的,因为看了官方的解释以及在网上看了好几个教程都没理解什么意思,所以把自己理解的东西整理分享一下。

官方的解释

官网链接:torch.gather()
给个截图如下

常用的参数有3个,第一个input表示要从中选取元素,第二个dim表示操作的维度,第三个index表示选取元素的索引。
按照官方的解释我是没看懂的,后面去找教程也一知半解,所以自己琢磨了一下,终于悟了。

使用详解

结合着例子,直接看代码把:

import torch

a = torch.arange(3, 12).view(3, 3)
print(a)
# tensor([[ 3,  4,  5],
#         [ 6,  7,  8],
#         [ 9, 10, 11]])
index = torch.tensor([[2, 1, 0]])
b = torch.gather(a, dim=0, index=index)
print(b)
# tensor([[9, 7, 5]])
# 1、将index中的各个元素的索引明确,获得具体坐标:
#    index = torch.tensor([[2, 1, 0]])中,
#    2的索引(坐标)为(0,0),1的索引(坐标)为(0,1),0的索引(坐标)为(0,2)
# 2、将具体坐标中对应的维度替换成index中的值:
#    2的索引(坐标)为(0,0),将第0个维度的索引替换后的新坐标为(2, 0),用2替换掉0
#    1的索引(坐标)为(0,1),将第0个维度的索引替换后的新坐标为(1, 1),用1替换掉0
#    0的索引(坐标)为(0,2),将第0个维度的索引替换后的新坐标为(0, 2),用0替换掉0
# 3、按照新的坐标取输入中的值:
# tensor([[ 3,  4,  5],
#         [ 6,  7,  8],
#         [ 9, 10, 11]]),坐标(2,0)值为9,坐标(1,1)值为7,坐标(0,2)值为5,得到最后的结果[9,7,5].

index = torch.tensor([[2, 1, 0]])
c = torch.gather(a, dim=1, index=index)
print(c)    # tensor([[5, 4, 3]])
# 1、获取具体坐标:(0,0),(0,1),(0,2)
# 2、第1维度替换坐标:(0,2),(0,1),(0,0)
# 3、找元素:[5,4,3]

# 二维的情况也一样
index = torch.tensor([[0, 2],
                      [1, 2]])
d = torch.gather(a, dim=1, index=index)
print(d)
# tensor([[3, 5],
#         [7, 8]])
# 1、获取具体坐标:(0,0),(0,1),(1,0),(1,1)
# 2、第1维度替换坐标:(0,0),(0,2),(1,1),(1,2)
# 3、找元素:[[3, 5],[7, 8]]

怕在代码里面太暗了看不清楚,在这里再贴一次:
以第一个为例:

创建张量
a = torch.arange(3, 12).view(3, 3)
print(a)
index = torch.tensor([[2, 1, 0]])
b = torch.gather(a, dim=0, index=index)
print(b)
a 的值如下:
tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
b 的值为tensor([[9, 7, 5]])
具体过程:

1、将index中的各个元素的索引明确,获得具体坐标:
index = torch.tensor([[2, 1, 0]])中,
2的索引(坐标)为(0,0),1的索引(坐标)为(0,1),0的索引(坐标)为(0,2)

2、将具体坐标中对应的维度替换成index中的值:
2的索引(坐标)为(0,0),将第0个维度的索引替换后的新坐标为(2, 0),用2替换掉0
1的索引(坐标)为(0,1),将第0个维度的索引替换后的新坐标为(1, 1),用1替换掉0
0的索引(坐标)为(0,2),将第0个维度的索引替换后的新坐标为(0, 2),用0替换掉0

3、按照新的坐标取输入中的值:
tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]]),坐标(2,0)值为9,坐标(1,1)值为7,坐标(0,2)值为5,得到最后的结果[9,7,5].

参考链接

图解PyTorch中的torch.gather函数
Pytorch系列(1):torch.gather()
pytorch之torch.gather方法

结束语

文章为分享、记录、整理自己的经历情况,水平有限,如有错误之处敬请指出。

转载请注明:文章转载自 www.051e.com
本文地址:http://www.051e.com/it/1033385.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 ©2023-2025 051e.com

ICP备案号:京ICP备12030808号