np.argmax&torch.max()对比

2022-10-08 14:07:27

argmax函数


通俗来说:在axis的增长方向上求最大值

np.argmax()

import numpyas np
a= np.array([[[1,5,5,2],[9,-6,2,8],[-3,7,-9,1]],[[-1,7,-5,2],[9,6,2,8],[3,7,9,1]],[[21,6,-5,2],[9,36,2,8],[3,7,79,1]]])
b=np.argmax(a, axis=0)#对于三维度矩阵,a有三个方向a[0][1][2]
#当axis=0时,是在a[0]方向上找最大值,即三个矩阵做比较,具体[1,5,5,2],[-1,7,-5,2],[21,6,-5,2],
#一共有三个,所以最终得到的结果b就为34列矩阵print(b)[[2100][0200][1020]]
 
c=np.argmax(a, axis=1)#对于三维度矩阵,a有三个方向a[0][1][2]
#当axis=1时,是在a[1]方向上找最大值,即在列方向比较,此时就是指在每个矩阵内部的列方向上进行比较[1,5,5,2],[9,-6,2,8],[-3,7,-9,1]
#一共有三个,所以最终得到的结果b就为34列矩阵print(c)[[1201][1021][0121]]
 
d=np.argmax(a, axis=2)#对于三维度矩阵,a有三个方向a[0][1][2]
#当axis=2时,是在a[2]方向上找最大值,即在行方向比较,此时就是指在每个矩阵内部的行方向上进行比较[1,5,5,2],[9,-6,2,8],[-3,7,-9,1]
#寻找第一行的最大值,可以看出第一行[1,5,5,2]最大值为5,,索引值为1print(d)[[101][102][012]]
##################################################################
# 第一个矩阵,取最后一行的所有列
m=np.argmax(a[0,-1,:])print(m)
#1
 
# 第二个矩阵,取第三行的所有列
h=np.argmax(a[1,2,:])print(h)
#2

# 第二个矩阵,取所有行的第三列
g=np.argmax(a[1,:,2])print(g)
#2
import numpyas npimport numpyas np

arrary= np.array([[[1,5,5,2],[9,-6,2,8],[-3,7,-9,1]],[[-1,7,-5,2],[9,6,2,8],[3,7,9,1]],[[21,6,-5,2],[9,36,2,8],[3,7,79,1]]])print(arrary.shape)

a= np.argmax(arrary, axis=0)
b= np.argmax(arrary, axis=1)
c= np.argmax(arrary, axis=2)print('argmax axis = 0 is ', a)print('argmax axis = 1 is ', b)print('argmax axis = 2 is ', c)import torch
d= torch.from_numpy(arrary)print(arrary)
d= torch.argmax(d, dim=-0)print('torch argmax is :', d)
torch.argmax(arrary, dim=-1)
#dim可以有-3-2-1012
torch.max(a,0) 返回每一列中最大值的那个元素,且返回其索引(返回最大元素在这一列的行索引)
axis=0表示以行的维度为基准,行上的所有数据所在列上的最大值,通俗来说:在axis的增长方向上求最大值
torch.max(a,1) 返回每一行中最大值的那个元素,且返回其索引(返回最大元素在这一行的列索引)
axis=1表示以列的维度为基准,列上的所有数据所在行上的最大值,通俗来说:在axis的增长方向上求最大值
  • 作者:Deeachain
  • 原文链接:https://blog.csdn.net/Deeachain/article/details/106981467
    更新时间:2022-10-08 14:07:27