问题背景
使用cython编译keras工程,编译完成后运行时报错
问题原因
神经网络中存在keras.layers.Lambda层,对于构造参数function的lambda表示,cython无法读取为正常属性
问题代码及解决
# RNN layer
lstm_1= LSTM(32, return_sequences=True, kernel_initializer='he_normal', name='lstm1')(inner)
lstm_1b= LSTM(32, return_sequences=True, go_backwards=True, kernel_initializer='he_normal',name='lstm1_b')(inner)# (None, 64, 32)# 此处为原代码,编译为pyd后执行报错# reversed_lstm_1b = Lambda(function = lambda inputTensor: K.reverse(inputTensor, axes = 1))(lstm_1b)defrev(inputTensor):return K.reverse(inputTensor, axes=1)
reversed_lstm_1b= Lambda(function= rev)(lstm_1b)
lstm1_merged= concatenate([lstm_1, reversed_lstm_1b])# (None, 64, 64)
lstm1_merged= BatchNormalization()(lstm1_merged)
代码及解决
我是编译为.so后,调用出错
原码
KL.Lambda(lambda t: tf.reshape(t,[tf.shape(t)[0],-1,2]))(x)
改为
defrpn_class_lambda(t):return tf.reshape(t,[tf.shape(t)[0],-1,2])
KL.Lambda(
function= rpn_class_lambda)(x)