理解keras中的K.batch_dot和TensorFlow的tf.matmul
概述
在使用keras中的keras.backend.batch_dot和tf.matmul实现功能其实是一样的智能矩阵乘法,比如 A , B , C , D , E , F , G , H , I , J , K , L A,B,C,D,E,F,G,H,I, J,K,LA,B,C,D,E,F,G,H,I,J,K,L都是二维矩阵,中间点表示矩阵乘法, A G AGAG表示矩阵 A AA和 G GG矩阵乘法( A AA的列维度等于 G GG行维度), W X = Z WX=ZWX=Z
import keras.backendas Kimport tensorflowas tfimport numpyas np
w= K.variable(np.random.randint(10,size=(10,12,4,5)))
k= K.variable(np.random.randint(10,size=(10,12,5,8)))
z= K.batch_dot(w,k)print(z.shape)#(10, 12, 4, 8)
import keras.backendas Kimport tensorflowas tfimport numpyas np
w= tf.Variable(np.random.randint(10,size=(10,12,4,5)),dtype=tf.float32)
k= tf.Variable(np.random.randint(10,size=(10,12,5,8)),dtype=tf.float32)
z= tf.matmul(w,k)print(z.shape)#(10, 12, 4, 8)
示例
from kerasimport backendas K
a= K.ones((3,4,5,2))
b= K.ones((2,5,3,7))
c= K.dot(a, b)print(c.shape)
会输出:
ValueError: Dimensions must be equal, but are 2 and 3 for ‘MatMul’ (op: ‘MatMul’) with input shapes: [60,2], [3,70].
from kerasimport backendas K
a= K.ones((3,4))
b= K.ones((4,5))
c= K.dot(a, b)print(c.shape)#(3,5)
或者
import tensorflowas tf
a= tf.ones((3,4))
b= tf.ones((4,5))
c= tf.matmul(a, b)print(c.shape)#(3,5)
如果增加维度:
from kerasimport backendas K
a= K.ones((2,3,4))
b= K.ones((7,4,5))
c= K.dot(a, b)print(c.shape)#(2, 3, 7, 5)
这个矩阵乘法会沿着两个矩阵最后两个维度进行乘法,不是element-wise矩阵乘法
from kerasimport backendas K
a= K.ones((1,2,3,4))
b= K.ones((8,7,4,5))
c= K.dot(a, b)print(c.shape)#(1, 2, 3, 8, 7, 5)
c a , b , c , i , j , k = ∑ r w a , b , c , r x i , j , r , k c_{a,b,c,i,j,k}=\sum_rw_{a,b,c,r}x_{i,j,r,k}ca,b,c,i,j,k=∑rwa,b,c,rxi,j,r,k
keras的dot方法是Theano中的复制
from kerasimport backendas K
a= K.ones((1,2,4))
b= K.ones((8,7,4,5))
c= K.dot(a, b)print(c.shape)# (1, 2, 8, 7, 5).
from kerasimport backendas K
a= K.ones((9,8,7,4,2))
b= K.ones((9,8,7,2,5))
c= K.batch_dot(a, b)print(c.shape)#(9, 8, 7, 4, 5)
或者
import tensorflowas tf
a= tf.ones((9,8,7,4,2))
b= tf.ones((9,8,7,2,5))
c= tf.matmul(a, b)print(c.shape)#(9, 8, 7, 4, 5)
参考
[1]: tf.keras.backend.batch_dot函数
[2]:keras batch_dot
[3]:Understand batch matrix multiplication
[4]:batch_dot