numpy三维batch矩阵相乘

2022-10-11 12:35:36

写代码时候遇到了带batch的矩阵乘法,需要用numpy实现。即A=[batch,height,width], B=[batch,width,height], C=AB=[batch,height,height]。在tensorflow中是可以直接使用tf.matmul得到结果的,但是在numpy中没有现成的三维矩阵乘法。

三维矩阵乘法的思路就是:C[0]=A[0]B[0], C[1]=A[1]B[1],也就是分别将A和B的batch的每个样本进行矩阵乘法,然后构成C。在numpy中可以使用np.einsum一行代码实现。

简单介绍:np.einsum('ij, jk', A, B)是矩阵A乘以矩阵B,等价于np.dot(A,B),这是二维的

对于三维的AB,则  np.einsum("ijk,ikn->ijn", A, B)  ijk表示A的索引,ikn表示B的索引,定义输出的维度是ijn

import numpy as np
import tensorflow as tf

if __name__ == "__main__":
  batch_size = 2
  height = 4
  width = 2
  a = np.random.rand(batch_size, height, width)
  b = np.random.rand(batch_size, width, height)
  print("*************** a 输入 {} ****************".format(a.shape))
  print(a)
  print("*************** b 输入 {} ****************".format(b.shape))
  print(b)

  aa = tf.placeholder(tf.float32, [batch_size, height, width])
  bb = tf.placeholder(tf.float32, [batch_size, width, height])
  cc = tf.matmul(a, b)
  with tf.Session() as sess:
    out = sess.run(cc, feed_dict={aa:a, bb:b})
  print("*************** tf 输出 {} ****************".format(out.shape))
  print(out)
  xx = np.einsum("ijk,ikn->ijn", a, b)
  print("\n*************** numpy 输出 {} ****************".format(xx.shape))

  print(xx)
  err_max2 = np.amax(np.absolute(np.subtract(out, xx)))
  print("\ntf与numpy误差:{}".format(err_max2))

# *************** a 输入 (2, 4, 2) ****************
# [[[0.48151815 0.59571173]
#   [0.54950679 0.07559809]
#   [0.54483139 0.49344093]
#   [0.66313407 0.7736222 ]]
# 
#  [[0.71144517 0.25567787]
#   [0.82224508 0.87165079]
#   [0.27935693 0.10498713]
#   [0.39752717 0.62073428]]]
# *************** b 输入 (2, 2, 4) ****************
# [[[0.43953543 0.51880854 0.35398745 0.59761315]
#   [0.1339994  0.46699152 0.9858384  0.67810861]]
# 
#  [[0.91692865 0.63337183 0.52427425 0.77657735]
#   [0.14262564 0.92203296 0.27971297 0.95416443]]]

# *************** tf 输出 (2, 4, 4) ****************
# [[[0.2914693  0.52800806 0.75772688 0.69171883]
#   [0.2516578  0.32039248 0.26904601 0.3796562 ]
#   [0.30559349 0.51309591 0.67931649 0.66020494]
#   [0.39513583 0.70531463 0.99740761 0.92089751]]
# 
#  [[0.68881068 0.68635275 0.4445088  0.79645093]
#   [0.87825983 1.32447763 0.67489395 1.47023509]
#   [0.27112423 0.2737384  0.1758259  0.31711724]
#   [0.45303667 0.82411997 0.38204068 0.90099316]]]
# 
# *************** numpy 输出 (2, 4, 4) ****************
# [[[0.2914693  0.52800806 0.75772688 0.69171883]
#   [0.2516578  0.32039248 0.26904601 0.3796562 ]
#   [0.30559349 0.51309591 0.67931649 0.66020494]
#   [0.39513583 0.70531463 0.99740761 0.92089751]]
# 
#  [[0.68881068 0.68635275 0.4445088  0.79645093]
#   [0.87825983 1.32447763 0.67489395 1.47023509]
#   [0.27112423 0.2737384  0.1758259  0.31711724]
#   [0.45303667 0.82411997 0.38204068 0.90099316]]]
# 
# tf与numpy误差:0.0
  • 作者:tafengtianya
  • 原文链接:https://blog.csdn.net/tafengtianya/article/details/107497063
    更新时间:2022-10-11 12:35:36