tf2.0 中multiply、matmul、dot、batch_dot注意区别

2022-10-19 11:18:43

在tf和keras中上面这4个函数经常用到,需要注意相互之间的区别。

multiply:矩阵的逐元素点乘,需要输入矩阵x和y的shape相同或者可broadcast。

matmul:标准的矩阵乘法,要求第一个矩阵x.shape[-1]等于第二个矩阵y.shape[-2],同时要求  x.shape[:-2]和y.shape[:-2]必须相同。

(1)x和y是2D矩阵时,简单的矩阵乘法。

(2)当x和y不是2D矩阵时,则x和y除最后2个维度之外,其他维度必须相同,而且被认为是batch维度。

batch_dot:明确指定x和y的第一个维度为batch维度x.shape[1:]y.shape[1:]之间执行tf.keras.backend.dot运算。

dot:非标准的矩阵乘法,第一个矩阵x.shape[-1]要等于第二个矩阵y.shape[-2],但不需要x.shape[:-2]和y.shape[:-2]相同。

说:功能类似tf.matmul,区别在于当rank大于2时,除最后2个维度之外的前面维度是否需要相同。

 举例说明

(一)multiply

(二)matmul

(三)dot

解释:只要求x.shape[-1]=y.shape[-2]即可,计算时这2个维度sum over执行点乘运算。

(四)batch_dot

解释:x.shape[0]等于y.shape[0],为batch维度。x.shape[1:]和y.shape[1:]指定dot运算。

  • 作者:Terry_dong
  • 原文链接:https://nlplearning.blog.csdn.net/article/details/117793437
    更新时间:2022-10-19 11:18:43