numpy模块中axis的理解——以np.argmax为例
np.argmax参数数量及其作用
np.argmax是用于取得数组中每一行或者每一列的的最大值。常用于机器学习中获取分类结果、计算精确度等。
函数如下:
np.argmax(
a,
axis=None,
out=None)
部分参数解释:
a:输入矩阵;
axis:对于二维向量而言,0代表对行进行最大值选取,此时对每一列进行操作;1代表对列进行最大值选取,此时对每一行进行操作。三维向量的情况更为复杂,需要结合例子说明。实际上axis的大小代表着进入到第axis+1个[ ]内,对其剩余的部分进行对比;
out:可以指定输出矩阵的变量
axis不同情况的示例
代码较长,许多都是注释,请大家耐心观看
import numpyas np# 一维向量测试# 取出x中元素最大值所对应的索引# 此时最大值为11,其对应的位置索引值为11
x= np.arange(12)
index= np.argmax(x)print("1 dimension test:",index)# 二维向量测试# 0代表对行进行最大值选取,此时对每一列进行操作
x= np.arange(12).reshape(3,4)
index= np.argmax(x,axis=0)# 结果为[2 2 2 2]print("2 dimension test, axis = 0:",index)# 二维向量测试# 1代表对列进行最大值选取,此时对每一行进行操作
x= np.arange(12).reshape(3,4)
index= np.argmax(x,axis=1)# 结果为[3 3 3]print("2 dimension test, axis = 1:",index)# 三维向量测试# 0代表进入第一个[]内进行对比
x= np.arange(24).reshape(2,3,4)
x[1,0,3]=1# x =# [[[ 0 1 2 3]# [ 4 5 6 7]# [ 8 9 10 11]]# [[12 13 14 1]# [16 17 18 19]# [20 21 22 23]]]
index= np.argmax(x,axis=0)print("3 dimension test, axis = 0:",index)# 当axis=0时,进入第一个[]内进行对比,此时x剩下两部分。# [[ 0 1 2 3]# [ 4 5 6 7]# [ 8 9 10 11]]# [[12 13 14 1]# [16 17 18 19]# [20 21 22 23]]# 两部分格式相同,将剩下的两部分每一个单位进行对比,对比结果为# [[1 1 1 0]# [1 1 1 1]# [1 1 1 1]]# 除去我设置的特殊位置外,其他位置均为第二部分大。# 三维向量测试# 1代表进入第二个[]内进行对比# x =# [[[ 0 1 2 3]# [ 4 5 6 7]# [ 8 9 10 11]]# [[12 13 14 1]# [16 17 18 19]# [20 21 22 23]]]
index= np.argmax(x,axis=1)print("3 dimension test, axis = 1:",index)# 当axis=1时,进入第二个[]内进行对比。# [ [ 0 1 2 3]# [ 4 5 6 7]# [ 8 9 10 11]# [12 13 14 1]# [16 17 18 19]# [20 21 22 23] ]# 对于第二个[]内的内容而言,均剩下三部分,我特意将两个第二个[]内的内容分开更容易辨认# 第一个是# [ 0 1 2 3]# [ 4 5 6 7]# [ 8 9 10 11]# 第二个是# [12 13 14 1]# [16 17 18 19]# [20 21 22 23]# 都是第三行的值最大,所以输出结果为# [[ 2 2 2 2]# [ 2 2 2 2]]# 三维向量测试# 2代表进入第三个[]内进行对比
x= np.arange(24).reshape(2,3,4)
x[1,0,3]=1# x =# [[[ 0 1 2 3]# [ 4 5 6 7]# [ 8 9 10 11]]# [[12 13 14 1]# [16 17 18 19]# [20 21 22 23]]]
index= np.argmax(x,axis=2)print("3 dimension test, axis = 2:",index)# 当axis=2时,进入第三个[]内进行对比。# [[ 0 1 2 3# 4 5 6 7# 8 9 10 11 ]# [ 12 13 14 1# 16 17 18 19# 20 21 22 23 ]]# 对于第三个[]内的内容而言,均剩下四部分,我特意将六个第三个[]内的内容分开更容易辨认# 第一个是# 0 1 2 3# 第二个是# 4 5 6 7# ……# 最后对比结果为# [[ 3 3 3 ]# [ 2 3 3 ]]
实际上axis的大小代表着进入到第axis+1个[ ]内,对其剩余的部分进行对比。
运行结果为:
1 dimension test:112 dimension test, axis=0:[2222]2 dimension test, axis=1:[333]3 dimension test, axis=0:[[1110][1111][1111]]3 dimension test, axis=1:[[2222][2222]]3 dimension test, axis=2:[[333][233]]