理解keras中的batch_dot,dot方法和TensorFlow的matmul

2022-10-02 11:15:59

理解keras中的K.batch_dot和TensorFlow的tf.matmul

概述

在使用keras中的keras.backend.batch_dot和tf.matmul实现功能其实是一样的智能矩阵乘法,比如 A , B , C , D , E , F , G , H , I , J , K , L A,B,C,D,E,F,G,H,I, J,K,LA,B,C,D,E,F,G,HI,J,K,L都是二维矩阵,中间点表示矩阵乘法, A G AGAG表示矩阵 A AA G GG矩阵乘法( A AA的列维度等于 G GG行维度), W X = Z WX=ZWX=Z

import keras.backendas Kimport tensorflowas tfimport numpyas np

w= K.variable(np.random.randint(10,size=(10,12,4,5)))
k= K.variable(np.random.randint(10,size=(10,12,5,8)))
z= K.batch_dot(w,k)print(z.shape)#(10, 12, 4, 8)
import keras.backendas Kimport tensorflowas tfimport numpyas np

w= tf.Variable(np.random.randint(10,size=(10,12,4,5)),dtype=tf.float32)
k= tf.Variable(np.random.randint(10,size=(10,12,5,8)),dtype=tf.float32)
z= tf.matmul(w,k)print(z.shape)#(10, 12, 4, 8)

在这里插入图片描述

示例

from kerasimport backendas K
a= K.ones((3,4,5,2))
b= K.ones((2,5,3,7))
c= K.dot(a, b)print(c.shape)

会输出:
ValueError: Dimensions must be equal, but are 2 and 3 for ‘MatMul’ (op: ‘MatMul’) with input shapes: [60,2], [3,70].

from kerasimport backendas K
a= K.ones((3,4))
b= K.ones((4,5))
c= K.dot(a, b)print(c.shape)#(3,5)

或者

import tensorflowas tf
a= tf.ones((3,4))
b= tf.ones((4,5))
c= tf.matmul(a, b)print(c.shape)#(3,5)

如果增加维度:

from kerasimport backendas K
a= K.ones((2,3,4))
b= K.ones((7,4,5))
c= K.dot(a, b)print(c.shape)#(2, 3, 7, 5)

这个矩阵乘法会沿着两个矩阵最后两个维度进行乘法,不是element-wise矩阵乘法

from kerasimport backendas K
a= K.ones((1,2,3,4))
b= K.ones((8,7,4,5))
c= K.dot(a, b)print(c.shape)#(1, 2, 3, 8, 7, 5)

c a , b , c , i , j , k = ∑ r w a , b , c , r x i , j , r , k c_{a,b,c,i,j,k}=\sum_rw_{a,b,c,r}x_{i,j,r,k}ca,b,c,i,j,k=rwa,b,c,rxi,j,r,k

keras的dot方法是Theano中的复制

from kerasimport backendas K
a= K.ones((1,2,4))
b= K.ones((8,7,4,5))
c= K.dot(a, b)print(c.shape)# (1, 2, 8, 7, 5).
from kerasimport backendas K
a= K.ones((9,8,7,4,2))
b= K.ones((9,8,7,2,5))
c= K.batch_dot(a, b)print(c.shape)#(9, 8, 7, 4, 5)

或者

import tensorflowas tf
a= tf.ones((9,8,7,4,2))
b= tf.ones((9,8,7,2,5))
c= tf.matmul(a, b)print(c.shape)#(9, 8, 7, 4, 5)

参考

[1]: tf.keras.backend.batch_dot函数
[2]:keras batch_dot
[3]:Understand batch matrix multiplication
[4]:batch_dot

  • 作者:huml126
  • 原文链接:https://blog.csdn.net/huml126/article/details/88739846
    更新时间:2022-10-02 11:15:59