numpy模块中axis的理解——以np.argmax为例

2022-10-08 10:09:56

numpy模块中axis的理解——以np.argmax为例

np.argmax参数数量及其作用

np.argmax是用于取得数组中每一行或者每一列的的最大值。常用于机器学习中获取分类结果、计算精确度等。
函数如下:

np.argmax(
	a, 
	axis=None, 
	out=None)

部分参数解释:
a:输入矩阵;
axis:对于二维向量而言,0代表对行进行最大值选取,此时对每一列进行操作;1代表对列进行最大值选取,此时对每一行进行操作。三维向量的情况更为复杂,需要结合例子说明。实际上axis的大小代表着进入到第axis+1个[ ]内,对其剩余的部分进行对比;
out:可以指定输出矩阵的变量

axis不同情况的示例

代码较长,许多都是注释,请大家耐心观看

import numpyas np# 一维向量测试# 取出x中元素最大值所对应的索引# 此时最大值为11,其对应的位置索引值为11
x= np.arange(12)
index= np.argmax(x)print("1 dimension test:",index)# 二维向量测试# 0代表对行进行最大值选取,此时对每一列进行操作
x= np.arange(12).reshape(3,4)
index= np.argmax(x,axis=0)# 结果为[2 2 2 2]print("2 dimension test, axis = 0:",index)# 二维向量测试# 1代表对列进行最大值选取,此时对每一行进行操作
x= np.arange(12).reshape(3,4)
index= np.argmax(x,axis=1)# 结果为[3 3 3]print("2 dimension test, axis = 1:",index)# 三维向量测试# 0代表进入第一个[]内进行对比
x= np.arange(24).reshape(2,3,4)
x[1,0,3]=1# x =# [[[ 0  1  2  3]#   [ 4  5  6  7]#   [ 8  9 10 11]]#  [[12 13 14  1]#   [16 17 18 19]#   [20 21 22 23]]]
index= np.argmax(x,axis=0)print("3 dimension test, axis = 0:",index)# 当axis=0时,进入第一个[]内进行对比,此时x剩下两部分。#  [[ 0  1  2  3]#   [ 4  5  6  7]#   [ 8  9 10 11]]#  [[12 13 14  1]#   [16 17 18 19]#   [20 21 22 23]]# 两部分格式相同,将剩下的两部分每一个单位进行对比,对比结果为#  [[1  1  1  0]#   [1  1  1  1]#   [1  1  1  1]]# 除去我设置的特殊位置外,其他位置均为第二部分大。# 三维向量测试# 1代表进入第二个[]内进行对比# x =# [[[ 0  1  2  3]#   [ 4  5  6  7]#   [ 8  9 10 11]]#  [[12 13 14  1]#   [16 17 18 19]#   [20 21 22 23]]]
index= np.argmax(x,axis=1)print("3 dimension test, axis = 1:",index)# 当axis=1时,进入第二个[]内进行对比。# [ [ 0  1  2  3]#   [ 4  5  6  7]#   [ 8  9 10 11]#   [12 13 14  1]#   [16 17 18 19]#   [20 21 22 23] ]# 对于第二个[]内的内容而言,均剩下三部分,我特意将两个第二个[]内的内容分开更容易辨认# 第一个是#   [ 0  1  2  3]#   [ 4  5  6  7]#   [ 8  9 10 11]# 第二个是#   [12 13 14  1]#   [16 17 18 19]#   [20 21 22 23]# 都是第三行的值最大,所以输出结果为#  [[ 2  2  2  2]#   [ 2  2  2  2]]# 三维向量测试# 2代表进入第三个[]内进行对比
x= np.arange(24).reshape(2,3,4)
x[1,0,3]=1# x =# [[[ 0  1  2  3]#   [ 4  5  6  7]#   [ 8  9 10 11]]#  [[12 13 14  1]#   [16 17 18 19]#   [20 21 22 23]]]
index= np.argmax(x,axis=2)print("3 dimension test, axis = 2:",index)# 当axis=2时,进入第三个[]内进行对比。# [[  0  1  2  3#     4  5  6  7#     8  9 10 11 ]#  [ 12 13 14  1#    16 17 18 19#    20 21 22 23 ]]# 对于第三个[]内的内容而言,均剩下四部分,我特意将六个第三个[]内的内容分开更容易辨认# 第一个是# 0  1  2  3# 第二个是# 4  5  6  7# ……# 最后对比结果为#  [[ 3  3  3 ]#   [ 2  3  3 ]]

实际上axis的大小代表着进入到第axis+1个[ ]内,对其剩余的部分进行对比。
运行结果为:

1 dimension test:112 dimension test, axis=0:[2222]2 dimension test, axis=1:[333]3 dimension test, axis=0:[[1110][1111][1111]]3 dimension test, axis=1:[[2222][2222]]3 dimension test, axis=2:[[333][233]]
  • 作者:Bubbliiiing
  • 原文链接:https://blog.csdn.net/weixin_44791964/article/details/100017976
    更新时间:2022-10-08 10:09:56