【API使用总结】np.argmax()和tf.argmax()的辨析

2022-09-28 12:58:51

提示:如果本文对您有帮助,请点赞支持!

前言

在写AI算法的Demo时,偶然间出现了一个bug,发现是我不小心将tf.argmax()写成了np.argmax(),正好闲来无事,辨析下两个API的使用


一、np.argmax()的使用

np.argmax()是Python的第三方库numpy中的一个常见API,经常用来获取数组中的最大值所在的索引。所以使用该API要先导入该库:

import numpy as np

该API的完全定义如下:

def argmax(a, axis=None, out=None)# 第1个参数是输入的np数组;第2个参数是所获取的轴,取值为整数0,1等;第3个参数是输出的np数组,一般用不到

首先定义一个一维数组来进行测试:

    # 定义一个一维数组
    y1 = np.array([1, 2, 3, 7, 8, 9])
    print("result: {}".format(np.argmax(y1,axis=None))) #result: 5
    print("result: {}".format(np.argmax(y1, axis=0)))  # result: 5
    print("result: {}".format(np.argmax(y1, axis=1)))  # numpy.AxisError: axis 1 is out of bounds for array of dimension 1

接下来定义一个二维数组来进行测试:

    # 定义一个二维数组
    y2 = np.array([[1, 9, 3], [7, 8, 9]])
    print("result: {}".format(np.argmax(y2,axis=None))) # result: 1
    print("result: {}".format(np.argmax(y2,axis=0)))  # result: [1 0 1]
    print("result: {}".format(np.argmax(y2,axis=1)))  # result: [1 2]

最后定义一个三维数组来进行测试:

# 定义一个三维数组
    y3 = np.array([[[1, 9, 3],
                    [7, 8, 9]],
                   [[2, 6, 3],
                    [7, 6, 9]],
                   ])
    print("result: {}".format(np.argmax(y3, axis=None)))  # result: 1返回最大值所在的索引,类型为整型,如果有多个相同的最大值,则返回第一个
    print("result: {}".format(np.argmax(y3, axis=0)))  # result:  [[1 0 0][0 0 0]] z最大的
    print("result: {}".format(np.argmax(y3, axis=1)))  # result: [[1 0 1] [1 0 1]]y最大的
    print("result: {}".format(np.argmax(y3, axis=2)))  # result: [[1 2 ] [1 2]] x最大的

总结上述实例,我们可以总结出如下规律:

当axis=None时,我们将n维数组降为一维数组,取该数组里面最大值的索引,若存在多个最大值则返回第一个最大值所在的索引,所以返回的是一个0维的整型数字;

例如上述的三维数组被看成[1, 9, 3,7, 8, 9,2, 6, 3,7, 6, 9],第一个最大值所在索引是1

当axis=0时,取第0维中的每个元素的对应位置取最大值索引,若存在多个最大值则返回第一个最大值所在的索引,返回的是n-1维数组;

例如上述的三维数组在第0维方向上,我们的作用对象是[[1, 9, 3],[7, 8, 9]]和[[2, 6, 3],[7, 6, 9]],此时作用后则变成了 [[1 0 0][0 0 0]]

当axis=1时,取第1维中的每个元素的对应位置取最大值索引,若存在多个最大值则返回第一个最大值所在的索引,返回的是n-1维数组;

例如上述的三维数组在第1维方向的,我们的第1个作用对象是[1, 9, 3]和[7, 8, 9],其作用后是[1 0 1];第2个作用对象是[2, 6, 3],[7, 6, 9],其作用后是[1 0 1],则最终结果变成了 [[1 0 1][0 0 1]]

当axis=2时,取第2维中的每个元素的对应位置取最大值索引,若存在多个最大值则返回第一个最大值所在的索引,返回的是n-1维数组;

例如上述的三维数组在第2维方向的,则我们的第1个作用对象是[1, 9, 3],其作用结果是1;第1个作用对象是[7, 8, 9],其作用结果是2,此时合起来是[1,2];第3个作用对象是[2, 6, 3],其作用结果是1;第4个作用对象是[7, 6, 9],其作用结果是2,此时合起来是[1,2];;则最终结果变成了 [[1 2][1 2]]

当axis>多维数组的秩-1时,则报错:numpy.AxisError: axis 1 is out of bounds for array of dimension 1

二、tf.argmax()的使用

tf.argmax()是TensorFlow的一个常见API,也是经常用来获取数组中的最大值所在的索引,其内部也是用numpy来实现的。所以使用该API要先导入该库:

import tensorflow as tf

该API的完全定义如下:

argmax(input,axis=None,name=None,dimension=None,output_type=dtypes.int64)# 第1个参数是输入的tf张量;第2个参数是所获取的轴,取值为整数0,1等;第3个参数是返回的tf张量的名字;剩下的参数一般不常用

该API返回的是tf张量,这一点和numpy返回np数组不同。

因为返回的是tf张量,所以直接输出tf张量,只会查看该张量对象的一些基本信息,例如:

print("result: {}".format(tf.argmax(y1)))  # result: Tensor("ArgMax_2:0", shape=(), dtype=int64)

在TensorFlow中要用会话Session来输出tf张量,所以正确的打印如下:

    with tf.Session() as sess:
        result = sess.run(tf.argmax(y1, 0))
        print("result: {}".format(result))# result: 5

其他用法和上述的np.argmax()相同。

  • 作者:魔法攻城狮MRL
  • 原文链接:https://blog.csdn.net/qq_41959920/article/details/115872288
    更新时间:2022-09-28 12:58:51